C4.5算法建立决策树JAVA实现

来源:互联网 发布:gulpfile.babel.js 编辑:程序博客网 时间:2024/05/07 01:14

转载连接:http://www.cnblogs.com/lixusign/archive/2012/06/13/2548124.html

当前的属性为:age income student credit_rating

当前的数据集为(最后一列是TARGET_VALUE):

---------------------------------

youth     high   no   fair      no 
youth     high   no   excellent   no 
middle_aged   high   no   fair     yes 
senior     low    yes  fair     yes 
senior     low    yes  excellent   no 
middle_aged   low    yes  excellent   yes 
youth     medium  no   fair     no 
youth     low     yes  fair     yes 
senior     medium  yes    fair     yes 
youth     medium  yes    excellent   yes 
middle_aged   high   yes  fair        yes 
senior     medium  no     excellent   no 
---------------------------------

C4.5建立树类

复制代码
package C45Test;import java.util.ArrayList;import java.util.List;import java.util.Map;public class DecisionTree {    public TreeNode createDT(List<ArrayList<String>> data,List<String> attributeList){                System.out.println("当前的DATA为");        for(int i=0;i<data.size();i++){            ArrayList<String> temp = data.get(i);            for(int j=0;j<temp.size();j++){                System.out.print(temp.get(j)+ " ");            }            System.out.println();        }        System.out.println("---------------------------------");        System.out.println("当前的ATTR为");        for(int i=0;i<attributeList.size();i++){            System.out.print(attributeList.get(i)+ " ");        }        System.out.println();        System.out.println("---------------------------------");        TreeNode node = new TreeNode();        String result = InfoGain.IsPure(InfoGain.getTarget(data));        if(result != null){            node.setNodeName("leafNode");            node.setTargetFunValue(result);            return node;        }        if(attributeList.size() == 0){            node.setTargetFunValue(result);            return node;        }else{            InfoGain gain = new InfoGain(data,attributeList);            double maxGain = 0.0;            int attrIndex = -1;            for(int i=0;i<attributeList.size();i++){                double tempGain = gain.getGainRatio(i);                if(maxGain < tempGain){                    maxGain = tempGain;                    attrIndex = i;                }            }            System.out.println("选择出的最大增益率属性为: " + attributeList.get(attrIndex));            node.setAttributeValue(attributeList.get(attrIndex));            List<ArrayList<String>> resultData = null;            Map<String,Long> attrvalueMap = gain.getAttributeValue(attrIndex);            for(Map.Entry<String, Long> entry : attrvalueMap.entrySet()){                resultData = gain.getData4Value(entry.getKey(), attrIndex);                TreeNode leafNode = null;                System.out.println("当前为"+attributeList.get(attrIndex)+"的"+entry.getKey()+"分支。");                if(resultData.size() == 0){                    leafNode = new TreeNode();                    leafNode.setNodeName(attributeList.get(attrIndex));                    leafNode.setTargetFunValue(result);                    leafNode.setAttributeValue(entry.getKey());                }else{                    for (int j = 0; j < resultData.size(); j++) {                        resultData.get(j).remove(attrIndex);                    }                    ArrayList<String> resultAttr = new ArrayList<String>(attributeList);                    resultAttr.remove(attrIndex);                    leafNode = createDT(resultData,resultAttr);                }                node.getChildTreeNode().add(leafNode);                node.getPathName().add(entry.getKey());            }        }        return node;    }        class TreeNode{                private String attributeValue;        private List<TreeNode> childTreeNode;        private List<String> pathName;        private String targetFunValue;        private String nodeName;                public TreeNode(String nodeName){                        this.nodeName = nodeName;            this.childTreeNode = new ArrayList<TreeNode>();            this.pathName = new ArrayList<String>();        }                public TreeNode(){            this.childTreeNode = new ArrayList<TreeNode>();            this.pathName = new ArrayList<String>();        }        public String getAttributeValue() {            return attributeValue;        }        public void setAttributeValue(String attributeValue) {            this.attributeValue = attributeValue;        }        public List<TreeNode> getChildTreeNode() {            return childTreeNode;        }        public void setChildTreeNode(List<TreeNode> childTreeNode) {            this.childTreeNode = childTreeNode;        }        public String getTargetFunValue() {            return targetFunValue;        }        public void setTargetFunValue(String targetFunValue) {            this.targetFunValue = targetFunValue;        }        public String getNodeName() {            return nodeName;        }        public void setNodeName(String nodeName) {            this.nodeName = nodeName;        }        public List<String> getPathName() {            return pathName;        }        public void setPathName(List<String> pathName) {            this.pathName = pathName;        }            }}
复制代码

 

 

