机器学习-逻辑回归

来源:互联网 发布:护士资格考试视频软件 编辑:程序博客网 时间:2024/04/30 14:44

1.概念

第一篇机器学习-线性回归中介绍了线性回归的概念和模型以及java的实现。其实,逻辑回归的本质上是线性回归,只是在特征到结果的映射中加入了一层函数映射,即先把特征线性求和,然后使用一个函数g(z)进行计算。g(z)函数可以将连续值映射到0和1上。也就是说线性回归的输出时y=f(x)=wx,而逻辑回归的输出是y=g(f(x))=g(wx)。

线性回归和逻辑回归的结构图如下所示:
这里写图片描述

逻辑回归于线性回归的不同点在于:将线性回归的输出范围,例如从负无穷到正无穷,压缩到0和1之间;把大值压缩到这个这个范围还有个很好的用处,就是可以消除特别冒尖的变量的影响。

2.函数模型

刚刚提到线性回归的输出时y=f(x)=wx,而逻辑回归的输出y=g(f(x))=g(wx)。这里的g(z)函数就是Logistic函数也叫Sigmoid函数,其函数形式如下:
这里写图片描述
Sigmoid函数又个很漂亮的S形,如下图所示:
这里写图片描述

其函数的实际描述如下:
给定n个特征x={x1,x2,x3,…,xn},设条件概率p(y=1|x)位观测样本y相对于事件因素x发生的概率,用Sigmoid函数表示如下:
这里写图片描述
其中,这里写图片描述

那么在x条件下y不发生的概率为:
这里写图片描述

假设现在又m个独立的观测事件:y=(y1 ,y2,…,ym ),则一个事件yi发生的概率为:
这里写图片描述

因此,m个独立事件出现的似然函数为(因为每个样本都是独立的,所以m个样本出现的概率就是它们各自出现的概率乘积):
这里写图片描述

然后,我们的目标就是求出使这一似然函数的值最大的参数估计,最大似然估计就是求出参数θ,使得L(θ)取得最大值,对函数L(θ)取对数得到:
这里写图片描述

最大似然估计就是求使得L(θ)取得最大值时的θ,其实这里可以使用梯度上升法求解,求得的θ就是要求最佳参数。我们可以乘以一个负的系数-1/m,转成梯度下降法求解,L(θ)转成J(θ):
这里写图片描述

所以,取J(θ)最小值时的θ为要求的最佳参数。

3.梯度下降求解

类似于上一篇写的线性回归的求解过程,这里也选用梯度下降法求解θ。
θ的更新过程如下:
这里写图片描述

因此,θ的更新可以写成:
这里写图片描述

接下来的过程可以参考线性回归的知识点,使用批量梯度下降或者随机梯度下降,代码的实现过程,只需要在线性回归的基础上加一个sigmoid函数,稍作修改即可。

4.正则化

对于线性回归或逻辑回归的损失函数构成的模型,可能有些权重很大,有些权重很小,而导致过拟合(通俗的讲就是过分拟合了训练数据),这使得模型的复杂度提高,泛华能力较差(可以理解为模型的一般性,对未知数据的预测能力)。

如下图从左往右分别是欠拟合,合适的拟合,过拟合。

这里写图片描述

过拟合问题的主要原因是拟合了过多的特征。其解决的主要方法如下:
(1)减少特征数量:
可用人工选择需要保留的特征,或者采用模型选择算法选取特征,例如spark里面的特征选择方法,基于卡方检验的特征选择)。减少特征会失去一些信息,即使特征选的很好。
(2)正则化:
保留所有的特征,但是减少θ的大小。正则化是结构化风险最小话策略的实现,是在经验风险上加一个正则化项或者惩罚项。正则化项一般是模型复杂度的单调递增函数,模型越复杂,正则化项就越大。

为了增强模型的泛化能力,防止训练模型过拟合,特别是对于大量稀疏特征的模型,模型的复杂度比较高,需要进行降维处理,我们需要保证在训练误差最小化的基础上,通过加上正则化项减少模型的复杂度。在逻辑回归中,支持L1和L2正则化。

损失函数如下:
这里写图片描述

5.java实现

