Hadoop/MapReduce、Spark 朴素贝叶斯分类器分类符号数据

来源:互联网 发布:mac桌面显示磁盘 编辑:程序博客网 时间:2024/06/05 19:59





package cjbayesclassfier;import java.io.IOException;import org.apache.hadoop.conf.Configuration;import org.apache.hadoop.conf.Configured;import org.apache.hadoop.fs.FileSystem;import org.apache.hadoop.fs.Path;import org.apache.hadoop.io.LongWritable;import org.apache.hadoop.io.NullWritable;import org.apache.hadoop.io.Text;import org.apache.hadoop.mapreduce.lib.output.MultipleOutputs; import org.apache.hadoop.mapreduce.Job;import org.apache.hadoop.mapreduce.Mapper;import org.apache.hadoop.mapreduce.Reducer;import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;import org.apache.hadoop.util.Tool;import org.apache.hadoop.util.ToolRunner;import edu.umd.cloud9.io.pair.PairOfStrings;/*** * 第一步:拆分输入文件每一行,得到类输出和条件概率输出 * @author chenjie */public class CJBayesClassfier_Step1 extends Configured implements Tool  {    /***     * 映射器:     * 输入:weather.txt     * 其中一行示例如下:Sunny,Hot,High,Weak,No     * 输出:     * key                  value     * (Sunny,No)       1     * (Hot,No)            1     * (High,No)          1     * (Weak,No)        1     * (CLASS,No)       1     * @author chenjie     */    public static class CJBayesClassfierMapper extends Mapper<LongWritable, Text, PairOfStrings, LongWritable>    {        PairOfStrings outputKey = new PairOfStrings();        LongWritable outputValue = new LongWritable(1);        @Override        protected void map(                LongWritable key,                Text value,                Context context)                throws IOException, InterruptedException {               String tokens[] = value.toString().split(",");               if(tokens == null || tokens.length < 2)                   return;               String classfier = tokens[tokens.length-1];               for(int i = 0; i < tokens.length; i++)               {                   if(i < tokens.length-1)                       outputKey.set(tokens[i], classfier);                   else                       outputKey.set("CLASS", classfier);                   context.write(outputKey, outputValue);               }           }        }        @Deprecated    public static class CJBayesClassfierReducer extends Reducer<PairOfStrings, LongWritable, PairOfStrings, LongWritable>    {        @Override        protected void reduce(                PairOfStrings key,                Iterable<LongWritable> values,                Context context)                throws IOException, InterruptedException {            Long sum = 0L;            for(LongWritable time : values)            {                sum +=  time.get();            }            context.write(key, new LongWritable(sum));        }    }    public static class CJBayesClassfierReducer2 extends Reducer<PairOfStrings, LongWritable, PairOfStrings, Text>    {        /**          * 设置多个文件输出          * */        private MultipleOutputs<PairOfStrings, Text> mos;              @Override        protected void setup(Context context)        throws IOException, InterruptedException {          mos=new MultipleOutputs<PairOfStrings, Text>(context);//初始化mos        }               /***        * 将key值相同的value进行累加        */        @Override        protected void reduce(                PairOfStrings key,                Iterable<LongWritable> values,                Context context)                throws IOException, InterruptedException {            System.out.println("key =" + key );            Long sum = 0L;            for(LongWritable time : values)            {                sum +=  time.get();            }           String result = key.getLeftElement() + "," + key.getRightElement() + "," + sum;            if(key.getLeftElement().equals("CLASS"))                mos.write("CLASS",  NullWritable.get(), new Text(result));            else                mos.write("OTHERS", NullWritable.get(), new Text(result));        }                /***         * 务必释放资源,否则不会有输出内容         */        @Override         protected void cleanup(         Context context)         throws IOException, InterruptedException {         mos.close();//释放资源         }     }    public static void main(String[] args) throws Exception    {        args = new String[2];        args[0] = "/media/chenjie/0009418200012FF3/ubuntu/weather.txt";        args[1] = "/media/chenjie/0009418200012FF3/ubuntu/CJBayes";;        int jobStatus = submitJob(args);        System.exit(jobStatus);    }        public static int submitJob(String[] args) throws Exception {        int jobStatus = ToolRunner.run(new CJBayesClassfier_Step1(), args);        return jobStatus;    }    @SuppressWarnings("deprecation")    @Override    public int run(String[] args) throws Exception {        Configuration conf = getConf();        Job job = new Job(conf);        job.setJobName("Bayes");        MultipleOutputs.addNamedOutput(job, "CLASS", TextOutputFormat.class, Text.class, Text.class);         MultipleOutputs.addNamedOutput(job, "OTHERS", TextOutputFormat.class, Text.class, Text.class);                 job.setInputFormatClass(TextInputFormat.class);        job.setOutputFormatClass(TextOutputFormat.class);                job.setOutputKeyClass(PairOfStrings.class);               job.setOutputValueClass(LongWritable.class);                             job.setMapperClass(CJBayesClassfierMapper.class);        job.setReducerClass(CJBayesClassfierReducer2.class);        FileInputFormat.setInputPaths(job, new Path(args[0]));        FileOutputFormat.setOutputPath(job, new Path(args[1]));                FileSystem fs = FileSystem.get(conf);        Path outPath = new Path(args[1]);        if(fs.exists(outPath))        {            fs.delete(outPath, true);        }                boolean status = job.waitForCompletion(true);        return status ? 0 : 1;    }        }


package cjbayesclassfier;import java.io.BufferedReader;import java.io.FileReader;import java.io.IOException;import java.net.URI;import java.util.HashMap;import java.util.Map;import org.apache.hadoop.conf.Configuration;import org.apache.hadoop.conf.Configured;import org.apache.hadoop.fs.FileSystem;import org.apache.hadoop.fs.Path;import org.apache.hadoop.io.DoubleWritable;import org.apache.hadoop.io.LongWritable;import org.apache.hadoop.io.NullWritable;import org.apache.hadoop.io.Text;import org.apache.hadoop.mapreduce.Job;import org.apache.hadoop.mapreduce.Mapper;import org.apache.hadoop.mapreduce.Reducer;import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;import org.apache.hadoop.util.Tool;import org.apache.hadoop.util.ToolRunner;import edu.umd.cloud9.io.pair.PairOfStrings;/*** * 第二步:计算概率 * @author chenjie * */public class CJBayesClassfier_Step2 extends Configured implements Tool  {    public static class CJBayesClassfierMapper2 extends Mapper<LongWritable, Text, PairOfStrings, DoubleWritable>    {        PairOfStrings outputKey = new PairOfStrings();        DoubleWritable outputValue = new DoubleWritable(1);        private Map<String,Integer> classMap = new HashMap<String,Integer>();        @Override        protected void setup(Context context) throws IOException, InterruptedException {            FileReader fr = new FileReader("CLASS");            BufferedReader br = new BufferedReader(fr);            String line = null;            while((line = br.readLine()) != null)            {                String tokens[] = line.split(",");                String classfier = tokens[1];                String count = tokens[2];                classMap.put(classfier, Integer.parseInt(count));            }            fr.close();            br.close();            int sum = 0;            for(Map.Entry<String,Integer> entry : classMap.entrySet())            {                sum += entry.getValue();            }            for(Map.Entry<String,Integer> entry : classMap.entrySet())            {                double poss = entry.getValue() * 1.0 / sum;                context.write(new PairOfStrings("CLASS", entry.getKey()), new DoubleWritable(poss));            }        }                @Override        protected void map(                LongWritable key,                Text value,                Context context)                throws IOException, InterruptedException {               String tokens[] = value.toString().split(",");               if(tokens == null || tokens.length < 3)                   return;               String X = tokens[0];               String classfier = tokens[1];               Integer count = Integer.valueOf(tokens[2]);               outputKey.set(X, classfier);               Integer classCount = classMap.get(classfier);               outputValue.set(count * 1.0 / classCount);               context.write(outputKey, outputValue);           }        }        public static class CJBayesClassfierReducer2 extends Reducer<PairOfStrings, DoubleWritable, NullWritable, Text>    {        @Override        protected void reduce(                PairOfStrings key,                Iterable<DoubleWritable> values,                Context context)                throws IOException, InterruptedException {            for(DoubleWritable dw : values)                context.write(NullWritable.get(), new Text(key.getLeftElement() + "," + key.getRightElement() + "," + dw));        }    }    public static void main(String[] args) throws Exception    {        args = new String[2];        args[0] = "/media/chenjie/0009418200012FF3/ubuntu/CJBayes/OTHERS-r-00000";        args[1] = "/media/chenjie/0009418200012FF3/ubuntu/CJBayes2";        int jobStatus = submitJob(args);        System.exit(jobStatus);    }        public static int submitJob(String[] args) throws Exception {        int jobStatus = ToolRunner.run(new CJBayesClassfier_Step2(), args);        return jobStatus;    }    @SuppressWarnings("deprecation")    @Override    public int run(String[] args) throws Exception {        Configuration conf = getConf();        Job job = new Job(conf);        job.setJobName("Bayes");        job.addCacheArchive(new URI("/media/chenjie/0009418200012FF3/ubuntu/CJBayes/CLASS-r-00000" + "#CLASS"));                job.setInputFormatClass(TextInputFormat.class);        job.setOutputFormatClass(TextOutputFormat.class);                job.setOutputKeyClass(PairOfStrings.class);               job.setOutputValueClass(DoubleWritable.class);                             job.setMapperClass(CJBayesClassfierMapper2.class);        job.setReducerClass(CJBayesClassfierReducer2.class);        FileInputFormat.setInputPaths(job, new Path(args[0]));        FileOutputFormat.setOutputPath(job, new Path(args[1]));                FileSystem fs = FileSystem.get(conf);        Path outPath = new Path(args[1]);        if(fs.exists(outPath))        {            fs.delete(outPath, true);        }                boolean status = job.waitForCompletion(true);        return status ? 0 : 1;    }        }


package cjbayesclassfier;import java.io.BufferedReader;import java.io.FileNotFoundException;import java.io.FileReader;import java.io.IOException;import java.net.URI;import java.util.ArrayList;import java.util.List;import org.apache.hadoop.conf.Configuration;import org.apache.hadoop.conf.Configured;import org.apache.hadoop.fs.FileSystem;import org.apache.hadoop.fs.Path;import org.apache.hadoop.io.DoubleWritable;import org.apache.hadoop.io.LongWritable;import org.apache.hadoop.io.NullWritable;import org.apache.hadoop.io.Text;import org.apache.hadoop.mapreduce.Job;import org.apache.hadoop.mapreduce.Mapper;import org.apache.hadoop.mapreduce.Reducer;import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;import org.apache.hadoop.util.Tool;import org.apache.hadoop.util.ToolRunner;import edu.umd.cloud9.io.pair.PairOfStrings;/*** * 第三步:根据上一步计算的概率进行贝叶斯推断 * @author chenjie * */public class CJBayesClassfier_Step3 extends Configured implements Tool  {    public static class CJBayesClassfierMapper3 extends Mapper<LongWritable, Text, Text, LongWritable>    {        LongWritable outputValue = new LongWritable(1);        @Override        protected void map(                LongWritable key,                Text value,                Context context)                throws IOException, InterruptedException {               context.write(value, outputValue);           }        }        public static class CJBayesClassfierReducer3 extends Reducer<Text, LongWritable, Text, Text>    {        private List<String> classfications;        @Override        protected void setup(                Reducer<Text, LongWritable, Text, Text>.Context context)                throws IOException, InterruptedException {            classfications = buildClassfications();            for(String classfication : classfications)            {                System.out.println("分类:" + classfication);            }            buildCJGLTable();            CJGLTable.show();        }                        private List<String> buildClassfications() throws IOException {            List<String> list = new ArrayList<String>();            FileReader fr = new FileReader("CLASS");            BufferedReader br = new BufferedReader(fr);            String line = null;            while((line = br.readLine()) != null)            {                String tokens[] = line.split(",");                String classfier = tokens[1];                list.add(classfier);            }            fr.close();            br.close();            return list;        }                private void buildCJGLTable() throws IOException {            FileReader fr = new FileReader("GL");            BufferedReader br = new BufferedReader(fr);            String line = null;            while((line = br.readLine()) != null)            {                String tokens[] = line.split(",");                PairOfStrings key  = new PairOfStrings(tokens[0],tokens[1]);                CJGLTable.add(key, Double.valueOf(tokens[2]));            }            fr.close();            br.close();        }        @Override        protected void reduce(                Text key,                Iterable<LongWritable> values,                Context context)                throws IOException, InterruptedException {           System.out.println("key=" + key);           System.out.println("values:");           for(LongWritable lw : values)           {               System.out.println(lw);           }           String [] attributes = key.toString().split(",");           String selectedClass = null;           double maxPosterior = 0.0;           for(String aClass : classfications)           {               System.out.println("对于类别:" + aClass);               double posterior = CJGLTable.getClassGL(aClass);               System.out.println("其概率为:" + posterior);               for(String attr : attributes)               {                   System.out.println("\t对于条件:"  + attr);                   double conGL = CJGLTable.getConditionalGL(attr, aClass);                   System.out.println("\t其概率为:" + conGL);                   posterior *= CJGLTable.getConditionalGL(attr, aClass);               }                              if(selectedClass == null)               {                   selectedClass = aClass;                   maxPosterior = posterior;               }               else               {                   if(posterior > maxPosterior)                   {                       selectedClass = aClass;                       maxPosterior = posterior;                   }               }               context.write(key, new Text("贝叶斯分类:" + selectedClass + ",其概率为" + maxPosterior));           }           context.write(key, new Text("最终结果:贝叶斯分类为" + selectedClass + ",其概率为" + maxPosterior));        }    }    public static void main(String[] args) throws Exception    {        args = new String[2];        args[0] = "/media/chenjie/0009418200012FF3/ubuntu/weather_predict.txt";        args[1] = "/media/chenjie/0009418200012FF3/ubuntu/CJBayes3";        int jobStatus = submitJob(args);        System.exit(jobStatus);    }        public static int submitJob(String[] args) throws Exception {        int jobStatus = ToolRunner.run(new CJBayesClassfier_Step3(), args);        return jobStatus;    }    @SuppressWarnings("deprecation")    @Override    public int run(String[] args) throws Exception {        Configuration conf = getConf();        Job job = new Job(conf);        job.setJobName("Bayes");        job.addCacheArchive(new URI("/media/chenjie/0009418200012FF3/ubuntu/CJBayes/CLASS-r-00000" + "#CLASS"));        job.addCacheArchive(new URI("/media/chenjie/0009418200012FF3/ubuntu/CJBayes2/part-r-00000" + "#GL"));                job.setInputFormatClass(TextInputFormat.class);        job.setOutputFormatClass(TextOutputFormat.class);                job.setOutputKeyClass(Text.class);               job.setOutputValueClass(LongWritable.class);                             job.setMapperClass(CJBayesClassfierMapper3.class);        job.setReducerClass(CJBayesClassfierReducer3.class);        FileInputFormat.setInputPaths(job, new Path(args[0]));        FileOutputFormat.setOutputPath(job, new Path(args[1]));                FileSystem fs = FileSystem.get(conf);        Path outPath = new Path(args[1]);        if(fs.exists(outPath))        {            fs.delete(outPath, true);        }                boolean status = job.waitForCompletion(true);        return status ? 0 : 1;    }        }


package cjbayesclassfier;import java.util.HashMap;import java.util.Map;import edu.umd.cloud9.io.pair.PairOfStrings;/*** * 保存概率表 * @author chenjie */public class CJGLTable {    private static Map<PairOfStrings,Double> map = new HashMap<PairOfStrings,Double>();    public static void add(PairOfStrings key,Double gl)    {        map.put(key, gl);    }    public static double getClassGL(String aClass)    {        PairOfStrings pos = new PairOfStrings("CLASS",aClass);        return map.get(pos)==null ? 0 : map.get(pos);    }    public static double getConditionalGL(String conditional,String aClass)    {        PairOfStrings pos = new PairOfStrings(conditional,aClass);        return map.get(pos)==null ? 0 : map.get(pos);    }    public static void show()    {        for(Map.Entry<PairOfStrings,Double> entry : map.entrySet())        {            System.out.println(entry);        }    }}




第一步:输入:weather.txt--------------------------Sunny,Hot,High,Weak,NoSunny,Hot,High,Strong,NoOvercast,Hot,High,Weak,YesRain,Mild,High,Weak,YesRain,Cool,Normal,Weak,YesRain,Cool,Normal,Strong,NoOvercast,Cool,Normal,Strong,YesSunny,Mild,High,Weak,NoSunny,Cool,Normal,Weak,YesRain,Mild,Normal,Weak,YesSunny,Mild,Normal,Strong,YesOvercast,Mild,High,Strong,YesOvercast,Hot,Normal,Weak,YesRain,Mild,High,Strong,No输出:CLASS-r-00000----------------------CLASS,No,5CLASS,Yes,9OTHERS-r-00000--------------------------Cool,No,1Cool,Yes,3High,No,4High,Yes,3Hot,No,2Hot,Yes,2Mild,No,2Mild,Yes,4Normal,No,1Normal,Yes,6Overcast,Yes,4Rain,No,2Rain,Yes,3Strong,No,3Strong,Yes,3Sunny,No,3Sunny,Yes,2Weak,No,2Weak,Yes,6第二步:缓存:CLASS-r-00000-----------------------CLASS,No,5CLASS,Yes,9输入:OTHERS-r-00000------------------------Cool,No,1Cool,Yes,3High,No,4High,Yes,3Hot,No,2Hot,Yes,2Mild,No,2Mild,Yes,4Normal,No,1Normal,Yes,6Overcast,Yes,4Rain,No,2Rain,Yes,3Strong,No,3Strong,Yes,3Sunny,No,3Sunny,Yes,2Weak,No,2Weak,Yes,6输出:part-r-00000----------------------------------CLASS,No,0.35714285714285715CLASS,Yes,0.6428571428571429Cool,No,0.2Cool,Yes,0.3333333333333333High,No,0.8High,Yes,0.3333333333333333Hot,No,0.4Hot,Yes,0.2222222222222222Mild,No,0.4Mild,Yes,0.4444444444444444Normal,No,0.2Normal,Yes,0.6666666666666666Overcast,Yes,0.4444444444444444Rain,No,0.4Rain,Yes,0.3333333333333333Strong,No,0.6Strong,Yes,0.3333333333333333Sunny,No,0.6Sunny,Yes,0.2222222222222222Weak,No,0.4Weak,Yes,0.6666666666666666第三步:缓存:CLASS-r-00000-------------------------------CLASS,No,5CLASS,Yes,9缓存:part-r-00000------------------------------------CLASS,No,0.35714285714285715CLASS,Yes,0.6428571428571429Cool,No,0.2Cool,Yes,0.3333333333333333High,No,0.8High,Yes,0.3333333333333333Hot,No,0.4Hot,Yes,0.2222222222222222Mild,No,0.4Mild,Yes,0.4444444444444444Normal,No,0.2Normal,Yes,0.6666666666666666Overcast,Yes,0.4444444444444444Rain,No,0.4Rain,Yes,0.3333333333333333Strong,No,0.6Strong,Yes,0.3333333333333333Sunny,No,0.6Sunny,Yes,0.2222222222222222Weak,No,0.4Weak,Yes,0.6666666666666666输入:weather_predict.txt---------------------------------Overcast,Hot,High,Strong过程:---------------------------------------------分类:No分类:Yes(High, No)=0.8(Strong, No)=0.6(Normal, No)=0.2(Normal, Yes)=0.6666666666666666(Strong, Yes)=0.3333333333333333(CLASS, No)=0.35714285714285715(CLASS, Yes)=0.6428571428571429(Cool, No)=0.2(High, Yes)=0.3333333333333333(Hot, No)=0.4(Sunny, No)=0.6(Weak, No)=0.4(Cool, Yes)=0.3333333333333333(Mild, No)=0.4(Overcast, Yes)=0.4444444444444444(Rain, No)=0.4(Rain, Yes)=0.3333333333333333(Weak, Yes)=0.6666666666666666(Hot, Yes)=0.2222222222222222(Sunny, Yes)=0.2222222222222222(Mild, Yes)=0.4444444444444444key=Overcast,Hot,High,Strongvalues:1对于类别:No其概率为:0.35714285714285715   对于条件:Overcast   其概率为:0.0   对于条件:Hot   其概率为:0.4   对于条件:High   其概率为:0.8   对于条件:Strong   其概率为:0.6对于类别:Yes其概率为:0.6428571428571429   对于条件:Overcast   其概率为:0.4444444444444444   对于条件:Hot   其概率为:0.2222222222222222   对于条件:High   其概率为:0.3333333333333333   对于条件:Strong   其概率为:0.3333333333333333输出:Overcast,Hot,High,Strong   贝叶斯分类:No,其概率为0.0Overcast,Hot,High,Strong   贝叶斯分类:Yes,其概率为0.007054673721340388Overcast,Hot,High,Strong   最终结果:贝叶斯分类为Yes,其概率为0.007054673721340388
使用Spark(原生API)

import org.apache.spark.{SparkConf, SparkContext}import scala.collection.mutableimport scala.collection.mutable.ArrayBufferobject CJBayes {  def main(args: Array[String]): Unit = {    val sparkConf = new SparkConf().setAppName("cjbayes").setMaster("local")    val sc = new SparkContext(sparkConf)    val input = "file:///media/chenjie/0009418200012FF3/ubuntu/weather.txt"    val predictFile = "file:///media/chenjie/0009418200012FF3/ubuntu/weather_predict.txt"    val output = "file:///media/chenjie/0009418200012FF3/ubuntu/weather"    val inputRDD = sc.textFile(input)    val trainDataSize = inputRDD.count()    val mapRDD = inputRDD.flatMap{line=>      val result = ArrayBuffer[Tuple2[Tuple2[String,String],Integer]]()      val tokens = line.split(",")      val classfier = tokens(tokens.length-1)      for(i <- 0 until tokens.length-1){        result += (Tuple2(Tuple2(tokens(i),classfier),1))      }      result += (Tuple2(Tuple2("CLASS",classfier),1))      result    }    val reduceRDD = mapRDD.reduceByKey(_+_)    val countsMap = reduceRDD.collectAsMap()    val PT = new mutable.HashMap[Tuple2[String,String],Double]()    val CLASSFICATIONS = new mutable.ArrayBuffer[String]()    countsMap.foreach(item=>{      val k = item._1      val v:Integer = item._2      val condition = k._1      val classfication = k._2      if(condition.equals("CLASS")){        PT.put(k,v.toDouble/trainDataSize.toDouble)        CLASSFICATIONS += k._2      }      else{        val k2 = new Tuple2[String,String]("CLASS",classfication)        val count = countsMap.get(k2)        if(count==null){          PT.put(k,0.0)        }        else{          PT.put(k,v.toDouble/count.get)        }      }    })    PT.foreach(println)    val predict = sc.textFile(predictFile)    predict.map(line=>{      val attributes = line.split(",")      var selectedClass = ""      var maxPosterior = 0.0      for(aClass <- CLASSFICATIONS){        println("对于类:" + aClass)        var posterior: Double = if (PT.get(Tuple2("CLASS", aClass)) == None) 0 else PT.get(Tuple2("CLASS", aClass)).get        println("其概率为:" + posterior)        for(attr <- attributes){          println("\t对于条件:" + attr)          val probability:Double = if (PT.get(Tuple2(attr,aClass)) == None) 0 else PT.get(Tuple2(attr,aClass)).get          println("\t其概率为:" + probability)          posterior *= probability          if(selectedClass == null){            selectedClass = aClass            maxPosterior = posterior          }          else{            if(posterior > maxPosterior){              selectedClass = aClass              maxPosterior = posterior            }          }        }      }      line + "," + selectedClass + ":" + maxPosterior    }).foreach(println)}
使用Spark(mllib机器学习库)
import org.apache.spark.mllib.classification.NaiveBayesimport org.apache.spark.mllib.linalg.Vectorsimport org.apache.spark.mllib.regression.LabeledPointimport org.apache.spark.{SparkConf, SparkContext}import scala.collection.mutableimport scala.collection.mutable.ArrayBufferobject CJBayes {  def main(args: Array[String]): Unit = {    val sparkConf = new SparkConf().setAppName("cjbayes").setMaster("local")    val sc = new SparkContext(sparkConf)    val input = "file:///media/chenjie/0009418200012FF3/ubuntu/weather1.txt"    val predictFile = "file:///media/chenjie/0009418200012FF3/ubuntu/weather_predict.txt"    val data = sc.textFile(input)    val parsedData =data.map { line =>      val parts =line.split(',')      LabeledPoint(parts(1).toDouble,Vectors.dense(parts(0).split(' ').map(_.toDouble)))    }    // 把数据的100%作为训练集,0%作为测试集.    val splits = parsedData.randomSplit(Array(1.0,0.0),seed = 11L)    val training =splits(0)    val test =splits(1)    //获得训练模型,第一个参数为数据,第二个参数为平滑参数,默认为1,可改    val model =NaiveBayes.train(training,lambda = 1.0)    //对模型进行准确度分析    val predictionAndLabel= test.map(p => (model.predict(p.features),p.label))    val accuracy =1.0 *predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()    println("accuracy-->"+accuracy)    println("Predictionof (2.0,1.0,1.0,2.0):"+model.predict(Vectors.dense(2.0,1.0,1.0,2.0)))  }}

原创粉丝点击