感知机的对偶形式实现java版

来源:互联网 发布:网络部招新面试问题 编辑:程序博客网 时间:2024/05/29 19:12
import java.util.ArrayList;
import java.util.List;


public class Perceptron3 {


/**

* 参考《统计学习方法》 李航 2012年三月第一版2.3.3节

* 根据输入的数据计算gram matrix

* @param list
* @return
*/
public double[][] getGramMatrix(List<Dot> list) {
int size = list.size();
double[][] gramMatrix = new double[size][size];
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
gramMatrix[i][j] = list.get(i).x1 * list.get(j).x1
+ list.get(i).x2 * list.get(j).x2;
}
}
// 输出
for (int i = 0; i < gramMatrix.length; i++) {
for (int j = 0; j < gramMatrix[i].length; j++) {
System.out.print(gramMatrix[i][j] + " ");
}
System.out.println();
}
return gramMatrix;
}


/**
* 计算 alpha * y * gramMatrix

* @param alpha
* @param y
* @param gramMatrix
* @param i
* @return
*/
public double getGram(double[] alpha, double y[], double[][] gramMatrix,
int i, double b) {
double sum = 0.0;
for (int j = 0; j < alpha.length; j++) {
sum += alpha[j] * y[j] * gramMatrix[j][i];
}
return sum + b;
}


/**
* 调整参数

* @param alpha
* @param yita
*/
public void adjust(Param3 param, double yi, int i) {
param.alpha[i] += param.yita;
param.b += param.yita * yi;
// System.out.println(param.b + "  &&&");
}


/**
* 训练

* @param list
* @param param
*/
public void train(List<Dot> list, Param3 param) {
double[][] gramMatrix = getGramMatrix(list);
double yi[] = new double[list.size()];
for (int i = 0; i < list.size(); i++) {
yi[i] = list.get(i).getY();
}
boolean end = false;
while (!end) {
// 判断是否有误差
end = true;
for (int i = 0; i < list.size(); i++) {
if (list.get(i).getY()
* getGram(param.alpha, yi, gramMatrix, i, param.b) <= 0) {
adjust(param, list.get(i).getY(), i);
end = false;
break;
}
}
}
// 和具体的维数 相关
double[] wi = new double[2];
for (int i = 0; i < wi.length; i++) {
for (int j = 0; j < param.alpha.length; j++) {
wi[i] += param.alpha[j] * list.get(j).getY()
* list.get(j).getX1();
}
}
param.w = wi;
}


public static void main(String[] args) {
Perceptron3 p = new Perceptron3();
// 添加点
List<Dot> list = new ArrayList<Dot>();
Dot d1 = new Dot(3, 3, 1);
list.add(d1);
Dot d2 = new Dot(4, 3, 1);
list.add(d2);
Dot d3 = new Dot(1, 1, -1);
list.add(d3);


Param3 param = new Param3(1, 0, new double[] { 0, 0, 0 });


p.train(list, param);
System.out.println(param.toString());


}

}





/**
 * 实例点
 * 
 * @author Administrator
 *
 */
class Dot {
double x1;
double x2;
double y;


public Dot(double x1, double x2, double y) {
super();
this.x1 = x1;
this.x2 = x2;
this.y = y;
}


public double getX1() {
return x1;
}


public void setX1(double x1) {
this.x1 = x1;
}


public double getX2() {
return x2;
}


public void setX2(double x2) {
this.x2 = x2;
}


public double getY() {
return y;
}


public void setY(double y) {
this.y = y;
}


@Override
public String toString() {
return "Dot [x1=" + x1 + ", x2=" + x2 + ", y=" + y + "]";
}

}





import java.util.Arrays;


/**
 * 参数 
 * @author Administrator
 *
 */
public class Param3 {
public double yita;
public double b;
public double[] alpha;
public double[] w;


public Param3(double yita, double b, double[] alpha) {
super();
this.yita = yita;
this.b = b;
this.alpha = alpha;
}


public Param3(double yita, double b, double[] alpha, double[] w) {
super();
this.yita = yita;
this.b = b;
this.alpha = alpha;
this.w = w;
}


public double getYita() {
return yita;
}


public void setYita(double yita) {
this.yita = yita;
}


public double getB() {
return b;
}


public void setB(double b) {
this.b = b;
}


public double[] getAlpha() {
return alpha;
}


public void setAlpha(double[] alpha) {
this.alpha = alpha;
}


public double[] getW() {
return w;
}


public void setW(double[] w) {
this.w = w;
}


@Override
public String toString() {
return "Param3 [yita=" + yita + ", b=" + b + ", alpha="
+ Arrays.toString(alpha) + ", w=" + Arrays.toString(w) + "]";
}


}

前三行是gram matrix

最后一行是训练模型的结果

w=[1.0 1.0] b = -3.0

0 0