package xudong.Regression;import java.io.BufferedReader;import java.io.File;import java.io.FileNotFoundException;import java.io.FileReader;import java.io.IOException;public class LogRegression {    /**     * 逻辑回归的实现     *      * 这里的逻辑回归的实现并未加入正则化项     * @author xudong     * @since 2017-8-8     */    private double [][] trainData;//数据集矩阵    private int row;    private int column;    private double [] theta;//参数    private double alpha;//学习速率    private int iteration;//迭代次数    public LogRegression(String filename){        int rowFile=getRowNumber(filename);//获取输入训练数据的行数        int columnFile=getColumnNumber(filename);//获取训练数据的列数        trainData=new double[rowFile][columnFile+1];//加了一个特征x0  x0==1        this.row=rowFile;        this.column=columnFile;        this.alpha=0.001;        this.iteration=1000;        this.theta=new double[column-1];        initialize_theta();        loadTrainDataFromFile(filename,rowFile,columnFile);    }    public LogRegression(String filename,double alpha,int iteration){        int rowFile=getRowNumber(filename);//获取输入训练数据的行数        int columnFile=getColumnNumber(filename);//获取训练数据的列数        trainData=new double[rowFile][columnFile+1];//加了一个特征x0  x0==1        this.row=rowFile;        this.column=columnFile;        this.alpha=alpha;        this.iteration=iteration;        this.theta=new double[column-1];        initialize_theta();        loadTrainDataFromFile(filename,rowFile,columnFile);    }    /**     * 从文件中加载数据集到trainData中     * @param filename     * @param rowFile     * @param columnFile     */    private void loadTrainDataFromFile(String filename, int rowFile,int columnFile) {        for(int i=0;i<row;i++){//trainData第一例全是0            trainData[i][0]=1.0;        }               File file=new File(filename);        BufferedReader br=null;        try {            br=new BufferedReader(new FileReader(file));            String temp=null;            int counter=0;            while((counter < row)&&(temp=br.readLine())!=null){                String[] tempData=temp.split(" ");                for(int i=0;i<column;i++){                    trainData[counter][i+1]=Double.parseDouble(tempData[i]);                }                counter++;            }            br.close();        } catch (FileNotFoundException e) {            e.printStackTrace();        } catch (IOException e) {            e.printStackTrace();        }    }    /**     * 初始化参数theta     */    private void initialize_theta() {        for(int i=0;i<theta.length;i++){            theta[i]=1.0;        }    }    /**     * 获取数据集的行数     * @param filename     * @return     */    private int getRowNumber(String filename){        int count=0;        File file=new File(filename);        BufferedReader br=null;        try {            br=new BufferedReader(new FileReader(file));            while(br.readLine()!=null){                count++;            }            br.close();        } catch (FileNotFoundException e) {            e.printStackTrace();        } catch (IOException e) {            e.printStackTrace();        } finally {            if(br!=null){                try {                    br.close();                } catch (IOException e) {                    e.printStackTrace();                }            }        }        return count;    }    /**     * 获取数据集的列数(特征维度)     * @param filename     * @return     */    private int getColumnNumber(String filename){        int count=0;        File file=new File(filename);        BufferedReader br=null;        try {            br=new BufferedReader(new FileReader(file));            String temp=br.readLine();            if(temp!=null){                count=temp.split(" ").length;            }            br.close();        } catch (FileNotFoundException e) {            e.printStackTrace();        } catch (IOException e) {            e.printStackTrace();        } finally {            if(br!=null){                try {                    br.close();                } catch (IOException e) {                    e.printStackTrace();                }            }        }        return count;    }    /**     * 训练模型并计算theta     */    public void trainTheta(){        int iteration=this.iteration;        while(iteration-- >  0){            //对每一个thetai 求偏导            double[] partial_derivative=compute_partial_derivative();            for(int i=0;i<theta.length;i++){                theta[i]-=alpha*partial_derivative[i];            }        }    }    private double[] compute_partial_derivative() {        double[] partial_derivative=new double[theta.length];        for(int j=0;j<theta.length;j++){            partial_derivative[j]= compute_partial_derivative_for_theta(j);        }        return partial_derivative;    }    private double compute_partial_derivative_for_theta(int j) {        double sum=0.0;        for(int i=0;i<row;i++){            sum+=h_theta_x_i_minus_y_i_times_x_j_i(i,j);        }        return sum/row;    }    private double h_theta_x_i_minus_y_i_times_x_j_i(int i, int j) {        double[] oneRow=getRow(i);//取一行数据,前面是feature最后一个是y        double result=0.0;        for(int k=0;k<(oneRow.length-1);k++){            result+=theta[k]*oneRow[k];        }        //这个地方加入了sigmoid函数        result=sigmoid(result);        //        result-=oneRow[oneRow.length-1];        result*=oneRow[j];        return result;    }    private double[] getRow(int i) {                return trainData[i];    }    /**     * sigmoid函数     * @param x     */    private double sigmoid(double x){        return (double)(1.0/(1+Math.exp(-x)));    }    //这里是主方法    public static void main(String[] args) {        String filename="";        LogRegression log=new LogRegression(filename);        log.trainTheta();        //预测的函数可以在这里写,就是一个根据训练完的参数theta求解的过程        //过程省略,请自行补上,也可以打印出theta来查看    }}