weka[1] - ID3算法

来源:互联网 发布:淘宝新手如何提高销量 编辑:程序博客网 时间:2024/05/16 05:28

我们知道ID3是一个最基本的决策树算法。他主要是每次根据InfoGain来选取特征进行分裂,并且没有进行剪枝。

buildClassifier:

  public void buildClassifier(Instances data) throws Exception {    // can classifier handle the data?    getCapabilities().testWithFail(data);    // remove instances with missing class    data = new Instances(data);    data.deleteWithMissingClass();        //递归构造决策树    makeTree(data);  }

这里没有什么好写的,只需看最后一行,makeTree这个函数。

makeTree:

    // Check if no instances have reached this node.    // 如果该树为空,则返回    if (data.numInstances() == 0) {      m_Attribute = null;      m_ClassValue = Instance.missingValue();      m_Distribution = new double[data.numClasses()];      return;    }

检验此时的树是否已经为空,如果为空,说明递归完成。上一个分裂节点,就是叶子。

 // Compute attribute with maximum information gain.    double[] infoGains = new double[data.numAttributes()];    Enumeration attEnum = data.enumerateAttributes();    while (attEnum.hasMoreElements()) {      Attribute att = (Attribute) attEnum.nextElement();      infoGains[att.index()] = computeInfoGain(data, att);    }    m_Attribute = data.attribute(Utils.maxIndex(infoGains));

这块就是计算每个属性的InfoGain,选出对应最大那个作为分裂属性。简单易懂的代码!(待会看看computeInfoGain这个函数)

// Make leaf if information gain is zero.     // Otherwise create successors.    if (Utils.eq(infoGains[m_Attribute.index()], 0)) {      m_Attribute = null;      m_Distribution = new double[data.numClasses()];      Enumeration instEnum = data.enumerateInstances();      while (instEnum.hasMoreElements()) {        Instance inst = (Instance) instEnum.nextElement();        m_Distribution[(int) inst.classValue()]++;      }      Utils.normalize(m_Distribution);      m_ClassValue = Utils.maxIndex(m_Distribution);      m_ClassAttribute = data.classAttribute();    } else {      Instances[] splitData = splitData(data, m_Attribute);      m_Successors = new Id3[m_Attribute.numValues()];      for (int j = 0; j < m_Attribute.numValues(); j++) {        m_Successors[j] = new Id3();        m_Successors[j].makeTree(splitData[j]);      }    }

第一个判断就是问此时InfoGain是否为0,如果InfoGain=0,那么意味着这个时候,该节点已经是叶子(因为全部样本属于同一个class了!)。

那么,开始计算m_Distribution,其实这个m_Distribution没啥太大用处,因为这个子树的样本肯定属于同一类,其他类全是0.

如果InfoGain!=0,意味着还需要继续分类。那么,我们已经知道要分类的属性了,接下来只要根据该属性,将原来的数据分成几个部分(该属性有几种取值,就分成几个),然后再递归地调用makeTree即可。用m_Successors存储所有子树。 

computeInfoGain:

private double computeInfoGain(Instances data, Attribute att)     throws Exception {    double infoGain = computeEntropy(data);    //若att有k种取值,则分成k个部分    Instances[] splitData = splitData(data, att);    for (int j = 0; j < att.numValues(); j++) {      if (splitData[j].numInstances() > 0) {        infoGain -= ((double) splitData[j].numInstances() /                     (double) data.numInstances()) *          computeEntropy(splitData[j]);      }    }    return infoGain;  }

这个也很容易看懂,只要知道infoGain的计算公式:

H(D)就是该属性的熵(这个就不说了)

computeEntropy:

<span style="font-size:14px;">private double computeEntropy(Instances data) throws Exception {    //统计每个类各有多少样本    double [] classCounts = new double[data.numClasses()];    Enumeration instEnum = data.enumerateInstances();    while (instEnum.hasMoreElements()) {      Instance inst = (Instance) instEnum.nextElement();      classCounts[(int) inst.classValue()]++;    }    double entropy = 0;    for (int j = 0; j < data.numClasses(); j++) {      //classCounts等于0,那么这部分pi*log(pi)=0      if (classCounts[j] > 0) {        entropy -= classCounts[j] * Utils.log2(classCounts[j]);      }    }    //之前if里没有包含分母,这里除以原来公式中的分母    entropy /= (double) data.numInstances();    return entropy + Utils.log2(data.numInstances());  }</span>

都在注释里了。

splitData:

private Instances[] splitData(Instances data, Attribute att) {    Instances[] splitData = new Instances[att.numValues()];    for (int j = 0; j < att.numValues(); j++) {      //初始化,把数据信息给子树,这里不是复制data给splitData!      splitData[j] = new Instances(data, data.numInstances());    }    Enumeration instEnum = data.enumerateInstances();    while (instEnum.hasMoreElements()) {      Instance inst = (Instance) instEnum.nextElement();      //inst.value(att)返回的是inst对应该属性的值      splitData[(int) inst.value(att)].add(inst);    }    for (int i = 0; i < splitData.length; i++) {      splitData[i].compactify();    }    return splitData;  }

这里基本也是挺直观的,建立一个Instances数组,然后每个坑存放一个子集。这里这个inst.value(att)有点不理解的地方,也就是说,他已经把每个属性的值转换到0-k了。

那个compactify()就是把信息改下,使得数据和info对应起来。

weka的ID3基本就是这些函数了,然后我有个最大的感觉就是他处理的数据形式有限。目前还没找到,如何处理numeric的代码~~~奇怪奇怪!!!


0 0
原创粉丝点击