SPOJ COT2 树上的莫队算法,树上区间查询

来源:互联网 发布:如何测试网络丢包 编辑:程序博客网 时间:2024/05/17 01:52

题意:n个节点形成的一棵树。每个节点有一个值。m次查询,求出(u,v)路径上出现了多少个不同的数。

树上的莫队算法,同样将树分成siz=sqrt(n)块,然后离线操作。先对树dfs一遍,每当子树节点个数num>=siz,就将这num个分成一块。读取所有的查询按左端点所在块排序。

重点在于怎么进行区间转移,对路径的lca特殊处理,参考博客http://blog.csdn.net/kuribohg/article/details/41458639  


用倍增法求lca单次要用logn复杂度,要跑3200ms。有个地方可以优化,就是知道了所有的查询,也就是事先知道了转移路径,可以用离线的方法求O(n)求出所有需要用到的lca,这个写起来比较麻烦,不过可以优化到1800ns。代码写的比较挫。。。。

logn求lca:3200+ms

#include <iostream>#include <cstdio>#include <cstring>#include <cmath>#include <cctype>#include <string>#include <vector>#include <map>#include <set>#include <vector>#include <queue>#include <stack>#include <algorithm>using namespace std;const int maxn=4e4+10;const int maxm=1e5+10;int n,m, siz;vector<int> g[maxn];int a[maxn], b[maxn], ans[maxm];int tot[maxn], in[maxn];int fa[maxn][20], dep[maxn];struct Query{    int l, r, id;    int st,ed;    bool operator <(const Query& a) const    {        return st!=a.st? st<a.st: ed<a.ed; //先按左端点所在块先后排序,其次考虑又右端点所在块    }};Query q[maxm];int tag, bel[maxn];int st[maxn], top;int dfs(int u, int par, int d, int &cnt){    dep[u]=d; fa[u][0]=par;    int num=0;    for(int i=0; i<g[u].size(); i++){        int v=g[u][i];        if(v!=par){            num+=dfs(v, u, d+1, cnt);            if(num>=siz){ //子树大小>=sqrt(n),分成一块                for(int i=0; i<num; i++)                    bel[st[--top]]=tag;                tag++;                num=0;            }        }    }    st[top++]=u;//记录子树遍历的点    return num+1;}void init(){    for(int i=0; i<=n; i++) g[i].clear();    memset(tot, 0, sizeof(tot));    memset(in, 0, sizeof(in));    siz=sqrt(n);    for(int i=1;i<=n; i++) scanf("%d",&a[i]), b[i]=a[i];    sort(b+1, b+n+1);    for(int i=1; i<=n; i++)        a[i]=lower_bound(b+1, b+n+1, a[i])-b;    for(int i=0; i<n-1; i++){        int u,v;        scanf("%d%d", &u, &v);        g[u].push_back(v);        g[v].push_back(u);    }    int cnt=0; tag=top=0;    int num=dfs(1, -1, 0, cnt);    for(int i=0; i<num; i++)        bel[st[--top]]=tag; //最后剩下的数也分成一块    for(int i=1; i<20; i++){        for(int u=1; u<=n; u++)            if(fa[u][i-1]==-1)                fa[u][i]=-1;            else fa[u][i]=fa[fa[u][i-1]][i-1];    }    for(int i=0; i<m; i++){        scanf("%d%d", &q[i].l, &q[i].r);        if(bel[q[i].l]>bel[q[i].r])            swap(q[i].l, q[i].r);        q[i].id=i;        q[i].st=bel[q[i].l];        q[i].ed=bel[q[i].r];    }    sort(q, q+m);}int lca(int u, int v){    if(dep[u]>dep[v]) swap(u, v);    for(int i=0; i<20; i++)        if((dep[v]-dep[u])>>i&1)            v=fa[v][i];    if(u==v) return u;    for(int i=19; i>=0; i--){        if(fa[u][i]!=fa[v][i]){            u=fa[u][i];            v=fa[v][i];        }    }    return fa[u][0];}void solve(){    int res=0;    int cu=1, cv=1;    for(int i=0; i<m; i++){        int nu=q[i].l, nv=q[i].r;        int par=lca(cu, nu);        while(cu!=par){            if(in[cu]){                if(--tot[a[cu]]==0)                    res--;            }            else if(++tot[a[cu]]==1)                res++;            in[cu]^=1;            cu=fa[cu][0];        }        cu=nu;        while(cu!=par){            if(in[cu]){                if(--tot[a[cu]]==0)                    res--;            }            else if(++tot[a[cu]]==1)                res++;            in[cu]^=1;            cu=fa[cu][0];        }        cu=nu;        par=lca(cv, nv);        while(cv!=par){            if(in[cv]){                if(--tot[a[cv]]==0)                    res--;            }            else if(++tot[a[cv]]==1)                res++;            in[cv]^=1;            cv=fa[cv][0];        }        cv=nv;        while(cv!=par){            if(in[cv]){                if(--tot[a[cv]]==0)                    res--;            }            else if(++tot[a[cv]]==1)                res++;            in[cv]^=1;            cv=fa[cv][0];        }        cv=nv;        par=lca(cu, cv);        ans[q[i].id]=res+(!tot[a[par]]);    }}int main(){    while(scanf("%d%d", &n, &m)==2){        init();        solve();        for(int i=0; i<m; i++)            printf("%d\n", ans[i]);    }    return 0;}


离线查询lca:1800+ms

