hive UDAF函数

来源:互联网 发布:linux抓包命令 编辑:程序博客网 时间:2024/06/11 01:57
package com.ymdd;  
  
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;


/**
 * 场景是对一个字段b列数据,超过阈值时就计数
 * 
 * */


//AbstractGenericUDAFResolverz主要类型检查
public class HiveUdaf extends AbstractGenericUDAFResolver {


public GenericUDAFEvaluator getEvaluator(TypeInfo[] info)
throws SemanticException {
// TODO Auto-generated method stub
//判断输入参数长度
if (info.length != 2)
{
throw new UDFArgumentTypeException(info.length-1, "please input two paramters");
}
return new genericEvaluate();
}


//内部类做逻辑运算  init->getNewAggregationBuffer->iterate->terminate->terminatePartial->merge

public static class genericEvaluate extends GenericUDAFEvaluator
{
private LongWritable result ;

//基本的描述对象的类型
        private PrimitiveObjectInspector input01;
        private PrimitiveObjectInspector input02;
        
// Mode类经历的四个过程partial1,partial2,final,complete,4个部分对应下面的方法
        //init在map和redurce只初始化1次
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
// TODO Auto-generated method stub
super.init(m, parameters);

//最后返回一个计数器,LongWritable是一个序列化的类
result = new LongWritable(0);
input01 = (PrimitiveObjectInspector) parameters[0];

//在redurce阶段时只有一个,所有不做if会报边界溢出
if (parameters.length > 1)
{
input02=(PrimitiveObjectInspector) parameters[1];
}

//告诉返回类型一个为Long类型
return PrimitiveObjectInspectorFactory.writableLongObjectInspector;

}
          
//缓存保存数据,可能是一个计数值器,每个map执行1次
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
// TODO Auto-generated method stub
//map端初始化一个agg类
CountAgg agg = new CountAgg();
//重置agg里的count为0
reset(agg);
return agg;
}
        
//读hive里原始数据,map阶段操作,redurce没有使用此方法
public void iterate(AggregationBuffer agg, Object[] par)
throws HiveException {
// TODO Auto-generated method stub
assert(par.length==2);

if (par == null || par[0] == null || par[1] == null)
{
return;
}

double base = PrimitiveObjectInspectorUtils.getDouble(par[0], input01);
double tmp =  PrimitiveObjectInspectorUtils.getDouble(par[1], input02);

//假设 35 > 30
if (base > tmp)
{
((CountAgg)agg).count++;
}

}


//判断
private void assert(boolean b) {
// TODO Auto-generated method stub
if(b)
{
System.out.println("OK");
}else
{
System.out.println("ERR");
}

}


//合并数据,如果partial不为空,则是有多个partial需要合并数据
public void merge(AggregationBuffer agg, Object partial)
throws HiveException {
// TODO Auto-generated method stub
if(partial != null)
{
long p = PrimitiveObjectInspectorUtils.getLong(partial, input01);
((CountAgg) agg).count += p ;
}
}
        
//重置数据空间和缓存
public void reset(AggregationBuffer countAgg) throws HiveException {
// TODO Auto-generated method stub
CountAgg agg = (CountAgg) countAgg;
agg.count=0;
}


public Object terminate(AggregationBuffer agg) throws HiveException {
// TODO Auto-generated method stub
result.set(((CountAgg) agg).count);
return result;
}
      
//返回数据,map端的部分数据
public Object terminatePartial(AggregationBuffer agg)
throws HiveException {
// TODO Auto-generated method stub
result.set(((CountAgg) agg).count);
return result;
}

// 自定义一个CountAgg类型
public static class CountAgg implements AggregationBuffer
{
long count;
}

}
  
}  
原创粉丝点击