机器学习之logistic回归算法的java实现

来源:互联网 发布:js设置p标签的值 编辑:程序博客网 时间:2024/05/17 04:29
package logistc;import java.io.BufferedReader;import java.io.FileInputStream;import java.io.IOException;import java.io.InputStreamReader;import java.util.ArrayList;public class Mian {public static void main(String[] args) throws IOException {String str=null;ArrayList<ArrayList<Double>> datas=new ArrayList<ArrayList<Double>>();ArrayList<ArrayList<Double>> test=new ArrayList<ArrayList<Double>>();try {    //读取训练集数据训练参数向量FileInputStream fis = new FileInputStream("C:\\Users\\zfw\\Desktop\\java项目\\datas.txt");         InputStreamReader isr = new InputStreamReader(fis, "UTF-8");         BufferedReader br = new BufferedReader(isr);         while((str=br.readLine())!=null) {        String[] strs=str.split(",");        ArrayList<Double> array=new ArrayList<Double>();        array.add(1.0);        for(int i=0;i<strs.length;i++) {        array.add(Double.parseDouble(strs[i]));        //System.out.println(strs[i]);        }        datas.add(array);        }        br.close();        FileInputStream fis1 = new FileInputStream("C:\\Users\\zfw\\Desktop\\java项目\\test.txt");         InputStreamReader isr1 = new InputStreamReader(fis1, "UTF-8");         BufferedReader br1 = new BufferedReader(isr1);         while((str=br1.readLine())!=null) {        String[] strs=str.split(",");        ArrayList<Double> array=new ArrayList<Double>();        for(int i=0;i<strs.length;i++) {        array.add(Double.parseDouble(strs[i]));        //System.out.println(strs[i]);        }        test.add(array);        }        br1.close();}catch(IOException ioe) {System.out.println("错误!"+ioe);}Logistic l=new Logistic(datas,test);l.print();l.predect(test);}}


package logistc;import java.util.ArrayList;public class Logistic {private ArrayList<ArrayList<Double>> datas=new ArrayList<ArrayList<Double>>();//训练集private double alph=0.001;private Double[] b;//参数向量public Logistic(ArrayList<ArrayList<Double>> datas,ArrayList<ArrayList<Double>> test){this.datas=datas;init(datas);}public void init(ArrayList<ArrayList<Double>> datas){//初始化参数向量b=new Double[this.datas.get(0).size()-1];System.out.println(b.length);for(int i=0;i<b.length;i++) {b[i]=1.0;}}public double h_theta_x_i(int j) {//预测分类函数double c=1.0;for(int i=1;i<this.b.length;i++) {c+=this.b[i]*this.datas.get(j).get(i);}return 1.0/(1+Math.exp(0.0-c));}public double compute_partial_derivative_for_theta(int j) {//求thetaj的偏导double sum=0.0;for(int  i=0;i<this.datas.size();i++) {sum+=(datas.get(i).get(datas.get(0).size()-1)-h_theta_x_i(i))*datas.get(i).get(j);}return sum;}public void compute_theta() {//迭代求thetafor(int i=1;i<b.length;i++) {b[i]+=this.alph*compute_partial_derivative_for_theta(i);}}public void print() {int a=1000000;while(a>0) {a--;compute_theta();System.out.print(a+"theta:");for(int i=0;i<b.length;i++) {System.out.print(b[i]+"\t");}System.out.println();}}public void predect(ArrayList<ArrayList<Double>> test) {int count=0;double sum=0.0;for(int i=0;i<test.size();i++) {for(int j=0;j<test.get(0).size()-1;j++) {sum+=this.b[j+1]*test.get(i).get(j);}if((1.0/(1+Math.exp(0.0-sum)))>0.5) {System.out.print(1);if(test.get(i).get((test.get(i).size()-1))==1.0)count++;}else {System.out.print(0);if(test.get(i).get((test.get(i).size()-1))==0.0)count++;}}System.out.println("正确率为:"+(double)count/test.size()*100+"%");}}


原创粉丝点击