Logistic Regression(LR)

来源:互联网 发布:java api pdf 编辑:程序博客网 时间:2024/05/17 05:14

逻辑回归算法。


package LogisticRegression;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.Random;


public class LR {


//样本类,存储每个样本的数据
public class Sample
{
private double[] attributes=new double[4];//样本属性值
private int label;//样本标签
private double decisionValue;//经过LR得到的决策值
Sample(double[] attributes,int label)
{
for(int i=0;i<this.attributes.length;i++)
{
this.attributes[i]=attributes[i];
}
this.label=label;
this.decisionValue=label;
}
public double[] getAttributes()//获得样本的属性
{
return this.attributes;
}
public int getLabel()//获得样本的标签
{
return this.label;
}
public double getDecisionValue()//获得样本的决策值
{
return this.decisionValue;
}
public void setDecisionValue(double value)//设置样本的决策值
{
this.decisionValue=value;
}
}

public void load_datas(String path,ArrayList<Sample> sampleSet,int num_features)
{
String line;//记录从数据集中读取的行数据
String[] s ;//存储从数据集中读出的行数据分割后的样本属性和标签
double[] attributes=new double[num_features];//存储每个样本的属性值
FileReader fr =null;
BufferedReader bufr = null ;
try
{
fr = new FileReader(path);
bufr= new BufferedReader(fr);
//读取训练样本
while((line = bufr.readLine())!=null)
{
s = line.split(",");
for(int i=0;i<num_features;i++)//将获得的样本属性值转为double类型
{
attributes[i]=Double.parseDouble(s[i]);
}
sampleSet.add(new Sample(attributes,Integer.parseInt(s[s.length-1].trim())));
}
}
catch(IOException e){}
finally
{
try
{
if(bufr!=null)
bufr.close();
}
catch(IOException e){

}
}
}

//将原始数据集划分为训练集和测试集
public void divid_dataSet(ArrayList<Sample> sampleSet,ArrayList<Sample> trainSet,ArrayList<Sample> testSet)
{
ArrayList<Sample> posSet = new ArrayList<Sample>();// 正样本集
ArrayList<Sample> negSet = new ArrayList<Sample>();// 负样本集
int pos = 0;// 正样本个数
int neg = 0;// 负样本个数
int train_instances = 0;// 训练样本个数
int train_instances_pos = 0;// 训练集中正 个数
int i=0;
Iterator<Sample> it_sampleSet= sampleSet.iterator();

Collections.shuffle(sampleSet);//打乱原始数据集,重新随机排列
while(it_sampleSet.hasNext())//迭代遍历原始样本集,计算其中正,负样本数目
{
Sample s=it_sampleSet.next();
if(s.getLabel()==1)
{
posSet.add(s);//标签为1的样本加入正样本集
pos++;
}
else
{
negSet.add(s);//标签为0的样本加入负样本集
neg++;
}
}

train_instances =(int)(sampleSet.size()*0.8);//取原始数据集的80%作为训练集
train_instances_pos = (int)(train_instances * ((pos*(1.0)) / (pos + neg)));//采用分层抽样,获取训练集中相应的正样本数

i=0;
Iterator<Sample> it_posSet= posSet.iterator();
while(it_posSet.hasNext())//迭代遍历正样本集
{
Sample s= it_posSet.next();
if (i< train_instances_pos) {// 选择正样本加入训练集
trainSet.add(s);
i++;
}
else// 选择正样本加入测试集
{
testSet.add(s);
}
}

i=0;
Iterator<Sample> it_negSet= negSet.iterator();
while(it_negSet.hasNext())//迭代遍历负样本集
{
Sample s= it_negSet.next();
if (i < train_instances - train_instances_pos) {// 选择负样本加入训练集
trainSet.add(s);
i++;
}
else// 选择负样本加入测试集
{
testSet.add(s);
}
}
}

//初始化权值
public void init_weights(double[] weights)
{
Random r = new Random();
for(int i=0;i<weights.length;i++)
{
weights[i]= r.nextInt(10)/10.0;//随机初始化权值为0~1之间的数
}
}

//每个样本的属性值与权值的乘积和
public double productSum(Sample s,double[] weights)
{
int i;
double productSum=0;
for(i=0;i<s.getAttributes().length;i++)
{
productSum += s.getAttributes()[i]*weights[i];
}
productSum += weights[i];//加偏置
return productSum;
}


//训练
public void train(double[] weights, ArrayList<Sample> trainSet,double threshold,double alpha)
{
int num=100,i,j;
Sample s;
double productSum=0,decisionValue,loss,sum,y_sum=0;//loss为损失函数
for(i=0;i<num;i++)
{
loss=0;
Iterator<Sample> it_trainSet=trainSet.iterator();
while(it_trainSet.hasNext())//迭代遍历训练集,计算损失函数与每个样本的决策值
{
s=it_trainSet.next();
productSum=productSum(s,weights);//计算样本乘积和
decisionValue=1.0/(1.0+Math.exp(-productSum));//经过sigmoid函数得到决策值
s.setDecisionValue(decisionValue);
loss += (s.getLabel()*Math.log(decisionValue)+(1-s.getLabel())*Math.log(1-decisionValue));//损失函数
   
}
loss=(-1.0/trainSet.size())*loss;
if(loss<1E-8)
{
break;
}

//调整权值
else
{
for(j=0;j<weights.length-1;j++)//遍历权值
{
sum=0;
y_sum=0;
it_trainSet=trainSet.iterator();
while(it_trainSet.hasNext())//遍历训练集
{
s=it_trainSet.next();
y_sum+=(s.getLabel()-s.getDecisionValue());//每个样本真实标签与决策值差的累加和,用于调整偏置
sum+=((s.getLabel()-s.getDecisionValue())*s.getAttributes()[j]);
}
weights[j]+=(alpha*sum);//调整对应权值
}
weights[j] += (alpha*y_sum);//调整偏置
}

}
}

public double[] test(double[] weights, ArrayList<Sample> testSet,double threshold,double alpha)throws IOException
{
double tp=0,fp=0,tn=0,fn=0;//分类结果混淆矩阵的4个值
double[] perform=new double[4];//存储分类结果混淆矩阵的4个值
int flag;
Sample s;
double productSum=0,decisionValue;
BufferedWriter bufw = new BufferedWriter(new FileWriter("./data/ROC.txt"));//将样本的决策值和真实标签写入文件,在matlab中画ROC
Iterator<Sample> it_trainSet=testSet.iterator();
while(it_trainSet.hasNext())
{
s=it_trainSet.next();
productSum=productSum(s,weights);
decisionValue=1.0/(1.0+Math.exp(-productSum));
if(decisionValue>threshold)
{
flag=1;
}
else
{
flag=0;
}
s.setDecisionValue(decisionValue);
bufw.write(s.getLabel()+"  "+decisionValue);
bufw.newLine();
if(s.getLabel()==1)
   {
    if(flag==1)
    tp++;
    else
    fn++;
   }
   if(s.getLabel()==0)
   {
    if(flag==1)
    fp++;
    else
    tn++;
   }
}
bufw.close();
perform[0]=tp;
perform[1]=fp;
perform[2]=tn;
perform[3]=fn;
return perform;
}
public void evaluation( double tp,double fp,double tn,double fn)//算法性能表示
{
double accuracy = (tp + tn) / (tp + tn + fp + fn);
double specificity = tn / (fp + tn);
double precision = tp / (tp + fp);
double recall = tp / (tp + fn);
double sensitivity = tp / (tp + fn);
double mcc = (tp * tn - fp * fn)
/ Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn));
double f1 = (2 * precision * recall) / (precision + recall);
System.out
.println("*********************************Result*****************************************");
System.out
.println("TP=" + tp + ",FP=" + fp + ",TN=" + tn + ",FN=" + fn);
System.out.println("Accuracy=" + accuracy + ",Specificity="
+ specificity);
System.out.println("Precision=" + precision + ",Recall=" + recall);
System.out.println("Sensitivity=" + sensitivity + ",MCC=" + mcc);
System.out.println("F1-measure=" + f1);
System.out
.println("********************************************************************************");
}


public void showWeights(double[] weights)
   {
    System.out.printf("\n");
System.out
.println("*******************************权值***************************************");

for(int i=0;i<weights.length;i++)
{
System.out.printf(weights[i]+"  ");
}
System.out.printf("\n");
System.out
.println("*************************************************************************");

}
public static void main(String[] args) throws IOException 
{
LR lr = new LR();
int num_features= 4;//样本特征数目
ArrayList<Sample>  sampleSet= new ArrayList<Sample>();//原始样本集
ArrayList<Sample>  trainSet= new ArrayList<Sample>();//训练样本集
ArrayList<Sample>  testSet= new ArrayList<Sample>();//测试样本集
double[] weights=new double[num_features+1];//权值
double threshold=0.5,alpha=0.2;//阈值与学习率

lr.load_datas("./data/dataSet.txt",sampleSet,num_features);//加载数据集
lr.divid_dataSet(sampleSet,trainSet,testSet);//将原始数据集划分为训练集和测试集
lr.init_weights(weights);//初始化权值
lr.train(weights,trainSet,threshold,alpha);
double[] perform=lr.test(weights,testSet,threshold,alpha);
lr.evaluation(perform[0],perform[1],perform[2],perform[3]);
lr.showWeights(weights);
}


}