数据挖掘笔记-分类-决策树-MapReduce实现-2

来源:互联网 发布:淘宝开书店是什么类目 编辑:程序博客网 时间:2024/05/17 16:54

上篇文章里面虽然结合Hadoop用到MapReduce去计算属性的增益率,但是发现整个程序似乎也并没有做到并行化处理。后面又看了一些网上的资料,自己又想了想,然后又重新实现了一下决策树,大体思路如下:

1、将一个大数据集文件拆分成N个小数据集文件,对数据做好预处理工作,上传到HDFS

2、计算HDFS上小数据集文件的最佳分割属性与分割点

3、汇总N个小数据集文件的最佳划分,投票选出最佳划分

4、N个小数据集的节点根据最终的最佳划分,分割自己节点上的数据,上传到HDFS,跳转到第二步

 

下面是具体的实现代码:其中用到了JobControl来控制多Job执行,还有涉及到几个MR程序,代码未进行过整理,望见谅。

public class DecisionTreeSprintBJob extends AbstractJob {private Map<String, Map<Object, Integer>> attributeValueStatistics = null;private Map<String, Set<String>> attributeNameToValues = null;private Set<String> allAttributes = null;/** 数据拆分,大数据文件拆分为小数据文件,便于分配到各个节点开启Job*/private List<String> split(String input, String splitNum) {String output = HDFSUtils.HDFS_TEMP_INPUT_URL + IdentityUtils.generateUUID();String[] args = new String[]{input, output, splitNum};DataFileSplitMR.main(args);List<String> inputs = new ArrayList<String>();Path outputPath = new Path(output);try {FileSystem fs = outputPath.getFileSystem(conf);Path[] paths = HDFSUtils.getPathFiles(fs, outputPath);for(Path path : paths) {System.out.println("split input path: " + path);InputStream in = fs.open(path);BufferedReader reader = new BufferedReader(new InputStreamReader(in));String line = reader.readLine();while (null != line && !"".equals(line)) {inputs.add(line);line = reader.readLine();}IOUtils.closeQuietly(in);IOUtils.closeQuietly(reader);}} catch (IOException e) {e.printStackTrace();}System.out.println("inputs size: " + inputs.size());return inputs;}/** 初始化工作,主要是获取特征属性集以及属性值的统计,主要是为了填充默认值*/private void initialize(String input) {System.out.println("initialize start.");allAttributes = new HashSet<String>();attributeNameToValues = new HashMap<String, Set<String>>();attributeValueStatistics = new HashMap<String, Map<Object, Integer>>();String output = HDFSUtils.HDFS_TEMP_INPUT_URL + IdentityUtils.generateUUID();String[] args = new String[]{input, output};AttributeStatisticsMR.main(args);Path outputPath = new Path(output);SequenceFile.Reader reader = null;try {FileSystem fs = outputPath.getFileSystem(conf);Path[] paths = HDFSUtils.getPathFiles(fs, outputPath);for(Path path : paths) {reader = new SequenceFile.Reader(fs, path, conf);AttributeKVWritable key = (AttributeKVWritable) ReflectionUtils.newInstance(reader.getKeyClass(), conf);IntWritable value = new IntWritable();while (reader.next(key, value)) {String attributeName = key.getAttributeName();allAttributes.add(attributeName);Set<String> values = attributeNameToValues.get(attributeName);if (null == values) {values = new HashSet<String>();attributeNameToValues.put(attributeName, values);}String attributeValue = key.getAttributeValue();values.add(attributeValue);Map<Object, Integer> valueStatistics = attributeValueStatistics.get(attributeName);if (null == valueStatistics) {valueStatistics = new HashMap<Object, Integer>();attributeValueStatistics.put(attributeName, valueStatistics);}valueStatistics.put(attributeValue, value.get());value = new IntWritable();}}} catch (IOException e) {e.printStackTrace();} finally {IOUtils.closeQuietly(reader);}System.out.println("initialize end.");}/** 预处理,主要是将分割后的小文件填充好默认值后在上传到HDFS上面*/private List<String> preHandle(List<String> inputs) throws IOException {List<String> fillInputs = new ArrayList<String>();for (String input : inputs) {Data data =null;try {Path inputPath = new Path(input);FileSystem fs = inputPath.getFileSystem(conf);FSDataInputStream fsInputStream = fs.open(inputPath);data = DataLoader.load(fsInputStream, true);} catch (IOException e) {e.printStackTrace();}DataHandler.computeFill(data.getInstances(), allAttributes.toArray(new String[0]), attributeValueStatistics, 1.0);OutputStream out = null;BufferedWriter writer = null;String outputDir = HDFSUtils.HDFS_TEMP_INPUT_URL + IdentityUtils.generateUUID();fillInputs.add(outputDir);String output = outputDir + File.separator + IdentityUtils.generateUUID();try {Path outputPath = new Path(output);FileSystem fs = outputPath.getFileSystem(conf);out = fs.create(outputPath);writer = new BufferedWriter(new OutputStreamWriter(out));StringBuilder sb = null;for (Instance instance : data.getInstances()) {sb = new StringBuilder();sb.append(instance.getId()).append("\t");sb.append(instance.getCategory()).append("\t");Map<String, Object> attrs = instance.getAttributes();for (Map.Entry<String, Object> entry : attrs.entrySet()) {sb.append(entry.getKey()).append(":");sb.append(entry.getValue()).append("\t");}writer.write(sb.toString());writer.newLine();}writer.flush();} catch (Exception e) {e.printStackTrace();} finally {IOUtils.closeQuietly(out);IOUtils.closeQuietly(writer);}}return fillInputs;}/** 创建JOB*/private Job createJob(String jobName, String input, String output) {Configuration conf = new Configuration();conf.set("mapred.job.queue.name", "q_hudong");Job job = null;try {job = new Job(conf, jobName);FileInputFormat.addInputPath(job, new Path(input));FileOutputFormat.setOutputPath(job, new Path(output));job.setJarByClass(DecisionTreeSprintBJob.class);job.setMapperClass(CalculateGiniMapper.class);job.setMapOutputKeyClass(Text.class);job.setMapOutputValueClass(AttributeWritable.class);job.setReducerClass(CalculateGiniReducer.class);job.setOutputKeyClass(Text.class);job.setOutputValueClass(AttributeGiniWritable.class);job.setInputFormatClass(TextInputFormat.class);job.setOutputFormatClass(SequenceFileOutputFormat.class);} catch (IOException e) {e.printStackTrace();}return job;}/** 根据HDFS上的输出路径选择最佳属性*/private AttributeGiniWritable chooseBestAttribute(String... outputs) {AttributeGiniWritable minSplitAttribute = null;double minSplitPointGini = 1.0;try {for (String output : outputs) {System.out.println("choose output: " + output);Path outputPath = new Path(output);FileSystem fs = outputPath.getFileSystem(conf);Path[] paths = HDFSUtils.getPathFiles(fs, outputPath);ShowUtils.print(paths);SequenceFile.Reader reader = null;for (Path path : paths) {reader = new SequenceFile.Reader(fs, path, conf);Text key = (Text) ReflectionUtils.newInstance(reader.getKeyClass(), conf);AttributeGiniWritable value = new AttributeGiniWritable();while (reader.next(key, value)) {double gini = value.getGini();System.out.println(value.getAttribute() + " : " + gini);if (gini <= minSplitPointGini) {minSplitPointGini = gini;minSplitAttribute = value;}value = new AttributeGiniWritable();}IOUtils.closeQuietly(reader);}System.out.println("delete hdfs file start: " + outputPath.toString());HDFSUtils.delete(conf, outputPath);System.out.println("delete hdfs file end: " + outputPath.toString());}} catch (IOException e) {e.printStackTrace();}if (null == minSplitAttribute) {System.out.println("minSplitAttribute is null");}return minSplitAttribute;}private Data obtainData(String input) {Data data = null;Path inputPath = new Path(input);try {FileSystem fs = inputPath.getFileSystem(conf);Path[] hdfsPaths = HDFSUtils.getPathFiles(fs, inputPath);FSDataInputStream fsInputStream = fs.open(hdfsPaths[0]);data = DataLoader.load(fsInputStream, true);} catch (IOException e) {e.printStackTrace();}return data;}/** 构建决策树*/private Object build(List<String> inputs) throws IOException {List<String> outputs = new ArrayList<String>();JobControl jobControl = new JobControl("CalculateGini");for (String input : inputs) {System.out.println("split path: " + input);String output = HDFSUtils.HDFS_TEMP_OUTPUT_URL + IdentityUtils.generateUUID();outputs.add(output);Configuration conf = new Configuration();ControlledJob controlledJob = new ControlledJob(conf);controlledJob.setJob(createJob(input, input, output));jobControl.addJob(controlledJob);}Thread jcThread = new Thread(jobControl);          jcThread.start();          while(true){              if(jobControl.allFinished()){  //                System.out.println(jobControl.getSuccessfulJobList());                  jobControl.stop();                  AttributeGiniWritable bestAttr = chooseBestAttribute(                outputs.toArray(new String[0]));                String attribute = bestAttr.getAttribute();        System.out.println("best attribute: " + attribute);        System.out.println("isCategory: " + bestAttr.isCategory());        if (bestAttr.isCategory()) {        return attribute;        }        TreeNode treeNode = new TreeNode(attribute);        Map<String, List<String>> splitToInputs =         new HashMap<String, List<String>>();        for (String input : inputs) {        Data data = obtainData(input);        String splitPoint = bestAttr.getSplitPoint();//        Map<String, Set<String>> attrName2Values = //        DataHandler.attributeValueStatistics(data.getInstances());        Set<String> attributeValues = attributeNameToValues.get(attribute);        System.out.println("attributeValues:");        ShowUtils.print(attributeValues);        if (attributeNameToValues.size() == 0 || null == attributeValues) {        continue;        }        attributeValues.remove(splitPoint);        StringBuilder sb = new StringBuilder();        for (String attributeValue : attributeValues) {        sb.append(attributeValue).append(",");        }        if (sb.length() > 0) sb.deleteCharAt(sb.length() - 1);        String[] names = new String[]{splitPoint, sb.toString()};        DataSplit dataSplit = DataHandler.split(new Data(        data.getInstances(), attribute, names));        for (DataSplitItem item : dataSplit.getItems()) {        if (item.getInstances().size() == 0) continue;        String path = item.getPath();        String name = path.substring(path.lastIndexOf(File.separator) + 1);        String hdfsPath = HDFSUtils.HDFS_TEMP_INPUT_URL + name;        HDFSUtils.copyFromLocalFile(conf, path, hdfsPath);        String split = item.getSplitPoint();        List<String> nextInputs = splitToInputs.get(split);        if (null == nextInputs) {        nextInputs = new ArrayList<String>();        splitToInputs.put(split, nextInputs);        }        nextInputs.add(hdfsPath);        }        }        for (Map.Entry<String, List<String>> entry :         splitToInputs.entrySet()) {        treeNode.setChild(entry.getKey(), build(entry.getValue()));        }        return treeNode;            }              if(jobControl.getFailedJobList().size() > 0){  //                System.out.println(jobControl.getFailedJobList());                  jobControl.stop();              }          }  }/** 分类样本集*/private void classify(TreeNode treeNode, String testSet, String output) {OutputStream out = null;BufferedWriter writer = null;try {Path testSetPath = new Path(testSet);FileSystem testFS = testSetPath.getFileSystem(conf);Path[] testHdfsPaths = HDFSUtils.getPathFiles(testFS, testSetPath);FSDataInputStream fsInputStream = testFS.open(testHdfsPaths[0]);Data testData = DataLoader.load(fsInputStream, true);DataHandler.computeFill(testData.getInstances(), allAttributes.toArray(new String[0]), attributeValueStatistics, 1.0);Object[] results = (Object[]) treeNode.classifySprint(testData);ShowUtils.print(results);DataError dataError = new DataError(testData.getCategories(), results);dataError.report();String path = FileUtils.obtainRandomTxtPath();out = new FileOutputStream(new File(path));writer = new BufferedWriter(new OutputStreamWriter(out));StringBuilder sb = null;for (int i = 0, len = results.length; i < len; i++) {sb = new StringBuilder();sb.append(i+1).append("\t").append(results[i]);writer.write(sb.toString());writer.newLine();}writer.flush();Path outputPath = new Path(output);FileSystem fs = outputPath.getFileSystem(conf);if (!fs.exists(outputPath)) {fs.mkdirs(outputPath);}String name = path.substring(path.lastIndexOf(File.separator) + 1);HDFSUtils.copyFromLocalFile(conf, path, output + File.separator + name);} catch (IOException e) {e.printStackTrace();} finally {IOUtils.closeQuietly(out);IOUtils.closeQuietly(writer);}}public void run(String[] args) {try {if (null == conf) conf = new Configuration();String[] inputArgs = new GenericOptionsParser(conf, args).getRemainingArgs();if (inputArgs.length != 4) {System.out.println("error, please input three path.");System.out.println("1. trainset path.");System.out.println("2. testset path.");System.out.println("3. result output path.");System.out.println("4. data split number.");System.exit(2);}List<String> splitInputs = split(inputArgs[0], inputArgs[3]);initialize(inputArgs[0]);List<String> inputs = preHandle(splitInputs);TreeNode treeNode = (TreeNode) build(inputs);TreeNodeHelper.print(treeNode, 0, null);classify(treeNode, inputArgs[1], inputArgs[2]);} catch (Exception e) {e.printStackTrace();}}public static void main(String[] args) {DecisionTreeSprintBJob job = new DecisionTreeSprintBJob();long startTime = System.currentTimeMillis();job.run(args);long endTime = System.currentTimeMillis();System.out.println("spend time: " + (endTime - startTime));}}


代码托管:https://github.com/fighting-one-piece/repository-datamining.git
 

 

 

1 0