Weka开发 -OneR源代码介绍

来源:互联网 发布:加泰罗尼亚理工知乎 编辑:程序博客网 时间:2024/06/04 19:49

        OneR是一个很简单的算法,出自论文:Very simple classification rules perform well on most commonly used datasets,由于论文的风格过于奔放,并且很长,所以我也就没怎么看。基本思想就是对每一个属性都建一个单层的分类器,对这些分类器进行比较,谁分类效果好就作为最终的分类器。


       public void buildClassifier(Instances instances) throws Exception


        boolean noRule = true;


        // only class? -> build ZeroR model

        if (data.numAttributes() == 1)


            m_ZeroR = new weka.classifiers.rules.ZeroR();



        } else


            m_ZeroR = null;



        // for each attribute ...

        Enumeration enu = instances.enumerateAttributes();

        while (enu.hasMoreElements())




                OneRRule r = newRule((Attribute) enu.nextElement(), data);


                // if this attribute is the best so far, replace the rule

                if (noRule || r.m_correct > m_rule.m_correct)


                    m_rule = r;


                noRule = false;

            } catch (Exception ex)





    下面看一下刚才的newRule函数,初始化一个missingValueCounts数组,数组大小为类别集合的大小。如果当前这个类别是离散的调用newNominalRule,如果是连续的调用newNumericRule。下面的几行代码现在可能还有点难理解(理解不了,看完下面的再转回来看),missingValueCounts保存的是对这个属性缺失值类别值的读数,而maxIndex函数返回的就是这个属性缺失时最有时候的类别Index。再下来If判断是否训练集中如果这个属性值缺失的样本,那么r.m_missingValueClass = -1;如果有,r.m_correct加上当这个属性缺失情况下最多出现的类别值的出现次数(没办法就是这么难表达)。

public OneRRule newRule(Attribute attr, Instances data) throws Exception



        OneRRule r;


        // ... create array to hold the missing value counts

        int[] missingValueCounts = new int[data.classAttribute().numValues()];


        if (attr.isNominal())


            r = newNominalRule(attr, data, missingValueCounts);

        } else


            r = newNumericRule(attr, data, missingValueCounts);


        r.m_missingValueClass = Utils.maxIndex(missingValueCounts);

        if (missingValueCounts[r.m_missingValueClass] == 0)


            r.m_missingValueClass = -1; // signal for no missing value class

        } else


            r.m_correct += missingValueCounts[r.m_missingValueClass];


        return r;



public OneRRule newNominalRule(Attribute attr, Instances data,

            int[] missingValueCounts) throws Exception



        // ... create arrays to hold the counts

        int[][] counts = new int[attr.numValues()][data.classAttribute().numValues()];


        // ... calculate the counts

        Enumeration enu = data.enumerateInstances();

        while (enu.hasMoreElements())


            Instance i = (Instance) enu.nextElement();

            if (i.isMissing(attr))


                missingValueCounts[(int) i.classValue()]++;

            } else


                counts[(int) i.value(attr)][(int) i.classValue()]++;




        OneRRule r = new OneRRule(data, attr); // create a new rule

        for (int value = 0; value < attr.numValues(); value++)


            int best = Utils.maxIndex(counts[value]);

            r.m_classifications[value] = best;

            r.m_correct += counts[value][best];


        return r;






public OneRRule newNumericRule(Attribute attr, Instances data,

            int[] missingValueCounts) throws Exception



        // ... can't be more than numInstances buckets

        int[] classifications = new int[data.numInstances()];

        double[] breakpoints = new double[data.numInstances()];


        // create array to hold the counts

        int[] counts = new int[data.classAttribute().numValues()];

        int correct = 0;

        int lastInstance = data.numInstances();


        // missing values get sorted to the end of the instances


        while (lastInstance > 0 && data.instance(lastInstance - 1).isMissing(attr))



            missingValueCounts[(int) data.instance(lastInstance).classValue()]++;


        int i = 0;

        int cl = 0; // index of next bucket to create

        int it;

        while (i < lastInstance)

        { // start a new bucket

            for (int j = 0; j < counts.length; j++)

                counts[j] = 0;


            { // fill it until it has enough of the majority class

                it = (int) data.instance(i++).classValue();


            } while (counts[it] < m_minBucketSize && i < lastInstance);


            // while class remains the same, keep on filling

            while (i < lastInstance && (int) data.instance(i).classValue() == it)





            while (i < lastInstance && // keep on while attr value is the same

(data.instance(i - 1).value(attr) == data.instance(i).value(attr)))


                counts[(int) data.instance(i++).classValue()]++;


            for (int j = 0; j < counts.length; j++)


                if (counts[j] > counts[it])


                    it = j;



            if (cl > 0)

            { // can we coalesce with previous class?

                if (counts[classifications[cl - 1]] == counts[it])


                    it = classifications[cl - 1];


                if (it == classifications[cl - 1])


                    cl--; // yes!



            correct += counts[it];

            classifications[cl] = it;

            if (i < lastInstance)


breakpoints[cl] = (data.instance(i - 1).value(attr) + data.instance(i).value(attr)) / 2;




        if (cl == 0)


            throw new Exception("Only missing values in the training data!");


        OneRRule r = new OneRRule(data, attr, cl); // new rule with cl branches

        r.m_correct = correct;

        for (int v = 0; v < cl; v++)


            r.m_classifications[v] = classifications[v];

            if (v < cl - 1)


                r.m_breakpoints[v] = breakpoints[v];




        return r;






public double classifyInstance(Instance inst) throws Exception



        // default model?

        if (m_ZeroR != null)


            return m_ZeroR.classifyInstance(inst);



        int v = 0;

        if (inst.isMissing(m_rule.m_attr))


            if (m_rule.m_missingValueClass != -1)


                return m_rule.m_missingValueClass;

            } else


                return 0; // missing values occur in test but not training set   



        if (m_rule.m_attr.isNominal())


            v = (int) inst.value(m_rule.m_attr);

        } else


            while (v < m_rule.m_breakpoints.length

                    && inst.value(m_rule.m_attr) >= m_rule.m_breakpoints[v])





        return m_rule.m_classifications[v];

