hiho 1167Advanced Theoretical Computer Science LCA

来源:互联网 发布:高阶矩阵求逆 编辑:程序博客网 时间:2024/05/16 08:28

题意:给n与p,给一棵树n-1条边,p个小松鼠,每个小松鼠有一个活动的路径为a~b的路径,如果两个小松鼠的路径相交了那么说明这两个小松鼠是朋友,问有多少对朋友

思路:如果两个路径相交,那么必然的其中一个路径的lca一定在相交的部分上,这个画一下树就可以明白了。

那么我们就可以统计这些路径上的每个点被多少条路径经过,每个点充当了多少次lca,分别用d[]与sum[]记录

如果该点充当lca那么就不被d记录进去,然后多少对朋友就用ans+=d[i]*sum[i]+sum[i]*(sum[i]-1)/2

前者是这个点不是lca的情况与是lca的相乘,后者是该点充当不同路径的lca时的朋友数

d可以用树上的差分数列的思想来统计(可看2014上海网络赛1003tree那个文章),sum直接累加即可。

#include <bits/stdc++.h>using namespace std;const int maxn=100005;struct node{int v,next;};node edge[maxn<<1];int head[maxn],cnt;struct node2{int v,id,next;};node2 qu[maxn<<1];int ask[maxn],ecnt;int a[maxn],b[maxn];int n;int p;int fa[maxn],lca[maxn],d[maxn],sum[maxn],vis[maxn];void add(int a,int b){    edge[cnt].v=b;    edge[cnt].next=head[a];    head[a]=cnt++;}void add2(int a,int b,int id){    qu[ecnt].v=b;    qu[ecnt].next=ask[a];    qu[ecnt].id=id;    ask[a]=ecnt++;}int Find(int x){    if(x==fa[x]) return x;    return fa[x]=Find(fa[x]);}void init(){        memset(head,-1,sizeof(head));        memset(ask,-1,sizeof(ask));        cnt=ecnt=0;}void Lca(int u){    vis[u]=1;    fa[u]=u;    for(int i=ask[u];i!=-1;i=qu[i].next){        int v=qu[i].v,id=qu[i].id;        if(vis[v]) lca[id]=Find(v);    }    for(int i=head[u];i!=-1;i=edge[i].next){        int v=edge[i].v;        if(vis[v]) continue;        Lca(v);        fa[v]=u;    }}void dfs(int u,int fa){        for(int i=head[u];i!=-1;i=edge[i].next){            int v=edge[i].v;            if(v==fa) continue;            dfs(v,u);            d[u]+=d[v];        }}int main(){    int x,y;    init();    scanf("%d%d",&n,&p);    for(int i=1;i<n;i++){        scanf("%d%d",&x,&y);        add(x,y);add(y,x);    }    for(int i=1;i<=p;i++){        scanf("%d%d",&x,&y);        d[x]++,d[y]++;        add2(x,y,i);add2(y,x,i);    }    Lca(1);//    for(int i=1;i<=p;i++){//        printf("lca: %d\n",lca[i]);//    }    for(int i=1;i<=p;i++){        d[lca[i]]-=2;        sum[lca[i]]++;    }    dfs(1,0);    long long ans=0;    for(int i=1;i<=n;i++){        ans+=(long long)d[i]*sum[i]+(long long)sum[i]*(sum[i]-1)/2;    }    printf("%lld\n",ans);return 0;}




0 0
原创粉丝点击