HDOJ 5021 Revenge of kNN II

来源:互联网 发布:淘宝卖家中心联系电话 编辑:程序博客网 时间:2024/06/11 15:31

5021 Revenge of kNN II

Time Limit: 8000/5000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others)
Total Submission(s): 196    Accepted Submission(s): 56


Problem Description
In pattern recognition, the k-Nearest Neighbors algorithm (or k-NN for short) is a non-parametric method used for classification and regression. In both cases, the input consists of the k closest training examples in the feature space.
In k-NN regression, the output is the property value for the object. This value is the average of the values of its k nearest neighbors.
---Wikipedia

Today, kNN takes revenge on you, again. You have to handle a kNN case in one-dimensional coordinate system. There are N points with a position Xi and value Vi. Then there are M kNN queries for point with index i, recalculate its value by averaging the values its k-Nearest Neighbors. Note you have to replace the value of i-th point with the new calculated value. And if there is a tie while choosing k-Nearest Neighbor, choose the one with the minimal index first.
(Have you ever tried the problem “Revenge of kNN”? They are twin problems!)
 

Input
The first line contains a single integer T, indicating the number of test cases. 

Each test case begins with two integers N and M. Then N lines follows, each line contains two integers Xi and Vi. Then M lines with the queried index Qi and Ki follows, in which Ki indicating the number of k-Nearest Neighbors

[Technical Specification]
1. 1 <= T <= 5
2. 2 <= N <= 100 000
3. 1 <= M <= 100 000
4. 1 <= Vi <= 1 000
5. 1 <= Xi <= 1 000 000 000, and no two Xi are identical.
6. 1 <= Qi <= N
7. 1 <= Ki <= N - 1
 

Output
For each test case, output sum of all queries rounded to three fractional digits.
 

Sample Input
15 31 22 33 64 85 82 23 24 2
 

Sample Output
17.000
Hint
For the first query, the 2-NN for point 2 is point 1 and 3, so the new value is (2 + 6) / 2 = 4.For the second query, the 2-NN for point 3 is point 2 and 4, and the value of point 2 is changed to 4 by the last query, so the new value is (4 + 8) / 2 = 6.Huge input, faster I/O method is recommended.
 

官方思路:
考虑如何快速求出距离最近的k个点的权值之和,这里的距离具有明显的二分性。这样可以在log(MAXX)的时间内求出k个点的坐标范围。求出之后的问题是,区间求和,单点更新,树状数组足够解决这个问题了。
在二分的时候注意K和K+1可能都是符合条件的,如果算出K+1被舍弃的话,减小Distance可能得到的是K-1,并不连续,所以要判断一下这种情况。


代码如下:(二分搜索有些难写,汗!!)
#include <iostream>#include <cstdio>#include <algorithm>#include <cstring>using namespace std;const int MAXN = 100005;typedef struct node{    int id;    int x;    int v;}node;node a[MAXN];int n, m;int index[MAXN];double C[MAXN];int L, R;int q, k;bool cmp(node a, node b){    return a.x<b.x;}int lowbit(int x){    return x&(-x);}void add(int i, double v){    while(i<=n){        C[i] += v;        i += lowbit(i);    }}double Sum(int i){    double res = 0;    while(i>0){        res += C[i];        i -= lowbit(i);    }    return res;}int findL(int x){    int l = 1, r = n, m, res1;    while(l<=r){        m = (l+r)>>1;        if(a[m].x>=x){            res1 = m;            r = m-1;        }        else l = m+1;    }    return res1;}int findR(int x){    int l = 1, r = n, m, res2;    while(l<=r){        m = (l+r)>>1;        if(a[m].x<=x){            res2 = m;            l = m+1;        }        else r = m-1;    }    return res2;}void findLR(){    int l, r, mid;    l = a[1].x;    r = a[n].x;    while(l<=r){        mid = (l+r)>>1;        L = findL(a[q].x-mid);        R = findR(a[q].x+mid);        if(R-L<k) l = mid+1;        else if(R-L>k+1) r = mid-1;        else if(R-L==k){            return;        }else if(R-L==k+1){            if(a[q].x-a[L].x == a[R].x-a[q].x){                if(a[L].id<a[R].id)                    R--;                else                    L++;            }else if(a[q].x-a[L].x<a[R].x-a[q].x)                R--;            else                L++;            return;        }    }}int main(){    int T;    double ans;    scanf("%d", &T);    while(T--){        scanf("%d %d", &n, &m);        for(int i=1;i<=n;i++){            scanf("%d %d", &a[i].x, &a[i].v);            a[i].id = i;        }        sort(a+1, a+n+1, cmp);        memset(C, 0, sizeof(C));        for(int i=1;i<=n;i++){            index[a[i].id] = i;            add(i, a[i].v);        }        ans = 0.0;        while(m--){            scanf("%d %d", &q, &k);            q = index[q];            findLR();            double s = Sum(R) - Sum(L-1);            double t = Sum(q) - Sum(q-1);            add(q, (s-t)/k-t);            ans += (s-t)/k;        }        printf("%.3f\n", ans);    }    return 0;}

另一种二分搜索代码:
#include <iostream>#include <cstdio>#include <algorithm>#include <cstring>using namespace std;const int MAXN = 100005;typedef struct node{    int id;    int x;    int v;}node;node a[MAXN];int n, m;int index[MAXN];double C[MAXN];int L, R;int q, k;bool cmp(node a, node b){    return a.x<b.x;}int lowbit(int x){    return x&(-x);}void add(int i, double v){    while(i<=n){        C[i] += v;        i += lowbit(i);    }}double Sum(int i){    double res = 0;    while(i>0){        res += C[i];        i -= lowbit(i);    }    return res;}int findL(int x){    int l = 1, r = n, m;    while(l<r){        m = (l+r)>>1;        if(a[m].x>=x) r = m;        else l = m+1;    }    return l;}int findR(int x){    int l = 1, r = n, m;    while(l<r){        m = (l+r+1)>>1;        if(a[m].x<=x) l = m;        else r = m-1;    }    return r;}void findLR(){    int l, r, mid;    l = a[1].x;    r = a[n].x;    while(l<=r){        mid = (l+r)>>1;        L = findL(a[q].x-mid);        R = findR(a[q].x+mid);        if(R-L<k) l = mid+1;        else if(R-L>k+1) r = mid-1;        else if(R-L==k){            return;        }else if(R-L==k+1){            if(a[q].x-a[L].x == a[R].x-a[q].x){                if(a[L].id<a[R].id)                    R--;                else                    L++;            }else if(a[q].x-a[L].x<a[R].x-a[q].x)                R--;            else                L++;            return;        }    }}int main(){    int T;    double ans;    scanf("%d", &T);    while(T--){        scanf("%d %d", &n, &m);        for(int i=1;i<=n;i++){            scanf("%d %d", &a[i].x, &a[i].v);            a[i].id = i;        }        sort(a+1, a+n+1, cmp);        memset(C, 0, sizeof(C));        for(int i=1;i<=n;i++){            index[a[i].id] = i;            add(i, a[i].v);        }        ans = 0.0;        while(m--){            scanf("%d %d", &q, &k);            q = index[q];            findLR();            double s = Sum(R) - Sum(L-1);            double t = Sum(q) - Sum(q-1);            add(q, (s-t)/k-t);            ans += (s-t)/k;        }        printf("%.3f\n", ans);    }    return 0;}


0 0
原创粉丝点击