Poj 1741&&CF 161D 点分治入门

来源:互联网 发布:淘宝网银支付最高限额 编辑:程序博客网 时间:2024/06/10 02:01

第一次写点分治,基本上是自己手撸的,由于昨天做了下准备工作(求树的重心),这题就显得很傻逼了。唯一和预想不太一样的是dfs过程记录了father,这样就不用记两个vis了,感觉在代码的简洁性上还是很巧妙的。还有,poj的垃圾多组测试数据要注意初始化,WA了一发。

Poj 1741代码:

#include <cstdio>#include <cstring>#include <algorithm>#include <vector>using namespace std;typedef long long ll;struct Edge{    int to, next, len;}edge[20500];int head[20500];bool vis[20500];int siz[20500];int f[20500];int cnt, rt, sum, n, k;ll ans;ll vs[20500];int vsc;ll v[20500];int vc;void init(){    memset(head, -1, sizeof(head));    memset(vis, false, sizeof(vis));    cnt=0;    ans=0;}void add(int u, int v, int w){    edge[cnt].to=v;    edge[cnt].len=w;    edge[cnt].next=head[u];    head[u]=cnt++;}void getrt(int u, int fa){    siz[u]=1;    f[u]=0;    for(int i=head[u];~i;i=edge[i].next){        if(edge[i].to!=fa&&!vis[edge[i].to]){            getrt(edge[i].to, u);            siz[u]+=siz[edge[i].to];            f[u]=max(f[u], siz[edge[i].to]);        }    }    f[u]=max(f[u], sum-siz[u]);    if(f[u]<f[rt])rt=u;}void getdeep(int u, int fa, ll len){    vs[++vsc]=len;    v[++vc]=len;    for(int i=head[u];~i;i=edge[i].next){        if(edge[i].to!=fa&&!vis[edge[i].to]){            getdeep(edge[i].to, u, len+edge[i].len);        }    }}void solve(int u){    vis[u]=true;    vsc=0;    for(int i=head[u];~i;i=edge[i].next){        if(!vis[edge[i].to]){            vc=0;            getdeep(edge[i].to, u, edge[i].len);            sort(v+1, v+1+vc);            int l=1, r=vc;            while(l<r){                if(v[l]+v[r]<=k){                    ans-=(r-l);                    l++;                }                else r--;            }        }    }    sort(vs+1, vs+1+vsc);    int l=1, r=vsc;    while(l<r){        if(vs[l]+vs[r]<=k){            ans+=(r-l);            l++;        }        else r--;    }    for(int i=1;i<=vsc;i++){        if(vs[i]<=k)ans++;    }    for(int i=head[u];~i;i=edge[i].next){        if(!vis[edge[i].to]){            rt=0, sum=siz[edge[i].to];            getrt(edge[i].to, 0);            solve(rt);        }    }}int main(){    while(~scanf("%d%d", &n, &k)){        if(!n&&!k)break;        init();        for(int i=1;i<n;i++){            int u, v, w;            scanf("%d%d%d", &u, &v, &w);            add(u, v, w);            add(v, u, w);        }        f[0]=30000;        rt=0, sum=n;        getrt(1, 0);        solve(rt);        printf("%lld\n", ans);    }}

又写了一个模板题CF 161D,纯手打,记了下时间,大概35min AC,这速度感觉勉强可以接受吧。

CF 161D代码:

#include <bits/stdc++.h>using namespace std;struct Edge{    int to, next;}e[105000];int head[105000];int cnt;int siz[105000];int f[105000];bool vis[105000];int n, k, sum, rt;long long ans;int v1[105000], v2[105000];int t1, t2;void init(){    memset(head, -1, sizeof(head));    cnt=0;    f[0]=100000<<1;    memset(vis, false, sizeof(vis));}void add(int u, int v){    e[cnt].to=v;    e[cnt].next=head[u];    head[u]=cnt++;    ans=0;}void getdeep(int u, int fa, int len){    v1[++t1]=len;    v2[++t2]=len;    for(int i=head[u];~i;i=e[i].next){        if(e[i].to!=fa&&!vis[e[i].to]){            getdeep(e[i].to, u, len+1);        }    }}void getrt(int u, int fa){    siz[u]=1;    f[u]=0;    for(int i=head[u];~i;i=e[i].next){        if(e[i].to!=fa&&!vis[e[i].to]){            getrt(e[i].to, u);            siz[u]+=siz[e[i].to];            f[u]=max(f[u], siz[e[i].to]);        }    }    f[u]=max(f[u], sum-siz[u]);    if(f[u]<f[rt])rt=u;}void solve(int u){    vis[u]=true;    t1=0;    for(int i=head[u];~i;i=e[i].next){        if(!vis[e[i].to]){            t2=0;            getdeep(e[i].to, u, 1);            sort(v2+1, v2+1+t2);            for(int j=1;j<=t2&&2*v2[j]<=k;j++){                int p=upper_bound(v2+1, v2+1+t2, k-v2[j])-lower_bound(v2+1, v2+1+t2, k-v2[j]);                if(2*v2[j]==k){                    ans-=1LL*p*(p-1)/2;                    break;                }                ans-=p;            }        }    }    sort(v1+1, v1+1+t1);    for(int i=1;i<=t1&&2*v1[i]<=k;i++){        int p=upper_bound(v1+1, v1+1+t1, k-v1[i])-lower_bound(v1+1, v1+1+t1, k-v1[i]);        if(2*v1[i]==k){            ans+=1LL*p*(p-1)/2;            break;        }        ans+=p;    }    for(int i=1;i<=t1;i++)ans+=(v1[i]==k);    for(int i=head[u];~i;i=e[i].next){        if(!vis[e[i].to]){            rt=0;sum=siz[e[i].to];            getrt(e[i].to, 0);            solve(rt);        }    }}int main(){    init();    cin>>n>>k;    for(int i=1;i<n;i++){        int u, v;        cin>>u>>v;        add(u, v);        add(v, u);    }    rt=0;sum=n;    getrt(1, 0);    solve(rt);    cout<<ans<<endl;}
原创粉丝点击