USING SAS TO FIND THE BEST K FOR K-NEAREST-NEIGHBOR CLASSIFICATION

来源:互联网 发布:个人记工资软件 编辑:程序博客网 时间:2024/06/06 06:58
******(1) USE K-MEANS CLUSTERING TO FIND NEAREAST NEIGHBORS****************;proc modeclus data = sashelp.iris m = 1 k = 4 out = _test1 neighbor;   var petallength petalwidth sepallength sepalwidth;   ods output neighbor = _test2;run;ods html style = harvest image_dpi = 400;proc sgplot data=_test1;   scatter y = density x = species / datalabel = cluster;run;data _test3;   set _test2; retain _tmpid;   if missing(id) = 0 then _tmpid = id; else id = _tmpid;run;data _test4 _test5;   set _test3; by id notsorted;   if first.id then neighbor = 0;   neighbor + 1; output _test4;   if last.id then output _test5;run;ods graphics / width = 6in height = 1in ;proc sgplot data = _test4;   vbar id / response = distance group = neighbor;   xaxis display = none grid;run;proc sgplot data = _test5;   series x = id y = neighbor;   xaxis display = none grid;   yaxis values = (1 to 6) label = 'No. of neighbors';run;******(2) PARTITION RAW DATASET INTO TRAINING AND VALIDATION DATASETS******;%macro partition(data = , target = , smpratio = ,                  seed = , train = , validate = );/*************************************************************** MACRO: partition()* GOAL: divide to training and validation sets that* represent original target variable's proportion* PARAMETERS: data = raw dataset* target = target variable* smprate = ratio between training and validation* set * seed = random seed for sampling**************************************************************/ods select none;ods output variables = _varlist;proc contents data = &data;run;proc sql;   select variable into: num_var separated by ' '   from _varlist   where lowcase(type) = 'num';quit;proc sort data = &data out = _tmp1;   by ⌖run;proc surveyselect data = _tmp1 samprate = &smpratio   out = _tmp2 seed = &seed outall;   strata &target / alloc = prop;run;data &train &validate;   set _tmp2; keep &num_var ⌖   if selected = 0 then output &train;   else output &validate;run;proc datasets nolist;   delete _:;quit;ods select all;%mend;%partition(data = sashelp.iris, target = species, smpratio = 0.5,        seed = 20110901, train = iris_train, validate = iris_validate);%partition(data = sashelp.cars, target = origin, smpratio = 0.5,       seed = 20110901, train = cars_train, validate = cars_validate);********(3) BUILD A USER-DEFINED FUNCTION TO IMPLEMENT K-NN************************;option mstored sasmstore = sasuser;%macro knn_macro / store source;%let target = %sysfunc(dequote(&target));%let input = %sysfunc(dequote(&input));%let train = %sysfunc(dequote(&train));%let validate = %sysfunc(dequote(&validate));%let error = 0;%if %length(&k) = 0 %then %do;%put ERROR: Value for K is missing ;%let error = 1;%end;%else %if %eval(&k) le 0 or %sysfunc(anydigit(&k)) = 0 %then %do;%put ERROR: Value for K is invalid ;%let error = 1;%end;%if %length(&target) = 0 %then %do;%put ERROR: Value for target is missing ;%let error = 1;%end;%if %length(&input) = 0 %then %do;%put ERROR: Value for INPUT is missing ;%let error = 1;%end;%if %sysfunc(exist(&train)) = 0 %then %do; %put ERROR: Training dataset does not exist ;%let error = 1;%end;%if %sysfunc(exist(&validate)) = 0 %then %do;%put ERROR: validation dataset does not exist ;%let error = 1;%end;%if &error = 1 %then %goto finish;ods output classifiedtestclass = _classifiedtestclass;proc discrim data = &train  test = &validate  testout = _scored   method = npar k = &k testlist ;   class ⌖   var &input;run;data _null_;   set _scored nobs = nobs end = eof;   retain count;   if &target ne _into_ then count + 1;   if eof then do;   misc = count / nobs;   call symput('misc', misc);   end;run;%finish:;%mend;proc fcmp outlib = sasuser.knn.funcs;/************************************************************ FUNCTION: knn() * GOAL: apply k-Nearest-Neighbor for classification* INPUT: k = number of nearest neighbours* train = training dataset* validate = validation dataset* target = target variable* input = input variables* OUTPUT: overall misclassification rate***********************************************************/   function knn(k, train $, validate $, target $, input $);   rc = run_macro('knn_macro', k, train, validate, target, input, misc);   if rc eq 0 then return(misc);   else return(.);    endsub;run;******(3) APPLY K-NN FUNCTION TO CLASSIFY IRIS AND CARS DATA****************;%macro errorchk(train = , validate = , target = , input = , k = );/************************************************************ MACRO: errorchk() * GOAL: use knn()function and visualize result* PARAMETERS: train = training dataset* validate = validation dataset* target = target variable* input = input variables* k = number of nearest neighbors***********************************************************/option cmplib = (sasuser.knn) mstored sasmstore = sasuser;data _null_;   misc_rate = knn(&k, symget('train'), symget('validate'),   symget('target'), symget('input'));   call symput('misc_rate', misc_rate);run;proc sql noprint;   select distinct &target into : varlist1 separated by ' '      from &validate;   select distinct cats("'", lowcase(&target), "'")         into: varlist2 separated by ','      from &validate;quit;proc transpose data = _classifiedtestclass out = _out1;   by from&target notsorted;   var &varlist1;run;data _out2;   set _out1;   where lowcase(from&target) in (&varlist2);   label _name_ = 'Level';run;proc sgplot data = _out2;   vbar from&target / response = col1 group = _name_;   xaxis label = 'Real';   yaxis label = 'Classified ';   inset "Overall misclassification rate is:   %sysfunc(putn(&misc_rate, percent8.2))" / position = topright;run;%mend;ods graphics / width= 400px height = 300px ;%errorchk(train = iris_train, validate = iris_validate, target = species,           input = petallength petalwidth sepallength sepalwidth, k = 5);%errorchk(train = cars_train, validate = cars_validate, target = origin,           input = invoice wheelbase length, k = 5);******(4) VISUALIZE CLASSIFICATION RESULT FOR CARS DATA********************;proc rank data = _scored groups = 4 out = _out3;   var invoice;   ranks q;run;proc sort data = _out3 out = _out3;   by q invoice;run;data _out4;   set _out3;   by q ;   retain fmtname "qvar" start end;   if first.q then start = invoice;   if last.q then end = invoice;   if last.q; length label $35;      q + 1;      label = cat(q,'Qu.:', '$',start,'-','$',end);run;proc format cntlin = _out4 fmtlib;run;data _out5(keep=level name invoice wheelbase length q qfmt);   set _out3;   qfmt = put(invoice, qvar.);   level = _into_;   name = '2-classified';     output;   level = origin;   name = '1-real';     output;run;ods graphics / width = 700px height = 500px ;proc sgpanel data = _out5;   panelby qfmt name / layout = lattice onepanel novarname;   scatter x = Wheelbase y = length / group = level ;run;******(5) USE LOGISTIC REGREESION TO CLASSIFY CARS DATA********************;proc logistic data = cars_train;   model origin = invoice wheelbase length / link = glogit;   score data = cars_validate out = logitscored;run;proc freq data = logitscored;   table f_origin*i_origin / nocol nocum nopercent;run;******(6) RUN LOOPS ON K-NN TO FIND THE BEST K VALUE***********************;%macro findk(train = , validate =, target = , input =, maxk =);/************************************************************ MACRO: findk()* GOAL: visualize results of k-NNs by loops* PARAMETERS: train = training dataset* validate = validation dataset* target = target variable* input = input variables* maxk = maximum value of k***********************************************************/option cmplib = (sasuser.knn) mstored sasmstore = sasuser;ods select none;data _tmp3;   do k = 1 to &maxk ;      misc_rate = knn(k, symget('train'), symget('validate'),      symget('target'), symget('input'));      output;   end;run;proc sql;   select min(misc_rate) into: min_misc      from _tmp3;   select k into: bestk separated by ', '      from _tmp3      having misc_rate = min(misc_rate);quit;ods select all;proc sgplot data = _tmp3;   series x = k y = misc_rate;   xaxis grid values = (1 to &maxk by 1)   label = 'k number of neareast neighbours';   yaxis grid values = ( 0 to 0.5 by 0.05)   label = 'Misclassification rate';   refline &min_misc / transparency = 0.3   label = "k = &bestk";   format misc_rate percent8.1;run;proc datasets nolist;delete _:;quit;%mend;ods html style = htmlblue;%findk(train = iris_train, validate = iris_validate, target = species,        input = petallength petalwidth sepallength sepalwidth, maxk = 20);%findk(train = cars_train, validate = cars_validate, target = origin,        input = invoice wheelbase length, maxk = 40);****** END OF ALL CODING *************************************************;

原创粉丝点击