spark mysql 行级别控制

来源:互联网 发布:巴洛克记忆音乐知乎 编辑:程序博客网 时间:2024/05/02 04:57

spark 的save mode

spark 的saveMode在org.apache.spark.sql.SaveMode下,是一个枚举类,支持

  1. Append(在mysql中为append)
  2. Overwrite(在mysql中为先删除表,再整体将新的df存进去)
  3. ErrorIfExists(存在表则报错)
  4. Ignore(存在表则不执行任何动作的退出)

而实际业务开发中,我们可能更希望一些行级别的动作而非这种表级别的动作

新的mysqlSaveMode

总结业务开发过程中常见的需求,设计出以下枚举类:

package org.apache.spark.sql.ximautilpackage org.apache.spark.sql.ximautil/**  * @author todd.chen at 8/26/16 9:52 PM.  *         email : todd.chen@ximalaya.com  */object JdbcSaveMode extends Enumeration {  type SaveMode = Value  val IgnoreTable, Append, Overwrite, Update, ErrorIfExists, IgnoreRecord = Value}
  1. IgnoreTable 类似原来的Ignore,表存在则不执行动作
  2. Append 类似原来的Append
  3. Overwrite 类似原来的Overwrite
  4. Update 则通过ON DUPLICATE KEY UPDATE保证
  5. ErrorIfExists 则类似原来的ErrorIfExists
  6. IgnoreRecord 则通过INSERT IGNORE INTO保证

对应的执行SQL语句应该是

  /**    * Returns a PreparedStatement that inserts a row into table via conn.    */  def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect, saveMode: SaveMode)  : PreparedStatement = {    val columnNames = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name))    val columns = columnNames.mkString(",")    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")    val sql = saveMode match {      case Update ⇒        val duplicateSetting = columnNames.map(name ⇒ s"$name=?").mkString(",")        s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"      case Append | Overwrite ⇒        s"INSERT INTO $table ($columns) VALUES ($placeholders)"      case IgnoreRecord ⇒        s"INSERT IGNORE INTO $table ($columns) VALUES ($placeholders)"      case _ ⇒ throw new IllegalArgumentException(s"$saveMode is illegal")    }    conn.prepareStatement(sql)  }

JDBCUtil 类的解读和满足需求下的重写

2.0之前的org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils其实是有问题的,对于每一行row的set都进行了比较类型,时间复杂度非常高,2.0之后重写出了一个setter逻辑,形成了一个prepareStatment的模板,这样瞬间将原来的比较类型进行了指数级优化,核心代码:

  // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for  // `PreparedStatement`. The last argument `Int` means the index for the value to be set  // in the SQL statement and also used for the value in `Row`.  private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit  private def makeSetter(      conn: Connection,      dialect: JdbcDialect,      dataType: DataType): JDBCValueSetter = dataType match {    case IntegerType =>      (stmt: PreparedStatement, row: Row, pos: Int) =>        stmt.setInt(pos + 1, row.getInt(pos))    case LongType =>      (stmt: PreparedStatement, row: Row, pos: Int) =>        stmt.setLong(pos + 1, row.getLong(pos))    case DoubleType =>      (stmt: PreparedStatement, row: Row, pos: Int) =>        stmt.setDouble(pos + 1, row.getDouble(pos))    case FloatType =>      (stmt: PreparedStatement, row: Row, pos: Int) =>        stmt.setFloat(pos + 1, row.getFloat(pos))    case ShortType =>      (stmt: PreparedStatement, row: Row, pos: Int) =>        stmt.setInt(pos + 1, row.getShort(pos))    case ByteType =>      (stmt: PreparedStatement, row: Row, pos: Int) =>        stmt.setInt(pos + 1, row.getByte(pos))    case BooleanType =>      (stmt: PreparedStatement, row: Row, pos: Int) =>        stmt.setBoolean(pos + 1, row.getBoolean(pos))    case StringType =>      (stmt: PreparedStatement, row: Row, pos: Int) =>        stmt.setString(pos + 1, row.getString(pos))    case BinaryType =>      (stmt: PreparedStatement, row: Row, pos: Int) =>        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))    case TimestampType =>      (stmt: PreparedStatement, row: Row, pos: Int) =>        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))    case DateType =>      (stmt: PreparedStatement, row: Row, pos: Int) =>        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))    case t: DecimalType =>      (stmt: PreparedStatement, row: Row, pos: Int) =>        stmt.setBigDecimal(pos + 1, row.getDecimal(pos))    case ArrayType(et, _) =>      // remove type length parameters from end of type name      val typeName = getJdbcType(et, dialect).databaseTypeDefinition        .toLowerCase.split("\\(")(0)      (stmt: PreparedStatement, row: Row, pos: Int) =>        val array = conn.createArrayOf(          typeName,          row.getSeq[AnyRef](pos).toArray)        stmt.setArray(pos + 1, array)    case _ =>      (_: PreparedStatement, _: Row, pos: Int) =>        throw new IllegalArgumentException(          s"Can't translate non-null value for field $pos")  }

