逻辑回归的数学推导及java代码实现
来源:互联网 发布:2017流行语言网络词 编辑:程序博客网 时间:2024/05/22 04:00
逻辑回归
1.1逻辑回归的概念
1.2逻辑回归的数学表达式
1.3逻辑回归的java代码实现
importjava.util.ArrayList;public class Matrix { public ArrayList<ArrayList<String>> data; public Matrix(){ data = new ArrayList<ArrayList<String>>(); }}import java.util.ArrayList;public class CreateDataSet extends Matrix { public ArrayList<String>lables; public CreateDataSet(){ super();lables = new ArrayList<String>(); } public void initTest(){ }}import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;import java.util.ArrayList;public class Logistic { public static void main(String[] args) {colicTest(); } /** * @author haolidong * @Description: [逻辑回归的简单测试] */ public static void LogisticTest() { // TODO Auto-generated method stubCreateDataSetdataSet = new CreateDataSet();dataSet = readFile("I:\\machinelearninginaction\\Ch05\\testSet.txt");ArrayList<Double> weights = new ArrayList<Double>(); weights = gradAscent1(dataSet, dataSet.lables, 150); for (inti = 0; i< 3; i++) {System.out.println(weights.get(i)); }System.out.println(); } /** * @paraminX * @param weights * @return * @author haolidong * @Description: [sigmod分类] */ public static String classifyVector(ArrayList<String>inX, ArrayList<Double> weights) {ArrayList<Double> sum = new ArrayList<>();sum.clear();sum.add(0.0); for (inti = 0; i<inX.size(); i++) {sum.set(0, sum.get(0) + Double.parseDouble(inX.get(i)) * weights.get(i)); } if (sigmoid(sum).get(0) > 0.5) return "1"; else return "0"; } /** * @author haolidong * @Description: [预测马的疝气病的死亡率] */ public static void colicTest() {CreateDataSettrainingSet = new CreateDataSet();CreateDataSettestSet = new CreateDataSet();trainingSet = readFile("I:\\machinelearninginaction\\Ch05\\horseColicTraining.txt");testSet = readFile("I:\\machinelearninginaction\\Ch05\\horseColicTest.txt");ArrayList<Double> weights = new ArrayList<Double>(); weights = gradAscent1(trainingSet, trainingSet.lables, 500);interrorCount = 0; for (inti = 0; i<testSet.data.size(); i++) { if (!classifyVector(testSet.data.get(i), weights).equals(testSet.lables.get(i))) {errorCount++; }System.out.println(classifyVector(testSet.data.get(i), weights) + "," + testSet.lables.get(i)); }System.out.println(1.0 * errorCount / testSet.data.size()); } /** * @paraminX * @return * @author haolidong * @Description: [sigmod函数] */ public static ArrayList<Double> sigmoid(ArrayList<Double>inX) {ArrayList<Double>inXExp = new ArrayList<Double>(); for (inti = 0; i<inX.size(); i++) {inXExp.add(1.0 / (1 + Math.exp(-inX.get(i)))); } return inXExp; } /** * @paramdataSet * @paramclassLabels * @paramnumberIter * @return * @author haolidong * @Description: [改进的随机梯度上升算法] */ public static ArrayList<Double> gradAscent1(Matrix dataSet, ArrayList<String>classLabels, intnumberIter) {int m = dataSet.data.size();int n = dataSet.data.get(0).size(); double alpha = 0.0;intrandIndex = 0;ArrayList<Double> weights = new ArrayList<Double>();ArrayList<Double>weightstmp = new ArrayList<Double>();ArrayList<Double> h = new ArrayList<Double>();ArrayList<Integer>dataIndex = new ArrayList<Integer>();ArrayList<Double>dataMatrixMulweights = new ArrayList<Double>(); for (inti = 0; i< n; i++) {weights.add(1.0);weightstmp.add(1.0); }dataMatrixMulweights.add(0.0); double error = 0.0; for (int j = 0; j <numberIter; j++) { // 产生0到99的数组for (int p = 0; p < m; p++) {dataIndex.add(p); } // 进行每一次的训练for (inti = 0; i< m; i++) { alpha = 4 / (1.0 + i + j) + 0.0001;randIndex = (int) (Math.random() * dataIndex.size());dataIndex.remove(randIndex); double temp = 0.0; for (int k = 0; k < n; k++) { temp = temp + Double.parseDouble(dataSet.data.get(randIndex).get(k)) * weights.get(k); }dataMatrixMulweights.set(0, temp); h = sigmoid(dataMatrixMulweights); error = Double.parseDouble(classLabels.get(randIndex)) - h.get(0); double tempweight = 0.0; for (int p = 0; p < n; p++) {tempweight = alpha * Double.parseDouble(dataSet.data.get(randIndex).get(p)) * error;weights.set(p, weights.get(p) + tempweight); } } } return weights; }public static CreateDataSetreadFile(String fileName) { File file = new File(fileName);BufferedReader reader = null;CreateDataSetdataSet = new CreateDataSet(); try { reader = new BufferedReader(new FileReader(file)); String tempString = null; // 一次读入一行,直到读入null为文件结束while ((tempString = reader.readLine()) != null) { // 显示行号String[] strArr = tempString.split("\t");ArrayList<String> as = new ArrayList<String>();as.add("1"); for (inti = 0; i<strArr.length - 1; i++) {as.add(strArr[i]); }dataSet.data.add(as);dataSet.lables.add(strArr[strArr.length - 1]); }reader.close(); } catch (IOException e) {e.printStackTrace(); } finally { if (reader != null) { try {reader.close(); } catch (IOException e1) { } } } return dataSet; }}
1.4测试结果
输入部分测试数据如下:
0 83 0 -0.7 0
0 77.39996 0 -6.3 0
1 83 0 -0.7 0
0 82.29999 0 -1.4 0
1 66.89996 0 -16.8 0
0 81 0 -2.7 0
0 87.39996 1 3.699999 0
0 82.79999 0 -0.9 0
0 84.29999 0 0.6 0
1 80.69995 0 -3 0
0 88.5 0 4.799999 0
0 80.09998 0 -3.6 0
0 83.19995 0 -0.5 0
0 88.5 0 4.799999 0
0 79.39996 0 -4.3 0
0 82.29999 0 -1.4 0
0 78.59998 0 -5.1 0
0 82.09998 0 -1.6 0
0 84.59998 0 0.9 0
0 78.19995 0 -5.5 0
0 83.69995 1 0 0
0 73.89996 0 -9.8 0
0 89.5 1 5.799999 0
0 81.29999 0 -2.4 0
0 83.09998 0 -0.6 0
第1列为lebels标签,第2、3、4、5列为属性。
输出结果为:
回归与分类
回归问题通常是用来预测一个值,如预测房价、未来的天气情况等等,例如一个产品的实际价格为500元,通过回归分析预测值为499元,我们认为这是一个比较好的回归分析。一个比较常见的回归算法是线性回归算法(LR)。另外,回归分析用在神经网络上,其最上层是不需要加上softmax函数的,而是直接对前一层累加即可。回归是对真实值的一种逼近预测。
分类问题是用于将事物打上一个标签,通常结果为离散值。例如判断一幅图片上的动物是一只猫还是一只狗,分类通常是建立在回归之上,分类的最后一层通常要使用softmax函数进行判断其所属类别。分类并没有逼近的概念,最终正确结果只有一个,错误的就是错误的,不会有相近的概念。最常见的分类方法是逻辑回归,或者叫逻辑分类。
- 逻辑回归的数学推导及java代码实现
- 十七、逻辑回归公式的数学推导
- Logistic回归代价函数的数学推导及实现
- 逻辑回归及其数学推导
- 逻辑回归概念及推导
- 线性回归的推导与java代码
- logistic回归的数学推导
- 逻辑回归的相关问题及java实现
- 逻辑回归模型推导及梯度下降
- 逻辑回归原理及推导过程
- logsit回归代码的推导
- logsit回归代码的推导
- 逻辑回归推导
- 逻辑回归推导
- 机器学习:逻辑回归原理及实现代码
- 机器学习逻辑回归:原理解析及代码实现
- 《PRML》Logistic回归(逻辑回归,LR)的推导
- 逻辑回归的实现
- 虚拟机安装centos7.2后遗留网络问题解决方法
- 网弧与红鸟的一些bug
- numpy基础教程—矩阵的简单属性和方法
- Intellij IDEA 中一次性折叠所有Java代码的快捷键设置 collapse all
- IntentService与HandlerThread源码解析
- 逻辑回归的数学推导及java代码实现
- Python的30个编程技巧
- StringBuffer类的常用方法
- Django框架学习笔记(13.获取单表单数据的三种方式)
- PHP并行查询MySQL
- Meterpreter常用命令介绍
- SpringBoot配置Redis连接池
- 关于对equals的源码分析
- java语言选择排序详解