Weka 学习 ID3

来源:互联网 发布:味蕾之诗淘宝不卖了吗 编辑:程序博客网 时间:2024/05/17 04:21

ID3算法相对简单,weka的实现也容易理解。首先介绍一下大致算法。算法概述如下。

1.选择一种度量(ID3选择的是信息增益),计算每个属性对于该度量的值。

2.根据结果选择一个属性进行分支。

3.如果每个分支全部属于一个类或者已经没有候选属性。则停止,否则对每个分支进行1,2操作。

下面对weka的ID3 class 作介绍,主要涉及到makeTree(Instances data),computeInfoGain(data, att),splitData(Instances data, Attribute att)三个函数。其中makeTree是入口函数,computeInfoGain的作用是计算信息增益,splitData的作用是分支。首先看makeTree函数。

private void makeTree(Instances data) throws Exception {    // Check if no instances have reached this node.    if (data.numInstances() == 0) {      m_Attribute = null;      m_ClassValue = Instance.missingValue();      m_Distribution = new double[data.numClasses()];      return;    }    // Compute attribute with maximum information gain.    double[] infoGains = new double[data.numAttributes()];    Enumeration attEnum = data.enumerateAttributes();    /**     * 对每个属性计算信息增益     */    while (attEnum.hasMoreElements()) {      Attribute att = (Attribute) attEnum.nextElement();      infoGains[att.index()] = computeInfoGain(data, att);    }    m_Attribute = data.attribute(Utils.maxIndex(infoGains));        // Make leaf if information gain is zero.     // Otherwise create successors.    if (Utils.eq(infoGains[m_Attribute.index()], 0)) {      m_Attribute = null;      m_Distribution = new double[data.numClasses()];      Enumeration instEnum = data.enumerateInstances();      while (instEnum.hasMoreElements()) {        Instance inst = (Instance) instEnum.nextElement();        m_Distribution[(int) inst.classValue()]++;      }      Utils.normalize(m_Distribution);      m_ClassValue = Utils.maxIndex(m_Distribution);      m_ClassAttribute = data.classAttribute();    } else {      Instances[] splitData = splitData(data, m_Attribute);      m_Successors = new Id3[m_Attribute.numValues()];      /**       * 这里对每个分支继续调用id3.makeTree(instatnces)。       */      for (int j = 0; j < m_Attribute.numValues(); j++) {        m_Successors[j] = new Id3();        m_Successors[j].makeTree(splitData[j]);      }    }  }


通过注释,应该不难理解大致过程。这里需要注意的是 程序里经常会出现Enumeration,这其实就是现在的Ieratorer,当时jdk版本较低,所以用的Enumeration,忽视掉就好了。

下面是splitData。只是按照类的值进行分支,也很容易理解。

  private double computeInfoGain(Instances data, Attribute att)     throws Exception {    double infoGain = computeEntropy(data);    Instances[] splitData = splitData(data, att);    for (int j = 0; j < att.numValues(); j++) {      if (splitData[j].numInstances() > 0) {        infoGain -= ((double) splitData[j].numInstances() /                     (double) data.numInstances()) *          computeEntropy(splitData[j]);      }    }    return infoGain;  }


至于 computeInfoGain对照公式就很容易理解了。这里只贴出代码

  private double computeInfoGain(Instances data, Attribute att)     throws Exception {    double infoGain = computeEntropy(data);    Instances[] splitData = splitData(data, att);    for (int j = 0; j < att.numValues(); j++) {      if (splitData[j].numInstances() > 0) {        infoGain -= ((double) splitData[j].numInstances() /                     (double) data.numInstances()) *          computeEntropy(splitData[j]);      }    }    return infoGain;  }


原创粉丝点击