数据挖掘--Cart算法的实现

来源:互联网 发布:联想网络控制器驱动 编辑:程序博客网 时间:2024/05/28 20:18
import java.io.BufferedReader;import java.io.FileReader;import java.io.IOException;import java.util.ArrayList;import java.util.HashMap;import java.util.HashSet;import java.util.Iterator;import java.util.List;import java.util.Map;import java.util.Set;public class Cart {String Var="";public float Gini_compute(List<String> Target,String Split){//函数作用:计算给定属性划分的Gini指数值,其中Target为二维向量集合,第一维表示属性,第二维表示种类.//格式Target:a1 c1    split:a1 a2 a3List<String> Target1=new ArrayList<String> ();List<String> Target2=new ArrayList<String> ();String[] Split_set=Split.split(" ");Iterator<String> Iter=Target.iterator();while(Iter.hasNext()){String tmp=Iter.next();String[] tmp_set=tmp.split(" ");int in_Split=0;for(int i=0;i<Split_set.length;i++){    if(Split_set[i].equals(tmp_set[0])){in_Split=1;break;}}if(in_Split==1){Target1.add(tmp);}else{Target2.add(tmp);}}float Gini=0;Gini=Gini_index(Target1)*((float)Target1.size())/(Target1.size()+Target2.size());Gini +=Gini_index(Target2)*((float)Target2.size())/(Target1.size()+Target2.size());Gini=Gini_index(Target)-Gini;return Gini;}public float Gini_index(List<String> Target){//函数作用:计算给集合的Gini指标计算.String[] Terget_array=new String[Target.size()];Set<String> Target_set=new HashSet<String>();Iterator<String> Iter=Target.iterator();int i=0;while(Iter.hasNext()){Terget_array[i]=Iter.next().split(" ")[1];Target_set.add(Terget_array[i]);i=i+1;}int[] count=new int[Target_set.size()];float[] p=new float[Target_set.size()];Iterator<String> Iter1=Target_set.iterator();i=0;while(Iter1.hasNext()){count[i]=0;String tmp=Iter1.next();for(int j=0;j<Terget_array.length;j++){if(Terget_array[j].equals(tmp)){count[i] +=1;}}p[i]=(((float)count[i])/Terget_array.length)*(((float)count[i])/Terget_array.length);i=i+1;}float sum=0;for(i=0;i<p.length;i++){sum=sum+p[i];}return 1-sum;}public List<String> Gini_select(List<String> DataSet,int i){//函数作用:计算DataSet中第i列指标的最优属性划分List<String> DataSet_i=new ArrayList<String>();Set<String> DataSet_i_set=new HashSet<String>();Iterator<String> Iter=DataSet.iterator();while(Iter.hasNext()){String[] tmp=Iter.next().split(" ");DataSet_i.add(tmp[i]+" "+tmp[tmp.length-1]);DataSet_i_set.add(tmp[i]);}String set_i="";Iterator<String> Iter1=DataSet_i_set.iterator();while(Iter1.hasNext()){set_i=set_i+" "+Iter1.next();}set_i=set_i.trim();ArrayList<String> list = new ArrayList<String>();doGetSubSequences(set_i,"",list);String max_set=list.get(0);float max=Gini_compute(DataSet_i,max_set);for(int j=1;j<list.size();j++){if(Gini_compute(DataSet_i,list.get(j))>max){max=Gini_compute(DataSet_i,list.get(j));max_set=list.get(j);}}List<String> return_list=new ArrayList<String>();return_list.add(max_set);return_list.add(String.valueOf(max));return return_list;}private static void doGetSubSequences(String word, String s,ArrayList<String> list) {if (word.length() == 0) {//函数作用:给定集合的所有子集s=s.trim();list.add(s);return;}String tail="";if(word.split(" ",2).length>=2){tail= word.split(" ",2)[1];}doGetSubSequences(tail, s, list);doGetSubSequences(tail, s + " "+word.split(" ",2)[0], list);}public void Cart_tree(List<String> DataSet,String path,int alpha,int alpha_max){if(alpha==alpha_max | DataSet.size()<=2){//cart决策树,终止条件1write_result(DataSet,path);return;}int count_var=DataSet.get(0).split(" ").length-1;String max_split_L="";float max_Gini=-1;int max_index=-1;for(int i=0;i<count_var;i++){if(Float.parseFloat(Gini_select(DataSet,i).get(1))>max_Gini){max_Gini=Float.parseFloat(Gini_select(DataSet,i).get(1));max_split_L=Gini_select(DataSet,i).get(0);max_index=i;}}if(max_Gini<=0.01){//cart决策树,终止条件2write_result(DataSet,path);return;}List<String> DataSet_L=new ArrayList<String>();List<String> DataSet_R=new ArrayList<String>();DataSet_split(DataSet,max_index,max_split_L,DataSet_L,DataSet_R);String max_split_R=Compute_split_R(DataSet,max_index,max_split_L);Cart_tree(DataSet_L,path+"|"+this.Var.split(" ")[max_index]+":"+max_split_L,alpha+1,alpha_max);Cart_tree(DataSet_R,path+"|"+this.Var.split(" ")[max_index]+":"+max_split_R,alpha+1,alpha_max);}private void write_result(List<String> DataSet, String path) {//函数作用:输出cart叶子节点的结果String[] Category=new String[DataSet.size()];for(int i=0;i<Category.length;i++){Category[i]=DataSet.get(i).trim().split(" ")[DataSet.get(i).trim().split(" ").length-1];}Map<String,Integer> map=new HashMap<String,Integer>();for(int i=0;i<Category.length;i++){if(!map.containsKey(Category[i])){map.put(Category[i], 1);}else{map.put(Category[i], map.get(Category[i])+1);}}int sum_count=0;int max_count=0;String max_Category="";Iterator<String> Iter=map.keySet().iterator();while(Iter.hasNext()){String tmp=Iter.next();if(map.get(tmp)>=max_count){max_count=map.get(tmp);max_Category=tmp;}sum_count=sum_count+map.get(tmp);}int count=DataSet.size();    String forcast=max_Category;    float accuracy_rate=((float)max_count)/sum_count;    System.out.println("Rule:"+path+".   Count:"+count+".   "+this.Var.split(" ")[this.Var.split(" ").length-1]+":"+forcast+".   Accuracy_rate:"+accuracy_rate);}private String Compute_split_R(List<String> DataSet, int index,String split_L) {//函数作用:DataSet中第index列中,属性一半划分为split_L,输出另外的一半划分split_RString split_R="";Set<String> set=new HashSet<String>();for(int i=0;i<DataSet.size();i++){set.add(DataSet.get(i).split(" ")[index]);}for(int i=0;i<split_L.trim().split(" ").length;i++){set.remove(split_L.trim().split(" ")[i]);}Iterator<String> Iter=set.iterator();while(Iter.hasNext()){split_R=split_R+" "+Iter.next();}return split_R.trim();}private void DataSet_split(List<String> DataSet, int max_index,String max_split_L, List<String> DataSet_L, List<String> DataSet_R) {for(int i=0;i<DataSet.size();i++){//函数作用:DataSet第max_index列按照属性max_split_L划分后的两个数集为DataSet_L,DataSet_R.int i_in_L=0;for(int j=0;j<max_split_L.trim().split(" ").length;j++){if(DataSet.get(i).split(" ")[max_index].equals(max_split_L.trim().split(" ")[j])){DataSet_L.add(DataSet.get(i));i_in_L=1;break;}}if(i_in_L==0){DataSet_R.add(DataSet.get(i));}}}public static void main(String[] args) throws IOException {BufferedReader br=new BufferedReader(new FileReader("F:/数据挖掘--算法实现/cart算法/input.txt"));          String line="";        int i=0;        List<String> DataSet=new ArrayList<String>();        String Var="";        while((line=br.readLine())!=null){        if(i==0){i=1;Var=line;continue;}        DataSet.add(line);        }        Cart a=new Cart();        a.Var=Var;a.Cart_tree(DataSet,"",0,2);}}

输入:

age income student credit_rating buys_computer
youth high no fair no
youth high no excellent no
middle_aged high no fair yes
senior medium no fair yes
senior low yes fair yes
senior low yes excellent no
middle_aged low yes excellent yes
youth medium no fair no
youth low yes fair yes
senior medium yes fair yes
youth medium yes excellent yes
middle_aged medium no excellent yes
middle_aged high yes fair yes
senior medium no excellent no

数据格式说明:第一行表示变量名,其中buys_computer是目标变量,其余的行表示用户数据,每个数据单元以空格分开


输出结果:

Rule:|age:middle_aged.   Count:4.   buys_computer:yes.   Accuracy_rate:1.0
Rule:|age:senior youth|student:yes.   Count:5.   buys_computer:yes.   Accuracy_rate:0.8
Rule:|age:senior youth|student:no.   Count:5.   buys_computer:no.   Accuracy_rate:0.8

0 0