这个虽然已经解决了大多数问题,但如果使用DUPLICATE还是有问题的:

  1. 非DUPLICATE的sql : insert into table_name (name,age,id) values (?,?,?)
  2. DUPLICATE的sql : insert into table_name (name,age,id) values (?,?,?) on duplicate key update name =? ,age=?,id=?

所以在prepareStatment中的占位符应该是row的两倍,而且应该是类似这样的一个逻辑:

row[1,2,3]setter(0,1) //index of setter,index of rowsetter(1,2)setter(2,3)setter(3,1)setter(4,2)setter(5,3)

我们能发现当超过setter.length 的一半时,此时的row的index应该是setterIndex - (setterIndex/2) + 1

所以新的一个实现是这样的:

// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for  // `PreparedStatement`.  argument `Int` means the index for the value to be set  // in the SQL statement and also used for the value in `Row`.  // offset using in duplicateSetting  private type JDBCValueSetter = (PreparedStatement, Row, Int, Int) ⇒ Unit  private def makeSetter(                          conn: Connection,                          dialect: JdbcDialect,                          dataType: DataType): JDBCValueSetter = dataType match {    case IntegerType ⇒      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        stmt.setInt(pos + 1, row.getInt(pos - offset))    case LongType ⇒      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        stmt.setLong(pos + 1, row.getLong(pos - offset))    case DoubleType ⇒      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        stmt.setDouble(pos + 1, row.getDouble(pos - offset))    case FloatType ⇒      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        stmt.setFloat(pos + 1, row.getFloat(pos - offset))    case ShortType ⇒      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        stmt.setInt(pos + 1, row.getShort(pos - offset))    case ByteType ⇒      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        stmt.setInt(pos + 1, row.getByte(pos - offset))    case BooleanType ⇒      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        stmt.setBoolean(pos + 1, row.getBoolean(pos - offset))    case StringType ⇒      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        stmt.setString(pos + 1, row.getString(pos - offset))    case BinaryType ⇒      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos - offset))    case TimestampType ⇒      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos - offset))    case DateType ⇒      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos - offset))    case t: DecimalType ⇒      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        stmt.setBigDecimal(pos + 1, row.getDecimal(pos - offset))    case ArrayType(et, _) ⇒      // remove type length parameters from end of type name      val typeName = getJdbcType(et, dialect).databaseTypeDefinition        .toLowerCase.split("\\(")(0)      (stmt: PreparedStatement, row: Row, pos: Int, offset: Int) ⇒        val array = conn.createArrayOf(          typeName,          row.getSeq[AnyRef](pos - offset).toArray)        stmt.setArray(pos + 1, array)    case _ ⇒      (_: PreparedStatement, _: Row, pos: Int, offset: Int) ⇒        throw new IllegalArgumentException(          s"Can't translate non-null value for field $pos")  } private def getSetter(fields: Array[StructField], connection: Connection, dialect: JdbcDialect, isUpdateMode: Boolean): Array[JDBCValueSetter] = {    val setter = fields.map(_.dataType).map(makeSetter(connection, dialect, _))    if (isUpdateMode) {      Array.fill(2)(setter).flatten    } else {      setter    }  }

在使用过程中的改变主要是:

