weka 交叉验证

来源:互联网 发布:淘宝订单能删除吗 编辑:程序博客网 时间:2024/06/02 04:15

调用weka实现交叉验证,并搭载图形界面

import weka.classifiers.Classifier;import weka.classifiers.Evaluation;import weka.classifiers.bayes.NaiveBayes;import weka.classifiers.evaluation.ThresholdCurve;import weka.core.Instances;import weka.core.Utils;import weka.gui.visualize.Plot2D;import weka.gui.visualize.PlotData2D;import weka.gui.visualize.ThresholdVisualizePanel;import javax.swing.*;import java.awt.*;import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.util.Arrays;import java.util.Random;public class ROC {    public static void main(String[] args) throws Exception {        String filePath = "d:/data/segment-challenge.arff";        BufferedReader buf = new BufferedReader(new FileReader(new File(filePath)));        Instances instance = new Instances(buf);        //获取属性        instance.setClassIndex(instance.numAttributes() - 1);        //获取分类器        Classifier classifier = new NaiveBayes();        //评价分类模型        Evaluation evaluation = new Evaluation(instance);        //交叉验证  分类器,数据实例,交叉数目,        evaluation.crossValidateModel(classifier, instance, 10, new Random(1));        ThresholdCurve tc = new ThresholdCurve();        //classIndex是类作为positive的索引        int classIndex = 0;        Instances instances = tc.getCurve(evaluation.predictions(), classIndex);        System.out.println("Roc curve" + evaluation.areaUnderROC(classIndex));        //获取TP,Fp        int tpIndex = instances.attribute(ThresholdCurve.TP_RATE_NAME).index();        int fpIndex = instances.attribute(ThresholdCurve.TP_RATE_NAME).index();        double[] tpRate = instances.attributeToDoubleArray(tpIndex);        double[] fpRate = instances.attributeToDoubleArray(fpIndex);//      System.out.println(Arrays.toString(tpRate)+Arrays.toString(fpRate));        for (double tp : tpRate) {            System.out.println(tp);        }        for (double fp : fpRate) {            System.out.println(fp);        }        //使用instances对象显示ROC曲面        ThresholdVisualizePanel tvp = new ThresholdVisualizePanel();        tvp.setROCString("(Area under ROC=" +                Utils.doubleToString(tc.getROCArea(instances), 4) + ")");        tvp.setName(instances.relationName());        PlotData2D pd = new PlotData2D(instances);        tvp.addPlot(pd);        String plotName = tvp.getName();        final javax.swing.JFrame jf = new javax.swing.JFrame(                "WeKa classifier visualize:" + plotName);        jf.setSize(500, 400);        jf.getContentPane().setLayout(new BorderLayout());        jf.getContentPane().add(tvp, BorderLayout.CENTER);        jf.addWindowFocusListener(new java.awt.event.WindowAdapter() {            public void windowClosing(java.awt.event.WindowEvent e) {                jf.dispose();            }        });        jf.setVisible(true);    }}

运行结果
这里写图片描述

0 0
原创粉丝点击