Spark编写UDAF自定义函数(JAVA)

来源:互联网 发布:libreoffice知乎 编辑:程序博客网 时间:2024/06/06 03:49

maven:

<!-- spark --><dependency>    <groupId>org.apache.spark</groupId>    <artifactId>spark-core_2.10</artifactId>    <version>1.6.0</version></dependency><dependency>    <groupId>org.apache.spark</groupId>    <artifactId>spark-sql_2.10</artifactId>    <version>1.6.0</version></dependency><dependency>    <groupId>org.apache.spark</groupId>    <artifactId>spark-hive_2.10</artifactId>    <version>1.6.0</version></dependency><!-- google工具类 --><dependency>    <groupId>com.google.guava</groupId>    <artifactId>guava</artifactId>    <version>18.0</version></dependency>

public class StringCount extends UserDefinedAggregateFunction {    /**     * inputSchema指的是输入的数据类型     * @return     */    @Override    public StructType inputSchema() {        List<StructField> fields = Lists.newArrayList();        fields.add(DataTypes.createStructField("str", DataTypes.StringType,true));        return DataTypes.createStructType(fields);    }    /**     * bufferSchema指的是  中间进行聚合时  所处理的数据类型     * @return     */    @Override    public StructType bufferSchema() {        List<StructField> fields = Lists.newArrayList();        fields.add(DataTypes.createStructField("count", DataTypes.IntegerType,true));        return DataTypes.createStructType(fields);    }    /**     * dataType指的是函数返回值的类型     * @return     */    @Override    public DataType dataType() {        return DataTypes.IntegerType;    }    /**     * 一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的。     * @return     */    @Override    public boolean deterministic() {        return true;    }    /**     * 设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer     * 即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2     * 不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"     * @param buffer     */    @Override    public void initialize(MutableAggregationBuffer buffer) {        buffer.update(0,0);    }    /**     * 用输入数据input更新buffer,类似于combineByKey     * @param buffer     * @param input     */    @Override    public void update(MutableAggregationBuffer buffer, Row input) {        buffer.update(0,Integer.valueOf(buffer.getAs(0).toString())+1);    }    /**     * 合并两个buffer,buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey     * 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节     * @param buffer1     * @param buffer2     */    @Override    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {        buffer1.update(0,Integer.valueOf(buffer1.getAs(0).toString())+Integer.valueOf(buffer2.getAs(0).toString()));    }    /**     * 计算并返回最终的聚合结果     * @param buffer     * @return     */    @Override    public Object evaluate(Row buffer) {        return buffer.getInt(0);    }}
public class UDAF {    public static void main(String[] args) {        SparkConf conf = new SparkConf().setAppName("UDAF").setMaster("local");        JavaSparkContext sc = new JavaSparkContext(conf);        SQLContext sqlContext = new SQLContext(sc);        List<String> nameList = Arrays.asList("xiaoming","xiaoming", "feifei","feifei","feifei", "katong");        //转换为javaRDD        JavaRDD<String> nameRDD = sc.parallelize(nameList, 3);        //转换为JavaRDD<Row>        JavaRDD<Row> nameRowRDD = nameRDD.map(new Function<String, Row>() {            public Row call(String name) throws Exception {                return RowFactory.create(name);            }        });        List<StructField> fields = Lists.newArrayList();        fields.add(DataTypes.createStructField("name", DataTypes.StringType,true));        StructType structType = DataTypes.createStructType(fields);        DataFrame namesDF = sqlContext.createDataFrame(nameRowRDD, structType);        //注册names        namesDF.registerTempTable("names");        sqlContext.udf().register("countString",new StringCount());        List<Row> rows = sqlContext.sql("select name,countString(name) from names group by name").javaRDD().collect();        for (Row row : rows) {            System.out.println(row);        }        sc.close();    }}
执行结果:

[feifei,3]
[xiaoming,2]
[katong,1]

1 0
原创粉丝点击