marchine learning 之交叉验证

来源:互联网 发布:淘宝客佣金在哪里设置 编辑:程序博客网 时间:2024/06/06 07:28

评价指标
这里写图片描述

public class PerformanceMeasure {    /**     * 1、 FN:False Negative,被判定为负样本,但事实上是正样本。     * 2、 FP:False Positive,被判定为正样本,但事实上是负样本。     * 3、TN:True Negative,被判定为负样本,事实上也是负样本。     * 4、TP:True Positive,被判定为正样本,事实上也是正样本。     * 5、precesion:查准率     * 即在检索后返回的结果中,真正正确的个数占整个结果的比例。     * precesion = TP/(TP+FP) 。     * 6、 recall:查全率     * 即在检索结果中真正正确的个数 占整个数据集(检索到的和未检索到的)中真正正确个数的比例。     * recall = TP/(TP+FN)即,检索结果中,你判断为正的样本也确实为正的,     * 以及那些没在检索结果中被你判断为负但是事实上是正的(FN)。     * 7、F-Measure     * 是Precision和Recall加权调和平均     * P和R指标有时候会出现的矛盾的情况,这样就需要综合考虑他们,最常见的方法就是F-Measure(又称为F-Score)。     * MCC马修斯相关系数     * 衡量非平衡数据集的指标     * MCC = (TP*TN - FP*FN)/((TP+FP)*(Tp+FN)*(TN+FP)*(TN+FN))^0.5     */    public double tp;    public double fp;    public double tn;    public double fn;    /**     * 计算马修斯相关系数     *     * @return     */    public double getCorrelationCoefficient() {        return (tp * fn - fp * fn) / Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));    }    public PerformanceMeasure(double tp, double tn, double fp, double fn) {        this.tp = tp;        this.tn = tn;        this.fp = fp;        this.tn = tn;    }    public PerformanceMeasure() {        this(0, 0, 0, 0);    }    public double getTPRate() {        return this.tp / (this.tp + this.fn);    }    public double getTNRate() {        return this.tn / (this.tn + this.fp);    }    public double getFNRate() {        return this.fn / (this.tp + this.fn);    }    public double getFPRate() {        return this.fp / (this.fp + this.tn);    }    public double getErrorRate() {        return (this.fp + this.fn) / this.getTotal();    }    public double getAccuracy() {        return (this.tp + this.tn) / this.getTotal();    }    public double getRecall() {        return this.tp / (this.tp + this.fn);    }    public double getPrecision() {        return this.tp / (this.tp + this.fp);    }    public double getCost() {        return fp / tp;    }    public double getTotal() {        return fp + fn + tp + tn;    }    /**     * 返回F-Measure     *     * @return     */    public double getFMeasure(){        double fMeasure = this.getRecall() * this.getPrecision()                / 2*(this.getRecall() + this.getPrecision());        if (Double.isNaN(fMeasure))            return 0;        else            return fMeasure;    }    @Override    public String toString() {        return "[TP=" + this.tp + ", FP=" + this.fp + ", TN=" + this.tn + ", FN=" + this.fn + "]";    }}

交叉验证

import com.javaPractice.MachineLearning.classification.Classifier;import com.javaPractice.MachineLearning.core.DataSet;import com.javaPractice.MachineLearning.core.DefaultDataset;import com.javaPractice.MachineLearning.core.Instance;import java.util.HashMap;import java.util.Map;import java.util.Random;/** * 交叉验证 */public class CrossValidation {    public Classifier classifier;    public CrossValidation(Classifier classifier) {        this.classifier = classifier;    }    /**     * 交叉验证     *     * @param dataSet     * @param numFolds     * @param random     * @return     */    public Map<Object, PerformanceMeasure> crossValidation(DataSet dataSet, int numFolds, Random random) {        DataSet[] folds = dataSet.folds(numFolds, random);        Map<Object, PerformanceMeasure> out = new HashMap<Object, PerformanceMeasure>();        for (Object obj : folds) {            //分类器的表现性            out.put(obj, new PerformanceMeasure());        }        for (int i = 0; i < numFolds; i++) {            DataSet validation = folds[i];            DataSet trainData = new DefaultDataset();            for (int j = 0; j < numFolds; j++) {                if (i != j) {                    trainData.addAll(folds[i]);                }            }            //建立分类器            classifier.buildClassifier(trainData);            /**             * 1.预测类别 == 实例类别(预测正确)             *  预测为正 实际为正   TP             *  预测为负 实际为负   TN             * 2.预测列表 != 实例类别(预测失败)             *  预测为正 实际为负   FP             *  预测为负 实际为正   FN             */            for (Instance instance : validation) {                //对分类实例进行预测 类别名称                Object prediction = classifier.classify(instance);                //实例的类别==预测类别                if (instance.classValue().equals(prediction)) {//prediction == class                    for (Object o : out.keySet()) {    //类别                        //如果预测数据的类别和该实例的类别相等                        if (o.equals(instance.classValue())) {                            //tp++正类预测正确                            out.get(o).tp++;                        } else {//负类预测正确                            out.get(o).tn++;                        }                    }                } else {//prediction != class                    for (Object o : out.keySet()) {                        //prediction is positive class                        //预测为正,实际为负                        if (prediction.equals(o)) {                            out.get(0).fp++;                        } else if (o.equals(instance.classValue())) {                            out.get(o).fn++;                        } else {//没有正类                            out.get(o).tn++;                        }                    }                }            }        }        return out;    }}
0 0