机器学习1——KNN

来源:互联网 发布:通快激光编程软件 编辑:程序博客网 时间:2024/06/03 19:52

KNN算法学习
matlab代码1

%% KNNclear allclc%% datatrainData = [1.0,2.0;1.2,0.1;0.1,1.4;0.3,3.5];trainClass = [1,1,2,2];testData = [0.5,2.3];k = 3;%% distancerow = size(trainData,1);col = size(trainData,2);test = repmat(testData,row,1);dis = zeros(1,row);for i = 1:row    diff = 0;    for j = 1:col        diff = diff + (test(i,j) - trainData(i,j)).^2;    end    dis(1,i) = diff.^0.5;end%% sortjointDis = [dis;trainClass];sortDis= sortrows(jointDis');sortDisClass = sortDis';%% findclass = sort(2:1:k);member = unique(class);num = size(member);max = 0;for i = 1:num    count = find(class == member(i));    if count > max        max = count;        label = member(i);    endenddisp('最终的分类结果为:');fprintf('%d\n',label)

运行之后的结果是,最终的分类结果为:2。和预期结果一样。

matlab代码2

function y = knn(X, X_train, y_train, K)%KNN k-Nearest Neighbors Algorithm.%%   INPUT:  X:         testing sample features, P-by-N_test matrix.%           X_train:   training sample features, P-by-N matrix.%           y_train:   training sample labels, 1-by-N row vector.%           K:         the k in k-Nearest Neighbors%%   OUTPUT: y    : predicted labels, 1-by-N_test row vector.%% Author: Ren Kan[~,N_test] = size(X);predicted_label = zeros(1,N_test);for i=1:N_test    [dists, neighbors] = top_K_neighbors(X_train,y_train,X(:,i),K);     % calculate the K nearest neighbors and the distances.    predicted_label(i) = recog(y_train(neighbors),max(y_train));    % recognize the label of the test vector.endy = predicted_label;end

查找最近K近邻的部分代码:

function [dists,neighbors] = top_K_neighbors( X_train,y_train,X_test,K )% Author: Ren Kan%   Input: %   X_test the test vector with P*1%   X_train and y_train are the train data set%   K is the K neighbor parameter[~, N_train] = size(X_train);test_mat = repmat(X_test,1,N_train);dist_mat = (X_train-double(test_mat)) .^2;% The distance is the Euclid Distance.dist_array = sum(dist_mat);[dists, neighbors] = sort(dist_array);% The neighbors are the index of top K nearest points.dists = dists(1:K);neighbors = neighbors(1:K);end

利用概率求解测试集预测标签部分代码:

function result = recog( K_labels,class_num )%RECOG Summary of this function goes here%   Author: Ren Kan[~,K] = size(K_labels);class_count = zeros(1,class_num+1);for i=1:K    class_index = K_labels(i)+1; % +1 is to avoid the 0 index reference.    class_count(class_index) = class_count(class_index) + 1;end[~,result] = max(class_count);result = result - 1; % Do not forget -1 !!!end

应用可以有:
手写体识别、数字验证码识别等。