SparkSQL的UDF和UDAF

来源:互联网 发布:苏州聚合数据招聘 编辑:程序博客网 时间:2024/06/03 17:14

1.UDF

注:以下的SparkSQL初始化方式不是最新的,请参考上篇博客进行修改

import org.apache.spark.{SparkConf, SparkContext}import org.apache.spark.sql.hive.HiveContextobject UDFTest {  def main(args: Array[String]): Unit = {     val conf = new SparkConf().setAppName("udf").setMaster("local")     val sc = new SparkContext(conf)    val hiveSQLContext = new HiveContext(sc)    hiveSQLContext.udf.register("toUpper",name =>{      if (name!=null){        name.toString.toUpperCase      }else{        " "      }    })    hiveSQLContext.udf.register("strLength",name=>{      if(name!=null){        name.toString.length      }else{        0      }    })    hiveSQLContext.sql("select toUpper(name) from student")    hiveSQLContext.sql("select strLength(name) from student")  }}
2.UDAF

(1)1.6.0版本

package lesson02import org.apache.spark.sql.{Row, types}import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}import org.apache.spark.{SparkConf, SparkContext}import org.apache.spark.sql.hive.HiveContextimport org.apache.spark.sql.types._/**  *  * 需求:  *     需要求这家公司的所有员工的平均工资  * 思路:  *    1)先求出所有员工的工资 countSalary  *    2)求出员工的总数   count  *    3) 平均工资countSalary / count  */object UDAFTest extends UserDefinedAggregateFunction{  def main(args: Array[String]): Unit = {    val conf = new SparkConf().setAppName("udf").setMaster("local")    val sc = new SparkContext(conf)    val hiveSQLContext = new HiveContext(sc)    hiveSQLContext.udf.register("avg_salary",UDAFTest)    hiveSQLContext.sql("select avg_salary(salary) from worker")  }   //定义输入的数据类型  override def inputSchema: StructType = StructType(    StructField("salary",DoubleType,true)::Nil  )  //定义输出的数据类型  override def dataType: DataType = DoubleType  /*    * 一般我们要完成聚合函数的功能,需要一些中间变量来帮忙完成。    * 然后可以在这儿去定时临时的缓存变量    * 根据我们的分析需要定义两个;    * countSalay: 用来记录所有员工的总工资    * count:用来统计总人数的    */  override def bufferSchema: StructType = StructType{    StructField("countSalary",DoubleType,true)::      StructField("count",IntegerType,true)::Nil  }  //给参与计算的中间变量赋初始值  override def initialize(buffer: MutableAggregationBuffer): Unit = {    buffer(0,0.0)    buffer(1,0)  }     /* * 修改 中间的结果值    * @param buffer  上一次    * @param input  这次    * */      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {    val countSalary = buffer.getDouble(0)    val count = buffer.getInt(1)    val salary = input.getDouble(0)    buffer(0,salary+countSalary)    buffer(1,1+count)  }  // 全部汇总  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {    val countSalary1 = buffer1.getDouble(0)    val count1 = buffer1.getInt(1)    val countSalary2 = buffer2.getDouble(0)    val count2 = buffer2.getInt(1)    buffer1(0,countSalary1+countSalary2)    buffer1(1,count1+count2)  }  //获取最后的结果值  override def evaluate(buffer: Row): Any = {    val countSalary = buffer.getDouble(0)    val count = buffer.getInt(1)    countSalary / count  }  override def deterministic: Boolean = true}


(2)2.2.0版本

import org.apache.spark.sql.expressions.MutableAggregationBufferimport org.apache.spark.sql.expressions.UserDefinedAggregateFunctionimport org.apache.spark.sql.types._import org.apache.spark.sql.Rowimport org.apache.spark.sql.SparkSessionobject MyAverage extends UserDefinedAggregateFunction {  // Data types of input arguments of this aggregate function  def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)  // Data types of values in the aggregation buffer  def bufferSchema: StructType = {    StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)  }  // The data type of the returned value  def dataType: DataType = DoubleType  // Whether this function always returns the same output on the identical input  def deterministic: Boolean = true  // Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to  // standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides  // the opportunity to update its values. Note that arrays and maps inside the buffer are still  // immutable.  def initialize(buffer: MutableAggregationBuffer): Unit = {    buffer(0) = 0L    buffer(1) = 0L  }  // Updates the given aggregation buffer `buffer` with new input data from `input`  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {    if (!input.isNullAt(0)) {      buffer(0) = buffer.getLong(0) + input.getLong(0)      buffer(1) = buffer.getLong(1) + 1    }  }  // Merges two aggregation buffers and stores the updated buffer values back to `buffer1`  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)  }  // Calculates the final result  def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)}// Register the function to access itspark.udf.register("myAverage", MyAverage)val df = spark.read.json("examples/src/main/resources/employees.json")df.createOrReplaceTempView("employees")df.show()// +-------+------+// |   name|salary|// +-------+------+// |Michael|  3000|// |   Andy|  4500|// | Justin|  3500|// |  Berta|  4000|// +-------+------+val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")result.show()// +--------------+// |average_salary|// +--------------+// |        3750.0|// +--------------+


原创粉丝点击