Spark外部数据源demo

来源:互联网 发布:手机登录淘宝网页版 编辑:程序博客网 时间:2024/05/20 21:59

一、创建Relation

package com.spark.datasource.demo;import org.apache.spark.sql.sources._import org.apache.spark.sql._import org.apache.spark.sql.types._import org.apache.spark.rdd.RDDimport java.sql.{ DriverManager, ResultSet }import org.apache.spark.sql.{ Row, SQLContext }import scala.collection.mutable.ArrayBufferimport org.slf4j.LoggerFactoryimport java.io._import org.apache.hadoop.conf.Configurationimport org.apache.hadoop.fs.Pathimport scala.collection.JavaConversions._/** * implement user define dataSource need steps * 1.1 create DefaultSource extends RelationProvider . *          class name must be DefaultSource * 1.2 implement user define Relation *          Relation support  4 scanning strategies *        <1> full table scan  ,  need extend TableScan *        <2> column scan , need extend PrunedScan *        <3> column scan + filter row , need extend PrunedFilterScan *        <4> CatalystScan * 1.3 implement user define RDD * 1.4 implement user define RddPatertion * 1.5 implement user define RDD Iterator */class DefaultSource extends RelationProvider    with SchemaRelationProvider with CreatableRelationProvider {  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {    createRelation(sqlContext, parameters, null)  }  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = {    return MyRelation(parameters, schema)(sqlContext)  }  override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {    createRelation(sqlContext, parameters, data.schema)  }}case class MyRelation(@transient val parameters: Map[String, String],  @transient userSchema: StructType)(@transient val sqlContext: SQLContext)    extends BaseRelation with TableScan with PrunedScan with PrunedFilteredScan with Serializable {  private val logger = LoggerFactory.getLogger(getClass)  private val sparkContext = sqlContext.sparkContext  def printStackTraceStr(e: Exception, data: String) = {    val sw: StringWriter = new StringWriter()    val pw: PrintWriter = new PrintWriter(sw)    e.printStackTrace(pw)    println("======>>printStackTraceStr Exception: " + e.getClass() + "\n==>" + sw.toString() + "\n==>data=" + data)  }  override def schema: StructType = {    if (this.userSchema != null) {      return this.userSchema    } else {      return StructType(Seq(StructField("data", IntegerType)))    }  }  override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {    logger.info("unhandledFilters  with filters " + filters.toList)    // unhandled function return true  spark deal with filter    // otherwise data source deal with    def unhandled(filter: Filter): Boolean = {      filter match {        case EqualTo(col, v) => {          println("EqualTo col is :" + col + " value is :" + v)          true        }        case _ => true      }    }    filters.filter(unhandled)  }  override def buildScan(): RDD[Row] = {     logger.info("Table Scan buildScan ")     return new MyRDD[Row](sparkContext)  }  override def buildScan(requiredColumns: Array[String]): RDD[Row] = {    logger.info("pruned build scan for columns " + requiredColumns.toList)    return new MyRDD[Row](sparkContext)  }  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {    logger.info("prunedfilteredScan build scan for columns " + requiredColumns.toList + "with filters " + filters.toList)    return new MyRDD[Row](sparkContext)  }}

二、创建RDD和Partition

package com.spark.datasource.demoimport org.apache.hadoop.fs.Pathimport org.apache.hadoop.conf.Configurationimport org.apache.hadoop.mapred.FileSplitimport org.apache.hadoop.mapred.Reporterimport org.apache.spark._import org.apache.spark.rdd._import org.apache.spark.util.NextIteratorimport scala.reflect.ClassTagimport org.apache.spark.sql.{ Row, SQLContext }import org.slf4j.LoggerFactoryimport org.apache.spark.sql.types._import scala.collection.JavaConversions._case class MyPartition(index: Int) extends Partition {}class MyRDD[T: ClassTag](    @transient private val _sc: SparkContext) extends RDD[T](_sc, Nil) {  private val logger = LoggerFactory.getLogger(getClass)  override def compute(split: Partition, context: TaskContext): Iterator[T] = {    logger.warn("call  MyRDD compute  function ")    val currSplit = split.asInstanceOf[MyPartition]    new MyIterator(currSplit,context)  }  override protected def getPartitions: Array[Partition] = {    logger.warn("call  MyRDD getPartitions  function ")    val partitions = new Array[Partition](1)    partitions(0) = new MyPartition(1)    partitions  }  override protected def getPreferredLocations(split: Partition): Seq[String] = {    logger.warn("call MyRDD getPreferredLocations  function")    val currSplit = split.asInstanceOf[MyPartition]    Seq("localhost")  }}

三、创建Iterator

package com.spark.datasource.demoimport org.apache.spark._import org.apache.spark.rdd._import org.apache.spark.util.NextIteratorimport scala.reflect.ClassTagimport org.apache.spark.sql.{ Row, SQLContext }import org.slf4j.LoggerFactoryimport org.apache.spark.sql.types._import java.io._class MyIterator[T: ClassTag](    split: MyPartition,    context: TaskContext) extends Iterator[T] {  private val logger = LoggerFactory.getLogger(getClass)  private val currSplit = split.asInstanceOf[MyPartition]  private var index = 0 ;  override def hasNext: Boolean = {       if(index == 1) {         return false       }       index = index + 1        return true  }  override def next(): T = {    val r = Row(100000)    r.asInstanceOf[T]  }}

四、Eclipse截图

这里写图片描述

五、SBT目录结构

这里写图片描述

build.sbt代码name := "SparkDataSourceDemo"version := "0.1"organization := "com.spark.datasource.demo"scalaVersion := "2.10.4"libraryDependencies += "org.apache.spark" %% "spark-sql" % "1.6.0" % "provided"resolvers += "Spark Staging Repository" at "https://repository.apache.org/content/repositories/orgapachespark-1038/"publishMavenStyle := truepublishTo := {  val nexus = "https://oss.sonatype.org/"  if (version.value.endsWith("SNAPSHOT"))    Some("snapshots" at nexus + "content/repositories/snapshots")  else    Some("releases"  at nexus + "service/local/staging/deploy/maven2")}

六、SBT打包命令

  1. 在build.sbt同目录下执行
    /usr/local/sbt/sbt package

七、测试运行

1.在build.sbt同目录下执行
/usr/local/spark/bin/spark-sql –jars target/scala-2.10/xclouddatasourcespark_2.10-0.1.jar

2.创建表语句
CREATE TEMPORARY TABLE test USING com.spark.datasource.demo OPTIONS ();

select * from test;

原创粉丝点击