spark的JdbcRDD的源码修改--创建JdbcRDD时可以不加条件进行查询

来源:互联网 发布:telnet的88端口 编辑:程序博客网 时间:2024/06/06 11:39

在我们使用JdbcRDD时系统默认的参数如下:

sc: SparkContext,getConnection: () => Connection,sql: String,lowerBound: Long,upperBound: Long,numPartitions: Int,mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _

根据其注释的说明:

select title, author from books where ? <= id and id <= ?

* @param lowerBound the minimum value of the first placeholder* @param upperBound the maximum value of the second placeholder*   The lower and upper bounds are inclusive.* @param numPartitions the number of partitions.*   Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,*   the query would be executed twice, once with (1, 10) and once with (11, 20)

由上上面的内容可以发现,JdbcRDD中的主构造函数中这几个参数是必不可少的,且没有辅助构造函数可以使用,于是我们在查询时就不得不输入上下界,即必须输入有查询条件的sql,然后以参数的形式传入JdbcRDD的主构造函数中。我们在实际的使用中,或者在测试中,我们需要不带参数进行使用就显得无能为力,为此,我们该如何做呢?

方法可能有很多,对我们来说,简单的实现由两种方式,即自己实现JdbcRDD和继承JdbcRDD,自己定义辅构造函数。本文只实现自己重新定义JdbcRDD,降低程序的耦合度。

通过查看JdbcRDD的源码发现,其实,

lowerBound 用于定义查询的下标
upperBound 用于定义查询的上标
numPartitions 用于定义查询的分区数
这三个参数在实际的生产环境中,可能很有用,通过该三个参数定义每个分区查询的范数据围,这也是spark人员设计时一定加上该参数的原因。
说明:
本例仅仅是简单的去掉该三个参数,需要知道的是方式不止这一种,且由于把分区参数去掉了,本代码默认的是一个分区,可以在代码中手动的设置多个分区。
修改JdbcRDD的源码,同时需要修改有NextIterator.scala(其实只是挪一下位置,源码不动的挪过来,由于源码是spark包下的private,所以不能引用在其他的包内),该文件就不再粘贴出来。
修改后的JdbcRDD.scala改名为JDBCRDD.scala,NextIterator.scala放在与JDBCRDD.scala同一个包内。
以下是JDBCRDD.scala源码
import java.sql.{Connection, ResultSet}import scala.reflect.ClassTagimport org.apache.spark.{Partition, SparkContext, TaskContext}import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}import org.apache.spark.api.java.function.{Function => JFunction}import org.apache.spark.internal.Loggingimport org.apache.spark.rdd.RDD/**  * Created by Administrator on 2017/9/8.  */class JDBCPartition(idx: Int) extends Partition {  override def index: Int = idx}class JDBCRDD[T: ClassTag](                            sc: SparkContext,                            getConnection: () => Connection,                            sql: String,                            mapRow: (ResultSet) => T = JDBCRDD.resultSetToObjectArray _)  extends RDD[T](sc, Nil) with Logging {  override def getPartitions: Array[Partition] = {    (0 to 1).map { i => new JDBCPartition(i) }.toArray  }  override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] {    context.addTaskCompletionListener { context => closeIfNeeded() }    val part = thePart.asInstanceOf[JDBCPartition]    val conn = getConnection()    val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)    val url = conn.getMetaData.getURL    if (url.startsWith("jdbc:mysql:")) {      stmt.setFetchSize(Integer.MIN_VALUE)    } else {      stmt.setFetchSize(100)    }    logInfo(s"statement fetch size set to: ${stmt.getFetchSize}")    val rs = stmt.executeQuery()    override def getNext(): T = {      if (rs.next()) {        mapRow(rs)      } else {        finished = true        null.asInstanceOf[T]      }    }    override def close() {      try {        if (null != rs) {          rs.close()        }      } catch {        case e: Exception => logWarning("Exception closing resultset", e)      }      try {        if (null != stmt) {          stmt.close()        }      } catch {        case e: Exception => logWarning("Exception closing statement", e)      }      try {        if (null != conn) {          conn.close()        }        logInfo("closed connection")      } catch {        case e: Exception => logWarning("Exception closing connection", e)      }    }  }}object JDBCRDD {  def resultSetToObjectArray(rs: ResultSet): Array[Object] = {    Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))  }  trait ConnectionFactory extends Serializable {    @throws[Exception]    def getConnection: Connection  }  def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]]  def create[T](                 sc: JavaSparkContext,                 connectionFactory: ConnectionFactory,                 sql: String,                 mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {    val JDBCRDD = new JDBCRDD[T](      sc.sc,      () => connectionFactory.getConnection,      sql,      (resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)    new JavaRDD[T](JDBCRDD)(fakeClassTag)  }  def create(              sc: JavaSparkContext,              connectionFactory: ConnectionFactory,              sql: String            ): JavaRDD[Array[Object]] = {    val mapRow = new JFunction[ResultSet, Array[Object]] {      override def call(resultSet: ResultSet): Array[Object] = {        resultSetToObjectArray(resultSet)      }    }    create(sc, connectionFactory, sql, mapRow)  }}

以下是测试上面JDBCRDD.scala的例子

import java.sql.DriverManagerimport org.apache.spark.{SparkConf, SparkContext}/**  * Created by Administrator on 2017/9/8.  */object TestJDBC {  def main(args: Array[String]): Unit = {    val conf = new SparkConf().setAppName("TestJDBC").setMaster("local[2]")    val sc = new SparkContext(conf)    try {      val connection = () => {        Class.forName("com.mysql.jdbc.Driver").newInstance()        DriverManager.getConnection("jdbc:mysql://192.168.0.4:3306/spark", "root", "root")      }      val JDBCRDD = new JDBCRDD(        sc,        connection,        "SELECT * FROM result",        r => {          val id = r.getInt(1)          val code = r.getString(2)          (id, code)        }      )      val jrdd = JDBCRDD.collect()      println(JDBCRDD.collect().toBuffer)      sc.stop()    }    catch {      case e: Exception => println(e.printStackTrace())    }  }}

简单的修改JdbcRDD的源码到此就完成了。希望对你有用。