逻辑回归的数学推导及java代码实现

来源:互联网 发布:2017流行语言网络词 编辑:程序博客网 时间:2024/05/22 04:00

逻辑回归

1.1逻辑回归的概念

这里写图片描述
这里写图片描述
这里写图片描述

1.2逻辑回归的数学表达式

这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述
这里写图片描述

1.3逻辑回归的java代码实现

importjava.util.ArrayList;public class Matrix {    public ArrayList<ArrayList<String>> data;    public Matrix(){        data = new ArrayList<ArrayList<String>>();    }}import java.util.ArrayList;public class CreateDataSet extends Matrix {    public ArrayList<String>lables;    public CreateDataSet(){        super();lables = new ArrayList<String>();    }    public void initTest(){    }}import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;import java.util.ArrayList;public class Logistic {    public static void main(String[] args) {colicTest();    }    /**     * @author haolidong     * @Description: [逻辑回归的简单测试]     */    public static void LogisticTest() {        // TODO Auto-generated method stubCreateDataSetdataSet = new CreateDataSet();dataSet = readFile("I:\\machinelearninginaction\\Ch05\\testSet.txt");ArrayList<Double> weights = new ArrayList<Double>();        weights = gradAscent1(dataSet, dataSet.lables, 150);        for (inti = 0; i< 3; i++) {System.out.println(weights.get(i));        }System.out.println();    }    /**     * @paraminX     * @param weights     * @return     * @author haolidong     * @Description: [sigmod分类]     */    public static String classifyVector(ArrayList<String>inX, ArrayList<Double> weights) {ArrayList<Double> sum = new ArrayList<>();sum.clear();sum.add(0.0);        for (inti = 0; i<inX.size(); i++) {sum.set(0, sum.get(0) + Double.parseDouble(inX.get(i)) * weights.get(i));        }        if (sigmoid(sum).get(0) > 0.5)            return "1";   else            return "0";    }    /**     * @author haolidong     * @Description: [预测马的疝气病的死亡率]     */    public static void colicTest() {CreateDataSettrainingSet = new CreateDataSet();CreateDataSettestSet = new CreateDataSet();trainingSet = readFile("I:\\machinelearninginaction\\Ch05\\horseColicTraining.txt");testSet = readFile("I:\\machinelearninginaction\\Ch05\\horseColicTest.txt");ArrayList<Double> weights = new ArrayList<Double>();        weights = gradAscent1(trainingSet, trainingSet.lables, 500);interrorCount = 0;        for (inti = 0; i<testSet.data.size(); i++) {            if (!classifyVector(testSet.data.get(i), weights).equals(testSet.lables.get(i))) {errorCount++;            }System.out.println(classifyVector(testSet.data.get(i), weights) + "," + testSet.lables.get(i));        }System.out.println(1.0 * errorCount / testSet.data.size());    }    /**     * @paraminX     * @return     * @author haolidong     * @Description: [sigmod函数]     */    public static ArrayList<Double> sigmoid(ArrayList<Double>inX) {ArrayList<Double>inXExp = new ArrayList<Double>();        for (inti = 0; i<inX.size(); i++) {inXExp.add(1.0 / (1 + Math.exp(-inX.get(i)))); }        return inXExp;    }    /**     * @paramdataSet     * @paramclassLabels     * @paramnumberIter     * @return     * @author haolidong     * @Description: [改进的随机梯度上升算法]     */    public static ArrayList<Double> gradAscent1(Matrix dataSet, ArrayList<String>classLabels, intnumberIter) {int m = dataSet.data.size();int n = dataSet.data.get(0).size();        double alpha = 0.0;intrandIndex = 0;ArrayList<Double> weights = new ArrayList<Double>();ArrayList<Double>weightstmp = new ArrayList<Double>();ArrayList<Double> h = new ArrayList<Double>();ArrayList<Integer>dataIndex = new ArrayList<Integer>();ArrayList<Double>dataMatrixMulweights = new ArrayList<Double>();        for (inti = 0; i< n; i++) {weights.add(1.0);weightstmp.add(1.0);        }dataMatrixMulweights.add(0.0);        double error = 0.0;        for (int j = 0; j <numberIter; j++) {            // 产生0到99的数组for (int p = 0; p < m; p++) {dataIndex.add(p);            }            // 进行每一次的训练for (inti = 0; i< m; i++) {                alpha = 4 / (1.0 + i + j) + 0.0001;randIndex = (int) (Math.random() * dataIndex.size());dataIndex.remove(randIndex);                double temp = 0.0;                for (int k = 0; k < n; k++) {                    temp = temp + Double.parseDouble(dataSet.data.get(randIndex).get(k)) * weights.get(k);                }dataMatrixMulweights.set(0, temp);                h = sigmoid(dataMatrixMulweights);                error = Double.parseDouble(classLabels.get(randIndex)) - h.get(0);                double tempweight = 0.0;                for (int p = 0; p < n; p++) {tempweight = alpha * Double.parseDouble(dataSet.data.get(randIndex).get(p)) * error;weights.set(p, weights.get(p) + tempweight);                }            }        }        return weights;    }public static CreateDataSetreadFile(String fileName) {        File file = new File(fileName);BufferedReader reader = null;CreateDataSetdataSet = new CreateDataSet();        try {            reader = new BufferedReader(new FileReader(file));            String tempString = null;            // 一次读入一行,直到读入null为文件结束while ((tempString = reader.readLine()) != null) {                // 显示行号String[] strArr = tempString.split("\t");ArrayList<String> as = new ArrayList<String>();as.add("1");                for (inti = 0; i<strArr.length - 1; i++) {as.add(strArr[i]);                }dataSet.data.add(as);dataSet.lables.add(strArr[strArr.length - 1]);            }reader.close();        } catch (IOException e) {e.printStackTrace();        } finally {            if (reader != null) {                try {reader.close();                } catch (IOException e1) {                }            }        }        return dataSet;    }}

1.4测试结果

输入部分测试数据如下:
0 83 0 -0.7 0
0 77.39996 0 -6.3 0
1 83 0 -0.7 0
0 82.29999 0 -1.4 0
1 66.89996 0 -16.8 0
0 81 0 -2.7 0
0 87.39996 1 3.699999 0
0 82.79999 0 -0.9 0
0 84.29999 0 0.6 0
1 80.69995 0 -3 0
0 88.5 0 4.799999 0
0 80.09998 0 -3.6 0
0 83.19995 0 -0.5 0
0 88.5 0 4.799999 0
0 79.39996 0 -4.3 0
0 82.29999 0 -1.4 0
0 78.59998 0 -5.1 0
0 82.09998 0 -1.6 0
0 84.59998 0 0.9 0
0 78.19995 0 -5.5 0
0 83.69995 1 0 0
0 73.89996 0 -9.8 0
0 89.5 1 5.799999 0
0 81.29999 0 -2.4 0
0 83.09998 0 -0.6 0
第1列为lebels标签,第2、3、4、5列为属性。
输出结果为:
这里写图片描述

回归与分类

    回归问题通常是用来预测一个值,如预测房价、未来的天气情况等等,例如一个产品的实际价格为500元,通过回归分析预测值为499元,我们认为这是一个比较好的回归分析。一个比较常见的回归算法是线性回归算法(LR)。另外,回归分析用在神经网络上,其最上层是不需要加上softmax函数的,而是直接对前一层累加即可。回归是对真实值的一种逼近预测。
    分类问题是用于将事物打上一个标签,通常结果为离散值。例如判断一幅图片上的动物是一只猫还是一只狗,分类通常是建立在回归之上,分类的最后一层通常要使用softmax函数进行判断其所属类别。分类并没有逼近的概念,最终正确结果只有一个,错误的就是错误的,不会有相近的概念。最常见的分类方法是逻辑回归,或者叫逻辑分类。