决策树之ID3算法java实现

来源:互联网 发布:途宝网络 编辑:程序博客网 时间:2024/05/22 14:59

package com.decisiontree;import java.util.ArrayList;import java.util.HashMap;import java.util.Iterator;public class ID3 {/** * @param Spring_LGF */public static void main(String[] args) {// TODO Auto-generated method stub//用于存储所有属性可能的取值ArrayList<String> attrOutlook = new ArrayList<String>();attrOutlook.add("sunny");attrOutlook.add("overcast");attrOutlook.add("rainy");ArrayList<String> attrTemperature = new ArrayList<String>();attrTemperature.add("hot");attrTemperature.add("mild");attrTemperature.add("cool");ArrayList<String> attrHumidity = new ArrayList<String>();attrHumidity.add("high");attrHumidity.add("normal");ArrayList<String> attrWindy = new ArrayList<String>();attrWindy.add("true");attrWindy.add("false");ArrayList<String> attrPlay = new ArrayList<String>();attrPlay.add("no");attrPlay.add("yes");//属性名与属性的取值进行对应 HashMap<String,ArrayList<String>> attr = new HashMap<String,ArrayList<String>>();attr.put("outlook", attrOutlook);attr.put("trmperature", attrTemperature);attr.put("humidity", attrHumidity);attr.put("windy", attrWindy);//attr.put("play",attrPlay);//存储属性的索引, 便于在对数据统计HashMap<String,Integer> attrIndex = new HashMap<String,Integer>();attrIndex.put("outlook", 0);attrIndex.put("trmperature", 1);attrIndex.put("humidity", 2);attrIndex.put("windy", 3);//attrIndex.put("play", 4);//样本存储String[][] data = {{"sunny","hot","high","false","no"},{"sunny","hot","high","true","no"},{"overcast","hot","high","false","yes"},{"rainy","mild","high","false","yes"},{"rainy","cool","normal","false","yes"},{"rainy","cool","normal","true","no"},{"overcast","cool","normal","true","yes"},{"sunny","mild","high","false","no"},{"sunny","cool","normal","false","yes"},{"rainy","mild","normal","false","yes"},{"sunny","mild","normal","true","yes"},{"overcast","mild","high","true","yes"},{"overcast","hot","normal","false","yes"},{"rainy","mild","high","true","no"}};ID3Tree root = new ID3Tree();buildID3Tree(root,data,attr,attrIndex);outputID3Tree(root);}//构造决策树public static ID3Tree buildID3Tree(ID3Tree root, String[][] data, HashMap<String,ArrayList<String>> attr, HashMap<String,Integer> attrIndex){Iterator<String>  attrIt = attr.keySet().iterator();String maxAttr = null;String attrName;//属性名称HashMap<String, Double> attrValueList = new HashMap<String, Double>();//用于记录每一个属性的取值在样本中出现的次数HashMap<String, Double> attrValueMap = new HashMap<String,Double>();while(attrIt.hasNext() && (!attr.isEmpty())){attrName = attrIt.next();//取得属性可能出现的取值列表ArrayList<String> attrList = attr.get(attrName);//取得属性的索引值int index = attrIndex.get(attrName);//用于扫描每一个属性的所有取值for(int i = 0; i < attrList.size(); i++){String attrValue = attrList.get(i);int isPlay = 0;int noPlay = 0;//扫描书样本中每一个属性的取值出现的次数for(int j = 0; j < data.length; j++){if(data[j][index] == null){break;}if(data[j][index].equals(attrValue) && data[j][4].equals("yes")){isPlay++;}if(data[j][index].equals(attrValue) && data[j][4].equals("no")){noPlay++;}}double num = (-1* log(((double)isPlay/(double)(isPlay+noPlay)),2.0) * ((double)isPlay/(double)(isPlay+noPlay))) - log(((double)noPlay/(double)(isPlay+noPlay)),2.0) * ((double)noPlay/(double)(isPlay+noPlay));//double num = ((-1)*(Math.log(isPlay/(isPlay+noPlay)) / Math.log(2.0) * isPlay / (isPlay+noPlay)) - (Math.log(noPlay/(isPlay+noPlay)) / Math.log(2.0) * noPlay / (isPlay+noPlay)));double sum = 0.0;if(Double.compare(num, Double.NaN) == 0){num = 0.0;}attrValueMap.put(attrValue, num);//计算每一个属性的熵值if(attrValueList.get(attrName) == null){attrValueList.put(attrName, num*(double)(isPlay+noPlay)/data.length);}else{ sum = attrValueList.get(attrName) + num*(double)(isPlay+noPlay)/data.length; attrValueList.put(attrName, sum);}}if(maxAttr == null){maxAttr = attrName;}else{if(attrValueList.get(attrName) - attrValueList.get(maxAttr) < 0.0){maxAttr = attrName;}}}if(maxAttr != null){int index = attrIndex.get(maxAttr);ArrayList<String> attrList = attr.get(maxAttr);root.attrName = maxAttr;root.treeList = new ArrayList<ID3Tree>();for(int i = 0; i < attrList.size(); i++){String valueName = attrList.get(i);double value = attrValueMap.get(valueName);ID3Tree node = new ID3Tree();int isPlay = 0;int isAttr = 0;for(int j = 0; j < data.length; j++){if(data[j][index] == null){break;}if(data[j][index].equals(valueName)){isAttr++;if(data[j][4].equals("yes")){isPlay++;}}}if(value == 0.0){node.isleaf = true;if(isPlay == isAttr){node.isPlay = true;}node.attrValue = valueName;root.treeList.add(node);}else{node.isleaf = false;node.attrValue = valueName;String [][]da= new String[14][4];for(int k = 0, n = 0; k < data.length; k++){if(data[k][index].equals(valueName)){da[n++] = data[k];}}HashMap<String,ArrayList<String>> attr2 = attr;attr2.remove(maxAttr);System.out.println(attr2);buildID3Tree(node,da,attr2,attrIndex);root.treeList.add(node);}}}return root;}//遍历决策树public static void outputID3Tree(ID3Tree root){System.out.println(root.attrName + "   " + root.attrValue + "  " + root.isPlay + "   " + root.isleaf);ArrayList<ID3Tree> treeList = root.treeList;if(root.treeList != null){for(int i = 0 ; i < treeList.size(); i++){outputID3Tree(treeList.get(i));}}}//对数的计算,第一个参数表示的对数,第二个参数表示的是底static public double log(double value, double base) {return Math.log(value) / Math.log(base);}static class ID3Tree{//是否是叶子节点private boolean isleaf;//是否出去玩,该值只有在叶子节点中出现private boolean isPlay;//上一个节点在该节点的取值private String attrValue;//孩子节点数组private ArrayList<ID3Tree> treeList;private String attrName;public String getAttrName() {return attrName;}public void setAttrName(String attrName) {this.attrName = attrName;}public boolean isPlay() {return isPlay;}public void setPlay(boolean isPlay) {this.isPlay = isPlay;}public boolean isIsleaf() {return isleaf;}public void setIsleaf(boolean isleaf) {this.isleaf = isleaf;}public ArrayList<ID3Tree> getTreeList() {return treeList;}public void setTreeList(ArrayList<ID3Tree> treeList) {this.treeList = treeList;}public String getAttrValue() {return attrValue;}public void setAttrValue(String attrValue) {this.attrValue = attrValue;}}}

0 0
原创粉丝点击