#include <iostream>#include <cstdio>#include <cstring>#include <cmath>#include <cctype>#include <string>#include <vector>#include <map>#include <set>#include <vector>#include <queue>#include <stack>#include <algorithm>using namespace std;#pragma comment(linker, "/STACK:1024000000,1024000000")typedef pair<int,int> P;#define fir first#define sec secondconst int maxn=4e4+10;const int maxm=1e5+10;int n,m, siz;vector<int> g[maxn];int first[maxn],ltot=0, nxt[6*maxm];P lq[6*maxm];//所有需要查询的lca,lq[i].first保存v,second保存查询的idint a[maxn], b[maxn], ans[maxm];int tot[maxn], in[maxn], fa1[maxn];int fa[maxn], lca[3*maxm], col[maxn];int bel[maxn],st[maxn],top=0;struct Query{    int l, r, id;    int st,ed;    bool operator <(const Query& a) const    {        return st!=a.st? st<a.st: ed<a.ed;    }};Query q[maxm];int tag;int dfs(int u, int par, int &cnt)//分块{    fa1[u]=par;    int num=0;    for(int i=0; i<g[u].size(); i++){        int v=g[u][i];        if(v!=par)            num+=dfs(v, u, cnt);        if(num>=siz){            for(int i=0; i<num; i++)                bel[st[--top]]=tag;            tag++;            num=0;        }    }    st[top++]=u;    return num+1;}int find(int u){    return fa[u]==u?u:(fa[u]=find(fa[u]));}int unite(int x, int y){    x=fa[x];    y=fa[y];    fa[y]=x;}void dfs2(int u, int par)//离线查询所有lca{    col[u]=1;    for(int i=first[u]; i!=-1; i=nxt[i]){        int v=lq[i].fir, id=lq[i].sec;        if(!col[v]) continue;        else if(col[v]==1){            lca[id]=v;        }        else{            lca[id]=find(v);        }    }    for(int i=0; i<g[u].size(); i++){        int v=g[u][i];        if(v!=par)            dfs2(v, u);    }    col[u]=2;    unite(par, u);}void add(int u, int v, int id)//查询m<=1e5,数比较多所以用前向星实现优化{    lq[ltot]=P(v,id);    nxt[ltot]=first[u];    first[u]=ltot++;}void init(){    for(int i=0; i<=n; i++) g[i].clear();    memset(tot, 0, sizeof(tot));    memset(in, 0, sizeof(in));    siz=sqrt(n);    for(int i=1;i<=n; i++) scanf("%d", a+i), b[i]=a[i];    sort(b+1, b+n+1);    for(int i=1; i<=n; i++)        a[i]=lower_bound(b+1, b+n+1, a[i])-b;    for(int i=0; i<n-1; i++){        int u,v;        scanf("%d%d", &u, &v);        g[u].push_back(v);        g[v].push_back(u);    }    int cnt=0; top=0; tag=0;    int num=dfs(1, -1, cnt);    for(int i=0; i<num; i++)        bel[st[--top]]=tag;    for(int i=0; i<m; i++){        scanf("%d%d", &q[i].l, &q[i].r);        if(bel[q[i].l]>bel[q[i].r])            swap(q[i].l, q[i].r);        q[i].id=i;        q[i].st=bel[q[i].l];        q[i].ed=bel[q[i].r];    }    sort(q, q+m);    cnt=0; ltot=0;    memset(first, -1, sizeof(first));    add(1, q[0].l, cnt);    add(q[0].l, 1, cnt++);    add(1, q[0].r, cnt);    add(q[0].r, 1, cnt++);    add(q[0].r, q[0].l, cnt);    add(q[0].l, q[0].r, cnt++);    //add(q[0].r, q[0].l, cnt++);    for(int i=0; i<m-1; i++){    add(q[i].l, q[i+1].l, cnt);//第i个查询左端点向第i+1个左端点转移,所以需要它们之间的lca    add(q[i+1].l, q[i].l, cnt++);    add(q[i].r, q[i+1].r, cnt);//第i个查询右端点向第i+1个右端点转移    add(q[i+1].r, q[i].r, cnt++);    add(q[i+1].r, q[i+1].l, cnt);//左端点和右端点的lca    add(q[i+1].l, q[i+1].r,cnt++);    }    for(int i=0; i<=n; i++) fa[i]=i;    memset(col, 0, sizeof(col));    dfs2(1, 0);}void solve(){    int res=0;    int cu=1, cv=1;    for(int i=0; i<m; i++){        int nu=q[i].l, nv=q[i].r;        //cout<<lca[i*3]<<' '<<lca[i*3+1]<<' '<<lca[i*3+2]<<endl;        int par=lca[i*3];        while(cu!=par){            if(in[cu]){                if(--tot[a[cu]]==0)                    res--;            }            else if(++tot[a[cu]]==1)                res++;            in[cu]^=1;            cu=fa1[cu];        }        cu=nu;        while(cu!=par){            if(in[cu]){                if(--tot[a[cu]]==0)                    res--;            }            else if(++tot[a[cu]]==1)                res++;            in[cu]^=1;            cu=fa1[cu];        }        cu=nu;        par=lca[i*3+1];        while(cv!=par){            if(in[cv]){                if(--tot[a[cv]]==0)                    res--;            }            else if(++tot[a[cv]]==1)                res++;            in[cv]^=1;            cv=fa1[cv];        }        cv=nv;        while(cv!=par){            if(in[cv]){                if(--tot[a[cv]]==0)                    res--;            }            else if(++tot[a[cv]]==1)                res++;            in[cv]^=1;            cv=fa1[cv];        }        cv=nv;        par=lca[i*3+2];        ans[q[i].id]=res+(!tot[a[par]]);    }}int main(){    while(cin>>n>>m){        init();        solve();        for(int i=0; i<m; i++)            printf("%d\n", ans[i]);    }    return 0;}


0 0
原创粉丝点击