机器学习实战逻辑回归的java实现

来源:互联网 发布:反电信网络诈骗宣传片 编辑:程序博客网 时间:2024/06/11 20:22
<pre name="code" class="java">package com.haolidong.Logistic;import java.util.ArrayList;/** *  * @author haolidong * @Description: [该类主要用于保存特征信息] * @parameter data: [主要保存特征矩阵] */public class Matrix {<span style="white-space:pre"></span>public ArrayList<ArrayList<String>> data;<span style="white-space:pre"></span>public Matrix() {<span style="white-space:pre"></span>// TODO Auto-generated constructor stub<span style="white-space:pre"></span>data = new ArrayList<ArrayList<String>>();<span style="white-space:pre"></span>}}
<pre name="code" class="java">package com.haolidong.Logistic;import java.util.ArrayList;/** *  * @author haolidong * @Description: [该类主要用于保存特征信息以及标签值] * @parameter labels: [主要保存标签值] */public class CreateDataSet extends Matrix {public ArrayList<String> labels;public CreateDataSet() {// TODO Auto-generated constructor stubsuper();labels = new ArrayList<String>();}/** * @author haolidong * @Description: [机器学习实战逻辑回归第一个案例的数据] */public void initTest() {}}
package com.haolidong.Logistic;import java.io.BufferedReader;import java.io.File;import java.io.FileReader;import java.io.IOException;import java.util.ArrayList;public class Logistic {public static void main(String[] args) {colicTest();}/** * @author haolidong * @Description: [逻辑回归的简单测试] */public static void LogisticTest() {// TODO Auto-generated method stubCreateDataSet dataSet = new CreateDataSet();dataSet = readFile("I:\\machinelearninginaction\\Ch05\\testSet.txt");ArrayList<Double> weights = new ArrayList<Double>();weights = gradAscent1(dataSet, dataSet.labels, 150);for (int i = 0; i < 3; i++) {System.out.println(weights.get(i));}System.out.println();}/** * @param inX * @param weights * @return * @author haolidong * @Description: [sigmod分类] */public static String classifyVector(ArrayList<String> inX, ArrayList<Double> weights) {ArrayList<Double> sum = new ArrayList<>();sum.clear();sum.add(0.0);for (int i = 0; i < inX.size(); i++) {sum.set(0, sum.get(0) + Double.parseDouble(inX.get(i)) * weights.get(i));}if (sigmoid(sum).get(0) > 0.5)return "1";elsereturn "0";}/** * @author haolidong * @Description: [预测马的疝气病的死亡率] */public static void colicTest() {CreateDataSet trainingSet = new CreateDataSet();CreateDataSet testSet = new CreateDataSet();trainingSet = readFile("I:\\machinelearninginaction\\Ch05\\horseColicTraining.txt");testSet = readFile("I:\\machinelearninginaction\\Ch05\\horseColicTest.txt");ArrayList<Double> weights = new ArrayList<Double>();weights = gradAscent1(trainingSet, trainingSet.labels, 500);int errorCount = 0;for (int i = 0; i < testSet.data.size(); i++) {if (!classifyVector(testSet.data.get(i), weights).equals(testSet.labels.get(i))) {errorCount++;}System.out.println(classifyVector(testSet.data.get(i), weights) + "," + testSet.labels.get(i));}System.out.println(1.0 * errorCount / testSet.data.size());}/** * @param inX * @return * @author haolidong * @Description: [sigmod函数] */public static ArrayList<Double> sigmoid(ArrayList<Double> inX) {ArrayList<Double> inXExp = new ArrayList<Double>();for (int i = 0; i < inX.size(); i++) {inXExp.add(1.0 / (1 + Math.exp(-inX.get(i))));}return inXExp;}/** * @param dataSet * @param classLabels * @param numberIter * @return * @author haolidong * @Description: [改进的随机梯度上升算法] */public static ArrayList<Double> gradAscent1(Matrix dataSet, ArrayList<String> classLabels, int numberIter) {int m = dataSet.data.size();int n = dataSet.data.get(0).size();double alpha = 0.0;int randIndex = 0;ArrayList<Double> weights = new ArrayList<Double>();ArrayList<Double> weightstmp = new ArrayList<Double>();ArrayList<Double> h = new ArrayList<Double>();ArrayList<Integer> dataIndex = new ArrayList<Integer>();ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();for (int i = 0; i < n; i++) {weights.add(1.0);weightstmp.add(1.0);}dataMatrixMulweights.add(0.0);double error = 0.0;for (int j = 0; j < numberIter; j++) {// 产生0到99的数组for (int p = 0; p < m; p++) {dataIndex.add(p);}// 进行每一次的训练for (int i = 0; i < m; i++) {alpha = 4 / (1.0 + i + j) + 0.0001;randIndex = (int) (Math.random() * dataIndex.size());dataIndex.remove(randIndex);double temp = 0.0;for (int k = 0; k < n; k++) {temp = temp + Double.parseDouble(dataSet.data.get(randIndex).get(k)) * weights.get(k);}dataMatrixMulweights.set(0, temp);h = sigmoid(dataMatrixMulweights);error = Double.parseDouble(classLabels.get(randIndex)) - h.get(0);double tempweight = 0.0;for (int p = 0; p < n; p++) {tempweight = alpha * Double.parseDouble(dataSet.data.get(randIndex).get(p)) * error;weights.set(p, weights.get(p) + tempweight);}}}return weights;}/** * @param dataSet * @param classLabels * @return * @author haolidong * @Description: [随机梯度上升算法] */public static ArrayList<Double> gradAscent0(Matrix dataSet, ArrayList<String> classLabels) {int m = dataSet.data.size();int n = dataSet.data.get(0).size();ArrayList<Double> weights = new ArrayList<Double>();ArrayList<Double> weightstmp = new ArrayList<Double>();ArrayList<Double> h = new ArrayList<Double>();double error = 0.0;ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();double alpha = 0.01;for (int i = 0; i < n; i++) {weights.add(1.0);weightstmp.add(1.0);}h.add(0.0);double temp = 0.0;dataMatrixMulweights.add(0.0);for (int i = 0; i < m; i++) {temp = 0.0;for (int k = 0; k < n; k++) {temp = temp + Double.parseDouble(dataSet.data.get(i).get(k)) * weights.get(k);}dataMatrixMulweights.set(0, temp);h = sigmoid(dataMatrixMulweights);error = Double.parseDouble(classLabels.get(i)) - h.get(0);double tempweight = 0.0;for (int p = 0; p < n; p++) {tempweight = alpha * Double.parseDouble(dataSet.data.get(i).get(p)) * error;weights.set(p, weights.get(p) + tempweight);}}return weights;}/** * @param dataSet * @param classLabels * @return * @author haolidong * @Description: [全部数据的梯度上升算法] */public static ArrayList<Double> gradAscent(Matrix dataSet, ArrayList<String> classLabels) {int m = dataSet.data.size();int n = dataSet.data.get(0).size();ArrayList<Double> weights = new ArrayList<Double>();ArrayList<Double> weightstmp = new ArrayList<Double>();ArrayList<Double> h = new ArrayList<Double>();ArrayList<Double> error = new ArrayList<Double>();ArrayList<Double> dataMatrixMulweights = new ArrayList<Double>();double alpha = 0.001;int maxCycles = 500;for (int i = 0; i < n; i++) {weights.add(1.0);weightstmp.add(1.0);}for (int i = 0; i < m; i++) {h.add(0.0);error.add(0.0);dataMatrixMulweights.add(0.0);}double temp;for (int i = 0; i < maxCycles; i++) {for (int j = 0; j < m; j++) {temp = 0.0;for (int k = 0; k < n; k++) {temp = temp + Double.parseDouble(dataSet.data.get(j).get(k)) * weights.get(k);}dataMatrixMulweights.set(j, temp);}h = sigmoid(dataMatrixMulweights);for (int q = 0; q < m; q++) {error.set(q, Double.parseDouble(classLabels.get(q)) - h.get(q));}double tempweight = 0.0;for (int p = 0; p < n; p++) {tempweight = 0.0;for (int q = 0; q < m; q++) {tempweight = tempweight + alpha * Double.parseDouble(dataSet.data.get(q).get(p)) * error.get(q);}weights.set(p, weights.get(p) + tempweight);}}return weights;}/** * @param fileName *            读入的文件名 * @return * @author haolidong * @Description: [根据读入的文件名形成特征集以及标签] */public static CreateDataSet readFile(String fileName) {File file = new File(fileName);BufferedReader reader = null;CreateDataSet dataSet = new CreateDataSet();try {reader = new BufferedReader(new FileReader(file));String tempString = null;// 一次读入一行,直到读入null为文件结束while ((tempString = reader.readLine()) != null) {// 显示行号String[] strArr = tempString.split("\t");ArrayList<String> as = new ArrayList<String>();as.add("1");for (int i = 0; i < strArr.length - 1; i++) {as.add(strArr[i]);}dataSet.data.add(as);dataSet.labels.add(strArr[strArr.length - 1]);}reader.close();} catch (IOException e) {e.printStackTrace();} finally {if (reader != null) {try {reader.close();} catch (IOException e1) {}}}return dataSet;}}



                                             
0 0