人工智能(AI)之KNN的基本实现

来源:互联网 发布:全站仪数据采集的步骤 编辑:程序博客网 时间:2024/04/27 15:15

数据集下载地址点我下载
本文主要介绍KNN的实现思想:

  1. KNN的主要思想就是:通过计算训练集与测试集之间的距离(欧氏距离、余弦距离、曼哈顿距离等),然后取出最相似的前N个数据对测试集进行预测
  2. 通过测试之后发现,就本次的数据集而言,把余弦距离以及欧氏距离进行加权来确定预测值结果较好,但仅仅是对于本次的训练数据而言
  3. KNN当中也还有很多细节可以去优化的,比如说对数据集进行一定的归一化,而归一化的方法也是很多的,具体怎么取,也是要看当前的数据集,找到适合的才是最好的
  4. 总之对于预测,找好模型才是最重要的,框架确定之后,再来讨论具体的优化会更有效果
#include <iostream>#include <fstream>#include <cstring>#include <cstdlib>#include <sstream>#include <string.h>#include <set>#include <cmath>#include <iterator>#include <queue>#include <map>using namespace std;#define ANGER 0#define DISGUST 1#define FEAR 2#define JOY 3#define SAD 4#define SURPRISE 5 char c[300];priority_queue<double,vector<double>,greater<double> >q;map<double,int>map1; //从小到大map<double,int, greater<double> >map2; //从大到小double> >两者空格不可少 const string Str1 = "train", Str2 = "test";set<string> sets;bool vector_old[2000][4000];double vector2[2000][4000];double proba[9][2000];double newproba[9][2000];double dis_save[2000];double K;int num1=0;void readanger(){    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/anger_train.txt");    int i = 0;    while (in && i < 246){        memset(c, 0, sizeof(c));        in.getline(c, 300);        string s;        s.append(c, 300);        stringstream ss(s);        ss >> s; // 第一个单词不用        double d;        ss >> d;        proba[ANGER][i++] = d;    }    in.close();}void readdisgust(){    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/disgust_train.txt");    int i = 0;    while (in && i < 246){        memset(c, 0, sizeof(c));        in.getline(c, 300);        string s;        s.append(c, 300);        stringstream ss(s);        ss >> s; // 第一个单词不用        double d;        ss >> d;        proba[DISGUST][i++] = d;    }    in.close();}void readfear(){    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/fear_train.txt");    int i = 0;    while (in && i < 246){        memset(c, 0, sizeof(c));        in.getline(c, 300);        string s;        s.append(c, 300);        stringstream ss(s);        ss >> s; // 第一个单词不用        double d;        ss >> d;        proba[FEAR][i++] = d;    }    in.close();}void readjoy(){    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/joy_train.txt");    int i = 0;    while (in && i < 246){        memset(c, 0, sizeof(c));        in.getline(c, 300);        string s;        s.append(c, 300);        stringstream ss(s);        ss >> s; // 第一个单词不用        double d;        ss >> d;        proba[JOY][i++] = d;    }    in.close();}void readsad(){    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/sad_train.txt");    int i = 0;    while (in && i < 246){        memset(c, 0, sizeof(c));        in.getline(c, 300);        string s;        s.append(c, 300);        stringstream ss(s);        ss >> s; // 第一个单词不用        double d;        ss >> d;        proba[SAD][i++] = d;    }    in.close();}void readsurprise(){    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/surprise_train.txt");    int i = 0;    while (in && i < 246){        memset(c, 0, sizeof(c));        in.getline(c, 300);        string s;        s.append(c, 300);        stringstream ss(s);        ss >> s; // 第一个单词不用        double d;        ss >> d;        proba[SURPRISE][i++] = d;    }    in.close();}void get_proba(){    readanger();    readdisgust();    readfear();    readsad();    readjoy();    readsurprise();}void get_word(){    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Dataset_words.txt");    ofstream out("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/anger.txt");    string str;    int i = 0;    if(in&&out)    {        while(getline(in,str))        {            if(i==0)            {                i++;                continue;            }            else             {                int j = 0;                stringstream ss;                ss << str;                while(!ss.eof())                {                    {                        if(j==0)                        {                            j++;                            ss >> str;                            str = " ";                            sets.insert(str);                        }                        //cout << str <<endl;                        else                        {                            ss >> str;                            sets.insert(str);                        }                    }                }            }        }    }else{        cerr<<"open in or out file error"<<endl;    }    for(set<string>::iterator it = sets.begin();it != sets.end();it++)    {        if(*it != " ")        {            out << *it << endl;            //cout << *it << endl;        }    }    in.close();    out.close();}void clear_stopwords(){    fstream in;    in.open("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Foxstoplist (1).txt");    ofstream out("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Foxstoplistout.txt");    string str;    if(in)    {        while(getline(in,str))        {            stringstream ss;            ss << str;            while(!ss.eof())            {                ss >> str;                out << str <<endl;                for(set<string>::iterator it = sets.begin();it != sets.end();)                {                    if(*it == str)                    {                        sets.erase(it);                        break;                    }                    else                    {                        it++;                    }                }            }        }    }    in.close();    out.close();}void vector_out(){    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Dataset_words.txt");    ofstream out("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/vector.txt");    string str;    int i = 0;    int row_num = 0;    while(in&&out)    {        while(getline(in,str))        {            if(i==0)            {                i++;                continue;            }            else            {                int j = 0;                stringstream ss;                ss << str;                while(!ss.eof())                {                    int lin_num = 0;                    if(j==0)                    {                        j++;                        ss >> str;                    }                    else                    {                        ss >> str;                        for(set<string>::iterator it=sets.begin(); it != sets.end() ; it++)                        {                            if(*it == str)                            {                                vector_old[row_num][lin_num] = true;                            }                            lin_num++;                        }                    }                }            }            row_num++;        }    }    string wenben = "文本编号 ";    out << wenben;    for(set<string>::iterator it= sets.begin(); it != sets.end(); it++)    {        out << *it << " ";    }    in.close();    out.close();}void compute_dis(double K){    for (int i = 0; i < 1246; i++){    double sum = 0;    for (int j = 0; j < sets.size(); j++){        if (vector_old[i][j])         {            sum++;        }    }           for (int j = 0; j < sets.size(); j++){        vector2[i][j] = vector_old[i][j]*1.0/sum;        //out << vector2[i][j] << " ";    }        //out <<endl;    }    for(int mood_n = 0 ; mood_n < 6 ; mood_n++)    {        for(int i = 0 ; i < 1000 ; i++)        {            int dis_num=0;            double pro_sum = 0;            double dis;            int pos;            double max_dis = 0;            double min_dis = 10000;            map<double,int>map1;            map<double,int, greater<double> >map2;            for(int j = 0 ; j < 246 ; j++)            {                dis = 0;                double angle = 0;                double xy_sum=0;                double xx=0;                double yy=0;                double save_angle[2000]={0};                for(int k = 0 ; k < sets.size() ; k++)                {                    xy_sum+=vector_old[i+246][k]*vector_old[j][k];                    xx+=vector_old[i+246][k]*vector_old[i+246][k];                    yy+=vector_old[j][k]*vector_old[j][k];                    //dis += (vector2[i+246][k]-vector2[j][k])*(vector2[i+246][k]-vector2[j][k]);                }                dis_save[j] = xx + yy - 2*xy_sum;                 angle = xy_sum/(sqrt(xx)*sqrt(yy));                //angle = angle*(1/sqrt(dis_save[j]));                angle = 0.8*angle + 0.2*dis_save[j];                map2.insert(make_pair(angle,j));                /*                for(int k = 0 ; k < sets.size() ; k++)                {                    dis += (vector2[i+246][k]-vector2[j][k])*(vector2[i+246][k]-vector2[j][k]);                }                dis = sqrt(dis);                dis_sum+=dis;                map1.insert(make_pair(dis,j));                */            }            cout << "i:" << i <<endl;            /*            for(map<double,int>::iterator it1 = map1.begin();it1!=map1.end();it1++)            {                double temp = it1->first;                temp = temp/dis_sum;                map1.insert(make_pair(temp,it1->second));            }            */            int K_i = 1;            double dis_sum = 0;            for(map<double,int>::iterator it = map2.begin(); it != map2.end(); it++)            {                if(K_i>K)                {                    break;                }                else                {                    K_i++;                    pro_sum += proba[mood_n][it->second];                    //dis_sum+=(1/(dis_save[it->second]*dis_save[it->second]));                    /*                    for(map<double,int>::iterator it1 = map1.begin();it1!=map1.end();it1++)                    {                        if(it->second == it1->second)                        {                            pro_sum = pro_sum + 0.6*                            break;                        }                    }                    */                }            }            newproba[mood_n][i] = pro_sum*1.0/K;        }    }    cout << "happy" <<endl;}void print(){    for(int i = 0 ; i < 6 ; i++)    {        ofstream f;        switch(i)        {            case ANGER:    f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/anger_predict.txt"); break;            case DISGUST:  f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/disgust_predict.txt"); break;            case FEAR:     f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/fear_predict.txt"); break;            case JOY:      f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/joy_predict.txt"); break;            case SAD:      f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/sad_predict.txt"); break;            case SURPRISE: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/surprise_predict.txt"); break;        }        for(int j = 0 ; j < 1000 ; j++)        {            f << newproba[i][j] <<endl;            //cout << newproba[i][j] <<endl;        }        f.close();    }}int main(){    cout << "请输入k:" <<endl;    cin >> K;    get_word();    cout << 0 <<endl;    clear_stopwords();    cout << 1 <<endl;    get_proba();    cout << 2 <<endl;    vector_out();    cout << 3 <<endl;    compute_dis(K);    cout << 4 <<endl;    print();    cout << 5 <<endl;    cout << sets.size() <<endl;     return 0;}
1 0
原创粉丝点击