应用libsvm对训练集进行训练并测试得出正确率和召回率

来源:互联网 发布:遗传算法应用领域 编辑:程序博客网 时间:2024/04/29 05:22

package org.lw.fenlei;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;

import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;

public class Test {

 
 @SuppressWarnings("resource")
 public static void main(String[] args) throws Exception{
  // 定义训练集点a{10.0, 10.0} 和 点b{-10.0, -10.0},对应label为{1.0, -1.0}
 // svm_node pa0 = new svm_node();
 // pa0.index = 0;
 // pa0.value = 10.0;
 // svm_node pa1 = new svm_node();
 // pa1.index = -1;
 // pa1.value = 10.0;
 // svm_node pb0 = new svm_node();
 // pb0.index = 0;
 // pb0.value = -10.0;
 // svm_node pb1 = new svm_node();
 // pb1.index = 0;
 // pb1.value = -10.0;
 // svm_node[] pa = {pa0,pa1};//点a
 // svm_node[] pb = {pb0,pb1};//点b
 // svm_node[][] datas = {pa,pb};//训练集的向量表
 // double[] lables = {1.0,-1.0};//a,b对应的lable
  
  
  //找到字典的长度
  
  int ZDlength=0;
  File ZD = new File ("c:" + File.separator + "dm" + File.separator + "ZD.txt");
  BufferedReader ZDbuf = null;
  InputStream ZDinput = new FileInputStream(ZD);
  ZDbuf = new BufferedReader(new InputStreamReader(ZDinput));
  while(ZDbuf.readLine()!=null){
   ZDlength++;
  }
  System.out.println("字典长度为:"+ZDlength);
  
  
  
  
  //找到xunlian文件夹下有几个文件,有几个文件就是分几个类
  File file = new File ("c:" + File.separator + "dm" + File.separator + "xunlian1");
  String path[] = file.list();
  int lengths = path.length ;
  svm_node[][] datas = new svm_node[10000][ZDlength];
  for(int j=0; j<10000; j ++)
  {
    for(int i = 0; i < ZDlength; i ++) {  
     
     datas[j][i] =new svm_node();
     datas[j][i].value = 0.0;
     if(i!=(datas[j].length-1))
      datas[j][i].index = i+1;
     else
      datas[j][ZDlength-1].index=-1;
    }
  }
  double[] lables = new double[1000];
  int vectors=0;
  int lb = 0;
  double bq = 1.0;
  
 //分别读入每个文件中的向量
 for (int i = 0;i<lengths;i++){
  
  System.out.println(path[i]+" "+bq);
  
  
  BufferedReader buf = null;
  File f = new File ("c:" + File.separator + "dm" + File.separator + "xunlian1"+ File.separator + path[i]);
  InputStream input = new FileInputStream(f);
  buf = new BufferedReader(new InputStreamReader(input));
  String b;
  while((b = buf.readLine())!=null){
   
    vectors++;
    lables[lb] = bq;
    
    //System.out.println(b);
    int a;
    String[] temp = b.split(" ");
    if(temp.length==0)
     a =1;
    double[] result = new double[temp.length];
    int k;
    for(k=0; k<temp.length;k++){
     try{
     
     String vector[] = temp[k].split(",");
     
     int index = Integer.parseInt(vector[0]);
     Double value = Double.parseDouble(vector[1]);

     result[k] = value;
     
     
     if(index!=-1){
      datas[lb][index-1].value = result[k];
     // System.out.println("第"+lb+"篇文档第"+index+"个值为:"+datas[lb][index-1].value);
     }
        else{
         datas[lb][ZDlength-1].value = result[k];
         //System.out.println("第"+lb+"篇文档第"+index+"个值为:"+datas[lb][ZDlength-1].value);
        }
     }
     catch(Exception e){
     
     }
    }
    lb++;
   
  }
  bq++;
  buf.close();
  
 }
  
  //定义svm_problem对象
  svm_problem problem = new svm_problem();
  problem.l = vectors;//向量个数
  problem.x = datas;//训练集向量表
  problem.y = lables;//对应的lable数组
  
  //定义svm_parameter对象
  svm_parameter param = new svm_parameter();
  param.svm_type = svm_parameter.C_SVC;
  param.kernel_type = svm_parameter.LINEAR;
  param.cache_size = 10000;
  param.eps = 0.00001;
  param.C = 1;
  
  //训练SVM分类模型
  System.out.println(svm.svm_check_parameter(problem, param));//如果参数
  //没有问题,则该函数返回null,否则返回error描述。
  svm_model model = svm.svm_train(problem,param);//svm.svm_train()训练出SVM分类模型;
  
  //定义测试数据点c
  //svm_node pc0 = new svm_node();
  //pc0.index = 0;
  //pc0.value = -0.1;
  //svm_node pc1 = new svm_node();
  //pc1.index = 1;
  //pc1.value = -0.0;
  //svm_node pc2 = new svm_node();
  //pc2.index = 2;
  //pc2.value = -6.0;
  //svm_node pc3 = new svm_node();
  //pc3.index = 3;
  //pc3.value = -4.0;
  //svm_node pc4 = new svm_node();
  //pc4.index = 4;
  //pc4.value = -3.0;
  //svm_node pc5 = new svm_node();
  //pc1.index = -1;
  //pc1.value = 0.0;
  //svm_node[] pc = {pc0,pc1,pc2,pc3,pc4,pc5};
  
  
  //输入要分类的文本数据
  //File file1 = new File ("c:" + File.separator + "dm" + File.separator + "input.txt");
  //InputStream input1 = new FileInputStream(file1);
  //svm_node[] test = new svm_node[1000];
  //for(int n = 0;n<1000;n++)
   //test[n] = new svm_node();
  //BufferedReader buf1 = null;
  //buf1 = new BufferedReader(new InputStreamReader(input1));
  //String b1;
  //while((b1 = buf1.readLine())!=null){
   //String[] temp1 = b1.split(" ");
   //@SuppressWarnings("unused")
   //double[] result1 = new double[temp1.length];
   //for(int k1=0; k1<temp1.length;k1++){
    //result1[k1] = Double.parseDouble(temp1[k1]);
    
    //if(k1!=(temp1.length-1))
     //test[k1].index = k1;
       //else
     //test[k1].index = -1;
    
    //test[k1].value = result1[k1];
    
   //}
  //}
  
  
  //预测测试数据的lable
  //System.out.println(svm.svm_predict(model, test));
  
  //输入测试集,并计算准确率和召回率
  
  
  
  File file1 = new File ("c:" + File.separator + "dm" + File.separator + "ceshi");
  String path1[] = file1.list();
  int lengths1 = path1.length ;
  svm_node[][] test = new svm_node[10000][ZDlength];
  for(int j=0; j<1000; j ++)
  {
    for(int i = 0; i < ZDlength; i ++) {  
     
     test[j][i] =new svm_node();
     test[j][i].value = 0.0;
     if(i!=(ZDlength-1))
      test[j][i].index = i+1;
     else
      test[j][ZDlength-1].index = -1;
    }
  }
  
  
  int lb1=0;
  double cc=0.0,bq1=1.0,sumzql=0.0,sumzhl=0.0,zzql=0.0, zzhl=0.0;
  double[] zq = new double[lengths1];  //正确分得该类的篇数
  double[] zp = new double[lengths1];  //总分得该类的篇数
  double[] bp = new double[lengths1];  //测试集中该类本来有多少篇
  double[] zql = new double[lengths1];  //每个类的准确率
  double[] zhl = new double[lengths1];   //每个类的召回率
  for(int i=0;i<lengths1;i++){
   zp[i]=0.0;
  }
  
  for (int i = 0;i<lengths1;i++){
   
   bp[i]=0.0;
   zq[i]=0.0;
   
   System.out.println(path1[i]);
   
   
   BufferedReader buf1 = null;
   File f1 = new File ("c:" + File.separator + "dm" + File.separator + "ceshi"+ File.separator + path1[i]);
   InputStream input1 = new FileInputStream(f1);
   buf1 = new BufferedReader(new InputStreamReader(input1));
   
   String b1;
   while((b1 = buf1.readLine())!=null){
    String[] temp1 = b1.split(" ");
    @SuppressWarnings("unused")
    double[] result1 = new double[temp1.length];
    for(int k1=0; k1<temp1.length;k1++){
     
     String vector[] = temp1[k1].split(",");
     int index = Integer.parseInt(vector[0]);
     Double value = Double.parseDouble(vector[1]);

     result1[k1] = value;
     
     if(index!=-1)
      test[lb1][index-1].value = result1[k1];
        else
      test[lb1][ZDlength-1].value = result1[k1];
     
    }
    cc= svm.svm_predict(model, test[lb1]);
    System.out.println("第"+(lb1+1)+"篇文档所分得的的类别是"+cc);
    lb1++;
    
    if(cc==bq1){
     zq[i]++;
     zp[i]++;
    }
    else{
     zp[((int) cc)-1]++;
    }
    bp[i]++;
   }
   
   
   bq1++;
   
  }
  
  double j =1.0;
  for(int i=0;i<lengths1;i++){
   zql[i] = zq[i]/zp[i];
   System.out.println("第"+j+"类文档的准确率是"+zql[i]);
   zhl[i] = zq[i]/bp[i];
   System.out.println("第"+j+"类文档的召回率是"+zhl[i]);
   sumzql+= zql[i];
   sumzhl+=zhl[i];
   j=j+1.0;
  }
  zzql=sumzql/lengths1;
  System.out.println("总准确率是"+zzql);
  zzhl=sumzhl/lengths1;
  System.out.println("总召回率是"+zzhl);
 }

}