源码:

  def savePartition(      getConnection: () => Connection,      table: String,      iterator: Iterator[Row],      rddSchema: StructType,      nullTypes: Array[Int],      batchSize: Int,      dialect: JdbcDialect,      isolationLevel: Int): Iterator[Byte] = {    require(batchSize >= 1,      s"Invalid value `${batchSize.toString}` for parameter " +      s"`${JdbcUtils.JDBC_BATCH_INSERT_SIZE}`. The minimum value is 1.")    val conn = getConnection()    var committed = false    var finalIsolationLevel = Connection.TRANSACTION_NONE    if (isolationLevel != Connection.TRANSACTION_NONE) {      try {        val metadata = conn.getMetaData        if (metadata.supportsTransactions()) {          // Update to at least use the default isolation, if any transaction level          // has been chosen and transactions are supported          val defaultIsolation = metadata.getDefaultTransactionIsolation          finalIsolationLevel = defaultIsolation          if (metadata.supportsTransactionIsolationLevel(isolationLevel))  {            // Finally update to actually requested level if possible            finalIsolationLevel = isolationLevel          } else {            logWarning(s"Requested isolation level $isolationLevel is not supported; " +                s"falling back to default isolation level $defaultIsolation")          }        } else {          logWarning(s"Requested isolation level $isolationLevel, but transactions are unsupported")        }      } catch {        case NonFatal(e) => logWarning("Exception while detecting transaction support", e)      }    }    val supportsTransactions = finalIsolationLevel != Connection.TRANSACTION_NONE    try {      if (supportsTransactions) {        conn.setAutoCommit(false) // Everything in the same db transaction.        conn.setTransactionIsolation(finalIsolationLevel)      }      val stmt = insertStatement(conn, table, rddSchema, dialect)      val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)        .map(makeSetter(conn, dialect, _)).toArray      try {        var rowCount = 0        while (iterator.hasNext) {          val row = iterator.next()          val numFields = rddSchema.fields.length          var i = 0          while (i < numFields) {            if (row.isNullAt(i)) {              stmt.setNull(i + 1, nullTypes(i))            } else {              setters(i).apply(stmt, row, i)            }            i = i + 1          }          stmt.addBatch()          rowCount += 1          if (rowCount % batchSize == 0) {            stmt.executeBatch()            rowCount = 0          }        }        if (rowCount > 0) {          stmt.executeBatch()        }      } finally {        stmt.close()      }      if (supportsTransactions) {        conn.commit()      }      committed = true    } catch {      case e: SQLException =>        val cause = e.getNextException        if (e.getCause != cause) {          if (e.getCause == null) {            e.initCause(cause)          } else {            e.addSuppressed(cause)          }        }        throw e    } finally {      if (!committed) {        // The stage must fail.  We got here through an exception path, so        // let the exception through unless rollback() or close() want to        // tell the user about another problem.        if (supportsTransactions) {          conn.rollback()        }        conn.close()      } else {        // The stage must succeed.  We cannot propagate any exception close() might throw.        try {          conn.close()        } catch {          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)        }      }    }    Array[Byte]().iterator  }

改动点:

def savePartition(                     getConnection: () => Connection,                     table: String,                     iterator: Iterator[Row],                     rddSchema: StructType,                     nullTypes: Array[Int],                     batchSize: Int,                     dialect: JdbcDialect,                     isolationLevel: Int,                     saveMode: SaveMode) = {    require(batchSize >= 1,      s"Invalid value `${batchSize.toString}` for parameter " +        s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.")    val isUpdateMode = saveMode == Update //check is UpdateMode    val conn = getConnection()    var committed = false    val length = rddSchema.fields.length    val numFields = if (isUpdateMode) length * 2 else length // real num Field length       val stmt = insertStatement(conn, table, rddSchema, dialect, saveMode)      val setters: Array[JDBCValueSetter] = getSetter(rddSchema.fields, conn, dialect, isUpdateMode) //call method getSetter        var rowCount = 0        while (iterator.hasNext) {          val row = iterator.next()          var i = 0          val midField = numFields / 2          while (i < numFields) {            //if duplicate ,'?' size = 2 * row.field.length            if (isUpdateMode) {              i < midField match { // check midField > i ,if midFiled >i ,rowIndex is setterIndex - (setterIndex/2) + 1                case trueif (row.isNullAt(i)) {                    stmt.setNull(i + 1, nullTypes(i))                  } else {                    setters(i).apply(stmt, row, i, 0)                  }                case falseif (row.isNullAt(i - midField)) {                    stmt.setNull(i + 1, nullTypes(i - midField))                  } else {                    setters(i).apply(stmt, row, i, midField)                  }              }            } else {              if (row.isNullAt(i)) {                stmt.setNull(i + 1, nullTypes(i))              } else {                setters(i).apply(stmt, row, i, 0)              }            }            i = i + 1          }

封装的bean对象:

case class JdbcSaveExplain(                            url: String,                            tableName: String,                            saveMode: SaveMode,                            jdbcParam: Properties                          )

封装的DataFrameWriter对象

package com.ximalaya.spark.xql.exec.jdbcimport java.util.Propertiesimport com.ximalaya.spark.common.log.CommonLoggerTraitimport language._import com.ximalaya.spark.xql.interpreter.jdbc.JdbcSaveExplainimport org.apache.spark.sql.DataFrameimport org.apache.spark.sql.execution.datasources.jdbc.JdbcUtilsimport org.apache.spark.sql.ximautil.JdbcSaveMode.SaveModeimport org.apache.spark.sql.ximautil.JdbcSaveMode._import org.apache.spark.sql.ximautil.XQLJdbcUtil/**  * @author todd.chen at 8/26/16 11:33 PM.  *         email : todd.chen@ximalaya.com  */class JdbcDataFrameWriter(dataFrame: DataFrame) extends Serializable with CommonLoggerTrait {  def writeJdbc(jdbcSaveExplain: JdbcSaveExplain) = {    this.jdbcSaveExplain = jdbcSaveExplain    this  }  def save(): Unit = {    assert(jdbcSaveExplain != null)    val saveMode = jdbcSaveExplain.saveMode    val url = jdbcSaveExplain.url    val table = jdbcSaveExplain.tableName    val props = jdbcSaveExplain.jdbcParam    if (checkTable(url, table, props, saveMode))      XQLJdbcUtil.saveTable(dataFrame, url, table, props, saveMode)  }  private def checkTable(url: String, table: String, connectionProperties: Properties, saveMode: SaveMode): Boolean = {    val props = new Properties()    extraOptions.foreach { case (key, value) =>      props.put(key, value)    }    // connectionProperties should override settings in extraOptions    props.putAll(connectionProperties)    val conn = JdbcUtils.createConnectionFactory(url, props)()    try {      var tableExists = JdbcUtils.tableExists(conn, url, table)      //table ignore ,exit      if (saveMode == IgnoreTable && tableExists) {        logger.info(" table {} exists ,mode is ignoreTable,save nothing to it", table)        return false      }      //error if table exists      if (saveMode == ErrorIfExists && tableExists) {        sys.error(s"Table $table already exists.")      }      //overwrite table ,delete table      if (saveMode == Overwrite && tableExists) {        JdbcUtils.dropTable(conn, table)        tableExists = false      }      // Create the table if the table didn't exist.      if (!tableExists) {        checkField(dataFrame)        val schema = JdbcUtils.schemaString(dataFrame, url)        val sql = s"CREATE TABLE $table (id int not null primary key auto_increment , $schema)"        conn.prepareStatement(sql).executeUpdate()      }      true    } finally {      conn.close()    }  }  //because table in mysql need id  as primary key auto increment,illegal if dataFrame contains id  field  private def checkField(dataFrame: DataFrame): Unit = {    if (dataFrame.schema.exists(_.name == "id")) {      throw new IllegalArgumentException("dataFrame exists id columns,but id is primary key auto increment in mysql ")    }  }  private var jdbcSaveExplain: JdbcSaveExplain = _  private val extraOptions = new scala.collection.mutable.HashMap[String, String]}object JdbcDataFrameWriter {  implicit def dataFrame2JdbcWriter(dataFrame: DataFrame): JdbcDataFrameWriter = JdbcDataFrameWriter(dataFrame)  def apply(dataFrame: DataFrame): JdbcDataFrameWriter = new JdbcDataFrameWriter(dataFrame)}

测试用例:

 implicit def map2Prop(map: Map[String, String]): Properties = map.foldLeft(new Properties) {    case (prop, kv) ⇒ prop.put(kv._1, kv._2); prop  }    val sparkContext = new SparkContext(sparkConf)    val sqlContext = new SQLContext(sparkContext)    //    val hiveContext = new HiveContext(sparkContext)    //    import hiveContext.implicits._    import sqlContext.implicits._    val df = sparkContext.parallelize(Seq(      (1, 1, "2", "ctccct", "286"),      (2, 2, "2", "ccc", "11"),      (4, 10, "2", "ccct", "12")    )).toDF("id", "iid", "uid", "name", "age")    val jdbcSaveExplain = JdbcSaveExplain(      "test",      "jdbc:mysql://localhost:3306/test",      "mytest",      JdbcSaveMode.Update,      Map("user""user", "password""password")    )    import JdbcDataFrameWriter.dataFrame2JdbcWriter    df.writeJdbc(jdbcSaveExplain).save()

mygithub

0 0
原创粉丝点击