非参方法-K NearestNeighbor(KNN)

来源:互联网 发布:ps -ef|grep java 编辑:程序博客网 时间:2024/05/22 07:45

非参方法-K NearestNeighbor(KNN)

KNN方法作为一种无参方法,使用起来十分简单,方便。更为重要的是它往往能够得到很好的效果。它既可以应用到分类中,也可以应用到回归中,是一种十分重要的方法。

问题

给定一组训练数据(X1,y1),(X2,y2),...(XN,yN), 同时又给定了预测样本Xt, 求取相对应的yt

问题分析

如果我们定义一种判断距离远近的函数,那么能够找到给定训练数据中的Xi(i=1,2,...N)距离Xt最近的一些点,也就是说找到Xt的“邻居”, 那么这些“邻居”所对应的y值应该与yt相差不大。

参考解决方案:

(1) 定义一种距离函数,求出所有训练数据输入值,即Xi(i=1,2,...N)到预测样本Xt的距离值。
(2) 找出这些距离值中最小的K个值,对应于X1,X2,...,XK).
(3) 若所求问题为回归问题(即训练数据的y值为连续的),则

yt=1Ki=1Kyi

若所求问题为分类问题(即y值是离散的, 且为M分类问题),则yt为K个“邻居”中含有数量最多的那个类所对应的值。
matlab 代码:

%************************************************************%               KNN for regression or classification%*************************************************************% the specified parameters are as follows:%          X: the input of train datas, it should be n*m Matrix(n is the%             nums of the data, while m is the dimensionalities)%          y: the output of train datas, t should be n*1 Matrix(n is the%             nums of the data)%          k: top k nearest neighbors%  predict_x: test sample, t should be m*1 Matrix(m is the%             dimensionalities)% regression: 1 denotes regression, 0 denotes classification%% Author: Bai Junyang%  Email: bjyhappy123@gmail.com%************************************************************function result = KNN(X, y, k, predict_x, regression)[n, m] = size(X);%compute the vector of the distancpredict_X = repmat(predict_x', n, 1);%size(X)%size(predict_X)distance = sum((X - predict_X).^2, 2);%find the top-K index:topIndextopIndex = zeros(k, 1);sort_distance = sort(distance);for i = 1:k        topIndex(i) = find(distance == sort_distance(i));end;%compute the resultresult = mean(y(topIndex)); if regression == 0    if result > 0.5        result = 1;    else        result = 0;    end;end;%plot the pointindex = 1:n;index(topIndex) = [];if regression == 1    %plot the predict point    plot(predict_x, result, 'ro', 'MarkerSize', 10);    hold on;    %plot the training data except the top-k data    for i = index        plot(X(i,:), y(i), 'ko', 'MarkerSize', 5);    end;    %plot the top-k data    for i = 1:k        plot(X(topIndex(i),:), y(topIndex(i),:), 'bo', 'MarkerSize', 10);    end;else    %plot the predict point    if result == 0        plot(predict_x(1), predict_x(2), 'ro', 'MarkerSize', 10);        hold on;    else        plot(predict_x(1), predict_x(2), 'r+', 'MarkerSize', 10);        hold on;    end;    %plot the training data except the top-k data    for i = index        if y(i) == 0            plot(X(i, 1), X(i, 2), 'yo', 'MarkerSize', 5);        else            plot(X(i, 1), X(i, 2), 'k+', 'MarkerSize', 5);        end;    end;    %plot the top-k data    for i = 1:k        if y(i) == 0            plot(X(topIndex(i),1), X(topIndex(i),2), 'bo', 'MarkerSize', 10);        else            plot(X(topIndex(i),1), X(topIndex(i),2), 'b+', 'MarkerSize', 10);        end;    end;    hold off;end

绘出图形:
这里写图片描述
红色就代表预测点的值,蓝色代表K个邻居,这里K = 5。

与线性回归的对比

测试所用的数据共97组,其中25组用于测试
代码如下:

function [knnError, lrError] = test(k)data = load('D://ex1data1.txt');X_train = data(1:72, 1);y_train = data(1:72, 2);X_test = data(73:97, 1);y_test = data(73:97, 2);%compute the Linear Regression ErrorlrX_train = [ones(72, 1), X_train];w = pinv(lrX_train)*y_train;lrX_test = [ones(25, 1), X_test];lrError = sum((lrX_test * w - y_test).^2);%compute the KNN ErrorknnError = 0;for i = 1:25    knnError = knnError + (y_test(i) - KNN(X_train, y_train, k, X_test(i), 1))^2;end;

若用平方根误差衡量两种方法,则可以得到下表:

KNN中K的值 KNN的误差值 Linear Regression的误差值 1 30.626 14.219 5 16.134 14.219 10 14.079 14.219

从表格可以看出,若不考虑计算量的大小,KNN可以得到与Linear Regression一样好的效果

KNN的评价

优点:
1.KNN算法思路十分简单,容易理解。
2.KNN算法没有训练的过程,不必求解相关参数。
3.在一般情况下,KNN均能取得不错的预测效果
缺点:
1.虽然不用求解参数,但每次预测均需要较大的计算量,若对于样本数量及其庞大,且对预测时间有较高要求的实际问题中,往往不能适用。
2.同时,K的选择也是其中一个问题,K的值过大,很容易导致计算量成倍地增加,但对于误差的减小贡献有限。例如测试例子中,若将k取20,误差也有13.6699,仅比k = 5时降低了2.5左右,但计算量的增加确很大。
3.KNN预测结果十分依赖于样本数据,若样本数据数据与待预测数据相距较远。例如样本数据的X值大部分位于1附近,但预测点的值在100附近,这样的预测结果准确率会大打折扣。
4.在分类问题中,KNN采用“硬划分”的方法,即对于一个2分类问题,其预测结果不是0便是1。不像逻辑回归(Logistic Regression)可以得到预测结果是1或是0的概率,甚至可以设置不同的概率阈值来得到相关的结果。

0 0
原创粉丝点击