java手写逻辑回归包括L1,L2正则实现

来源:互联网 发布:nagios监控linux主机 编辑:程序博客网 时间:2024/05/17 02:09

作为一枚机器学习的爱好者,逻辑回归算是一个简单入门的算法,原理比较简单,但是自己手动实现逻辑回归有一些要注意的事项:

第一是步长选择的问题,根据你的数据大小来选择。

第二是自己手动可选择加不加入常数项,用于做训练。

第三是实际写代码用的梯度上升代码来求解,算法原理建议使用梯度下降,但是工程为了方便用梯度上升来求解。

第四是正则化问题,可以选择L1、L2正则来实现你的代码。

第五是终止条件的问题,一般写工程可以选择迭代次数,也可以选择根据最后weights变化来写终止条件,也可以两个一起结合一起使用。

第六是优化算法,可以用批梯度,也可以用随机梯度,也可以拟牛顿迭代法,原理都较简单。


基本就是这些,欢迎大牛补充,下面自己用java写了个,数据源是python机器学习实战那本书里面的数据,java实现就么有用矩阵,了解矩阵算法背后原理实际用list也是一个性质,不说直接看代码。


首先是读取数据代码:

package com.wanda.logistic;import java.io.BufferedReader;import java.io.FileInputStream;import java.io.File;import java.io.IOException;import java.io.InputStreamReader;import java.util.ArrayList;import java.util.Arrays;import java.util.List;public class ReadData {public static final String PATH = "d:\\wilson.zhou\\Desktop\\logistic.txt";public  static List<List<Float>> dataList = new ArrayList<List<Float>>();public static List<Float> labelList = new ArrayList<Float>();static {try {init();} catch (IOException e) {e.printStackTrace();}}private static void init() throws IOException {BufferedReader buff = new BufferedReader(new InputStreamReader((new FileInputStream(new File(PATH)))));String str = buff.readLine();while (str != null) {String[] arr = str.split("\t");labelList.add(Float.parseFloat(arr[2]));dataList.add(Arrays.asList(Float.parseFloat(arr[0]),Float.parseFloat(arr[1])));str = buff.readLine();}buff.close();}}




逻辑回归代码:

package com.wanda.logistic;import java.util.Arrays;import java.util.List;public class LogRegression {public static void main(String[] args) {LogRegression lr = new LogRegression();ReadData instances = new ReadData();lr.train(instances, 0.001f, 1); //}public void train(ReadData instances, float step, int type) {List<List<Float>> datas = instances.dataList;List<Float> labels = instances.labelList;int size = datas.size();int dim = datas.get(0).size();float[] w = new float[dim]; // 初始化权重float changas = Float.MAX_VALUE;int caculate = 0;switch (type) {case 1: // 批梯度下降的方式while (changas > 0.0001) {float[] wClone = w.clone();float[] out = new float[size];for (int s = 0; s < size; s++) {float lire = innerProduct(w, datas.get(s));out[s] = sigmoid(lire);}for (int d = 0; d < dim; d++) {float sum = 0;for (int s = 0; s < size; s++) {sum += (labels.get(s) - out[s]) * datas.get(s).get(d);}float q=w[d];w[d] = (float) (q + step * sum);//w[d] = (float) (q + step * sum-0.01*Math.pow(q,2)); L2正则//w[d] = (float) (q + step * sum-0.01*Math.abs(q));  L1正则}changas = changsWeight(wClone, w);caculate++;System.out.println("迭代次数是:" + caculate + "  权重是:"+ Arrays.toString(w));}break;case 2://随机梯度下降while (changas > 0.0001) {float[] wClone = w.clone();for (int s = 0; s < size; s++) {float lire = innerProduct(w, datas.get(s));float out = sigmoid(lire);float error = labels.get(s) - out;for (int d = 0; d < dim; d++) {w[d] += step * error * datas.get(s).get(d);}}changas = changsWeight(wClone, w);caculate++;System.out.println("迭代次数是:" + caculate + "  权重是:"+ Arrays.toString(w));}break;default:break;}}private float changsWeight(float[] wClone, float[] w) {float changs = 0;for (int i = 0; i < w.length; i++) {changs += Math.pow(w[i] - wClone[i], 2);}return (float) Math.sqrt(changs);}private float innerProduct(float[] w, List<Float> x) {float sum = 0;for (int i = 0; i < w.length; i++) {sum += w[i] * x.get(i);}return sum;}private float sigmoid(float src) {return (float) (1.0 / (1 + Math.exp(-src)));}}




数据:

-0.01761214.0530640-1.3956344.6625411-0.7521576.5386200-1.3223717.15285300.42336311.05467700.4067047.06733510.66739412.7414520-2.4601506.86680510.5694119.5487550-0.02663210.42774300.8504336.92033411.34718313.17550001.1768133.1670201-1.7818719.0979530-0.5666065.74900310.9316351.5895051-0.0242056.1518231-0.0364532.6909881-0.1969490.44416511.0144595.75439911.9852983.2306191-1.693453-0.5575401-0.57652511.7789220-0.346811-1.6787301-2.1244842.67247111.2179169.5970150-0.7339289.0986870-3.642001-1.61808710.3159853.52395311.4166149.6192320-0.3863233.98928610.5569218.29498411.22486311.5873600-1.347803-2.40605111.1966044.95185110.2752219.54364700.4705759.3324880-1.8895679.5426620-1.52789312.1505790-1.18524711.3093180-0.4456783.29730311.0422226.1051551-0.61878710.32098601.1520830.54846710.8285342.6760451-1.23772810.5490330-0.683565-2.16612510.2294565.9219381-0.95988511.55533600.49291110.99332400.1849928.7214880-0.35571510.3259760-0.3978228.05839700.82483913.73034301.5072785.02786610.0996716.8358391-0.34400810.71748501.7859287.7186451-0.91880111.5602170-0.3640094.7473001-0.8417224.11908310.4904261.9605391-0.0071949.07579200.35610712.44786300.34257812.2811620-0.810823-1.46601812.5307776.47680111.29668311.60755900.47548712.0400350-0.78327711.00972500.07479811.0236500-1.3374720.4683391-0.10278113.7636510-0.1473242.87484610.5183899.88703501.0153997.5718820-1.658086-0.02725511.3199442.17122812.0562165.0199811-0.8516334.3756911-1.5100476.0619920-1.076637-3.18188811.82109610.28399003.0101508.4017661-1.0994581.6882741-0.834872-1.7338691-0.8466373.84907511.40010212.62878101.7528425.46816610.0785570.05973610.089392-0.71530011.82566212.69380800.1974459.74463800.1261170.9223111-0.6797971.22053010.6779832.55666610.76134910.6938620-2.1687910.14363211.3886109.34199700.31702914.7390250


0 0
原创粉丝点击