增益率计算类(取log的时候底用的是e,没用2

复制代码
package C45Test;import java.util.ArrayList;import java.util.HashMap;import java.util.HashSet;import java.util.Iterator;import java.util.List;import java.util.Map;import java.util.Set;//C 4.5 实现public class InfoGain {        private List<ArrayList<String>> data;    private List<String> attribute;        public InfoGain(List<ArrayList<String>> data,List<String> attribute){                this.data = new ArrayList<ArrayList<String>>();        for(int i=0;i<data.size();i++){            List<String> temp = data.get(i);            ArrayList<String> t = new ArrayList<String>();            for(int j=0;j<temp.size();j++){                t.add(temp.get(j));            }            this.data.add(t);        }                this.attribute = new ArrayList<String>();        for(int k=0;k<attribute.size();k++){            this.attribute.add(attribute.get(k));        }        /*this.data = data;        this.attribute = attribute;*/    }        //获得熵    public double getEntropy(){                Map<String,Long> targetValueMap = getTargetValue();        Set<String> targetkey = targetValueMap.keySet();        double entropy = 0.0;        for(String key : targetkey){            double p = MathUtils.div((double)targetValueMap.get(key), (double)data.size());            entropy += (-1) * p * Math.log(p);        }        return entropy;    }        //获得InfoA    public double getInfoAttribute(int attributeIndex){                Map<String,Long> attributeValueMap = getAttributeValue(attributeIndex);        double infoA = 0.0;        for(Map.Entry<String, Long> entry : attributeValueMap.entrySet()){            int size = data.size();            double attributeP = MathUtils.div((double)entry.getValue() , (double) size);            Map<String,Long> targetValueMap = getAttributeValueTargetValue(entry.getKey(),attributeIndex);            long totalCount = 0L;            for(Map.Entry<String, Long> entryValue :targetValueMap.entrySet()){                totalCount += entryValue.getValue();             }            double valueSum = 0.0;            for(Map.Entry<String, Long> entryTargetValue : targetValueMap.entrySet()){                 double p = MathUtils.div((double)entryTargetValue.getValue(), (double)totalCount);                 valueSum += Math.log(p) * p;            }            infoA += (-1) * attributeP * valueSum;        }        return infoA;            }        //得到属性值在决策空间的比例    public Map<String,Long> getAttributeValueTargetValue(String attributeName,int attributeIndex){                Map<String,Long> targetValueMap = new HashMap<String,Long>();        Iterator<ArrayList<String>> iterator = data.iterator();        while(iterator.hasNext()){            List<String> tempList = iterator.next();            if(attributeName.equalsIgnoreCase(tempList.get(attributeIndex))){                int size = tempList.size();                String key = tempList.get(size - 1);                Long value = targetValueMap.get(key);                targetValueMap.put(key, value != null ? ++value :1L);            }        }        return targetValueMap;    }        //得到属性在决策空间上的数量    public Map<String,Long> getAttributeValue(int attributeIndex){                Map<String,Long> attributeValueMap = new HashMap<String,Long>();        for(ArrayList<String> note : data){            String key = note.get(attributeIndex);            Long value = attributeValueMap.get(key);            attributeValueMap.put(key, value != null ? ++value :1L);        }        return attributeValueMap;            }        public List<ArrayList<String>> getData4Value(String attrValue,int attrIndex){                List<ArrayList<String>> resultData = new ArrayList<ArrayList<String>>();        Iterator<ArrayList<String>> iterator = data.iterator();        for(;iterator.hasNext();){            ArrayList<String> templist = iterator.next();            if(templist.get(attrIndex).equalsIgnoreCase(attrValue)){                ArrayList<String> temp = (ArrayList<String>) templist.clone();                resultData.add(temp);            }        }        return resultData;    }        //获得增益率    public double getGainRatio(int attributeIndex){        return MathUtils.div(getGain(attributeIndex), getSplitInfo(attributeIndex));    }        //获得增益量    public double getGain(int attributeIndex){        return getEntropy() - getInfoAttribute(attributeIndex);    }        //得到惩罚因子    public double getSplitInfo(int attributeIndex){                Map<String,Long> attributeValueMap = getAttributeValue(attributeIndex);        double splitA = 0.0;        for(Map.Entry<String, Long> entry : attributeValueMap.entrySet()){            int size = data.size();            double attributeP = MathUtils.div((double)entry.getValue() , (double) size);            splitA += attributeP * Math.log(attributeP) * (-1);        }        return splitA;    }        //得到目标函数在当前集合范围内的离散的值    public Map<String,Long> getTargetValue(){                Map<String,Long> targetValueMap = new HashMap<String,Long>();        Iterator<ArrayList<String>> iterator = data.iterator();        while(iterator.hasNext()){            List<String> tempList = iterator.next();            String key = tempList.get(tempList.size() - 1);            Long value = targetValueMap.get(key);            targetValueMap.put(key, value != null ? ++value : 1L);        }        return targetValueMap;    }        //获得TARGET值    public static List<String> getTarget(List<ArrayList<String>> data){                List<String> list = new ArrayList<String>();        for(ArrayList<String> temp : data){            int index = temp.size() -1;            String value = temp.get(index);            list.add(value);        }        return list;    }        //判断当前纯度是否100%    public static String IsPure(List<String> list){                Set<String> set = new HashSet<String>();        for(String name :list){            set.add(name);        }        if(set.size() > 1) return null;        Iterator<String> iterator = set.iterator();        return iterator.next();    }    }
复制代码

 

测试类,数据集读取以上的分别放到2个List中。

复制代码
package C45Test;import java.util.ArrayList;import java.util.List;import C45Test.DecisionTree.TreeNode;public class MainC45 {    private static final List<ArrayList<String>> dataList = new ArrayList<ArrayList<String>>();    private static final List<String> attributeList = new ArrayList<String>();        public static void main(String args[]){                DecisionTree dt = new DecisionTree();        TreeNode node = dt.createDT(configData(),configAttribute());        System.out.println();    }}
复制代码

 

大数运算工具类

复制代码
package C45Test;import java.math.BigDecimal;public abstract class MathUtils {        //默认余数长度    private static final int DIV_SCALE = 10;        //受限于DOUBLE长度    public static double add(double value1,double value2){                BigDecimal big1 = new BigDecimal(String.valueOf(value1));        BigDecimal big2 = new BigDecimal(String.valueOf(value2));        return big1.add(big2).doubleValue();    }        //大数加法    public static double add(String value1,String value2){                BigDecimal big1 = new BigDecimal(value1);        BigDecimal big2 = new BigDecimal(value2);        return big1.add(big2).doubleValue();    }        public static double div(double value1,double value2){                BigDecimal big1 = new BigDecimal(String.valueOf(value1));        BigDecimal big2 = new BigDecimal(String.valueOf(value2));        return big1.divide(big2,DIV_SCALE,BigDecimal.ROUND_HALF_UP).doubleValue();    }        public static double mul(double value1,double value2){                BigDecimal big1 = new BigDecimal(String.valueOf(value1));        BigDecimal big2 = new BigDecimal(String.valueOf(value2));        return big1.multiply(big2).doubleValue();    }        public static double sub(double value1,double value2){                BigDecimal big1 = new BigDecimal(String.valueOf(value1));        BigDecimal big2 = new BigDecimal(String.valueOf(value2));        return big1.subtract(big2).doubleValue();    }        public static double returnMax(double value1, double value2) {                BigDecimal big1 = new BigDecimal(value1);        BigDecimal big2 = new BigDecimal(value2);        return big1.max(big2).doubleValue();    }}
复制代码
1 0