机器学习基石---第二周PLA

来源:互联网 发布:淘宝店铺处罚 编辑:程序博客网 时间:2024/06/08 19:27
knitr::opts_chunk$set(echo = TRUE)

  台大《机器学习基石》第二周课的笔记,只整理部分重要内容。希望能把课上学的,做一个精简的记录。

变量说明

  存在两类数据,标记为y,取值为1,1。特征向量记为xx=(x0,x1,x2,...,xd)。其中x0为常量1,其余为具体特征值。存在超平面wTx=0,其中w=(w0,w1,...,wd),可以正确分开两类数据。共有N个样本数据。

迭代过程

  PLA采取知错就改的策略。遍历所有样本,如果发现分类错误,采用如下方式如下方式更新w
  

Fort=0,1,...N1.findamistakeofwtcalled(xn(t),yn(t))sign(wTtxn(t))yn(t)2.(tryto)correctthemistakebywt+1wt+yn(t)xn(t)...untilnomoremistakesreturnlastw(calledwPLA)asg

更新理由

这里写图片描述
  判断类别的公式:

sign(wTtxn(t))=sign(wTtxn(t)cos(θ))

  如果正类被误判,则cos(θ)<0,即θ(π2,π),所以要缩小法向量和特征向量之间的夹角。故采用上图方法迭代w的值。

证明

  证明线性可分数据集,PLA算法一定能够经过有限次的迭代,得到一个完美的分割超平面。

每一次迭代wt更接近wf

  1. wf为完美分类器
  2. (xn,yn)为错分的样本
  3. (xn(t),yn(t))为第t次迭代时,wt错分的样本
  因为wf是完美分类器,则一定有:

yn(t)wTfxn(t)minnynwTfxn>0

  利用任意一个错判样本(xn(t),yn(t))进行第t+1次迭代之后,计算:

wTfwt+1wTfwt+1=wTf(wt+yn(t)xn(t))wTfwt+1=wTfwt+yn(t)wTfxn(t)wTfwt+1wTfwt+minnyn(t)wTfxn(t)wTfwt+1>wTfwt+0wTfwt+1=wTfwtwTfwt+1

  从余弦相似度的角度看,通过错判样本对wt的修正,使得迭代后的w更接近于完美的分割超平面。

每一次迭代wt的模增长较小

wt+12=wt+yn(t)xn(t)2=wt2+2yn(t)wTtxn(t)+yn(t)xn(t)2wt2+0+yn(t)xn(t)2wt2+maxnynxn2

迭代次数有限

  假设w0=0,经过T次迭代之后:

wTfwTwfwT=wTf(wT1+yn(T1)xn(T1))wfwT=wTf(wT1+yn(T1)xn(T1))wfwT=wTfwT1+yn(T1)wTfxn(T1)wfwTwTfwT1+minnynwTfxnwfwTwTfwT2+yn(T2)wTfxn(T2)+minnynwTfxnwfwTwTfwT2+2minnynwTfxnwfwTTminnynwTfxnwfwTFurther:wTfwTTminnynwTfxnTwTfwTminnynwTfxnT2(wTfwT)2(minnynwTfxn)2=wf2wT2sin2(θ)(minnynwTfxn)2wf2wT2(minnynwTfxn)2wf2maxynxnn2(minnynwTfxn)2=wf2maxxnn2(minnynwTfxn)2

  所以迭代次数T有上界。

案例

构造数据集

  构造数据集,验证算法。

x11 <- 1:10x21 <- x11 + runif(10, 0, 1) + 3x22 <- x11 - runif(10, 0, 1)example_data <- data.frame(x1 = rep(x11, 2),x2 = c(x21, x22),label = rep(c(1, -1), each = 10))example_data$label <- as.factor(example_data$label)library(ggplot2)ggplot(data = example_data, aes(x = x1,y = x2,color = label,shape = label)) +geom_point()

这里写图片描述

PLA算法

## 参数:数据集、标签名称PLA_f <- function(dataset, label) {  ## 样本数  row_num <-  nrow(dataset)  w <- rep(1, ncol(dataset))  w0 <- matrix(w, 1, 3, byrow = T)  real_label <- as.numeric(as.vector(dataset[, label]))  feature_matrix <-    as.matrix(data.frame(x0 = rep(1, row_num), cbind(dataset[, setdiff(colnames(dataset), label)])))  i <- 1  j <- 0  while (i < row_num & j == 0) {    i <- 1    j <- 0    for (i in 1:row_num) {      ## 判断是否有误判      if (as.vector(feature_matrix[i,] %*% t(w0)) * real_label[i] <= 0) {        ## 存在误判,修正w0        w0 <- w0 + real_label[i] * feature_matrix[i,]        w <- c(w, w0)        j <- 1      }      if(j == 1){        j <- 0        i <- row_num-1        break()}    }  }  w_data <- data.frame(matrix(w,ncol=ncol(dataset),byrow = TRUE))  colnames(w_data) <- paste0("x",0:(ncol(feature_matrix)-1))  w_data <- dplyr::mutate(w_data,                          slope = -x1 / x2,                          intercept = -x0 / x2)  return(w_data)}

求解

w_data <- PLA_f(dataset = example_data, label = "label")w_data
   x0 x1           x2        slope    intercept1   1  1  1.000000000   -1.0000000   -1.00000002   0  0  0.495471116    0.0000000    0.00000003  -1 -1 -0.009057768 -110.4024725 -110.40247254   0  0  4.912654036    0.0000000    0.00000005  -1 -1  4.408125152    0.2268538    0.22685386  -2 -2  3.903596268    0.5123481    0.51234817  -3 -4  1.915120282    2.0886417    1.56648128  -2 -1  8.363856425    0.1195621    0.23912419  -3 -2  7.859327541    0.2544747    0.381712010 -4 -4  5.870851555    0.6813322    0.681332211 -5 -9  1.747566727    5.1500179    2.861121112 -4 -8  6.669278532    1.1995300    0.5997650

动图

library(animation)## 指定ImageMagic目录位置,注意是magick.exe,之前版本貌似一致是convert.exeani.options(convert = "D:/ImageMagic/ImageMagick-7.0.7-Q16/magick.exe")saveGIF(  expr = {    library(ggplot2)    for (i in 1:nrow(w_data)) {plot(      x = example_data$x1[1:10],      y = example_data$x2[1:10],      pch = 15,      col = "red",      xlim = c(0, 20),      ylim = c(0, 15),      xlab = "x1",      ylab = "x2",main = paste0("Picture",i)    )      lines(x = example_data$x1[11:20],            y = example_data$x2[11:20],            type = "p",            pch = 17,            col = "blue")      abline(coef=c(w_data$intercept[i],w_data$slope[i]),lwd=2)      }  },  ## GIF文件名,注意文件后缀名要加上  movie.name = "PLA.gif",  ## 时间间隔  interval = 1,  ## 图形设置  ani.width = 600,  ani.height = 600,  ## 文件输出在当前目录  outdir = getwd())

这里写图片描述

Ref

[1]课程PPT

2017-12-19于杭州

原创粉丝点击