SPOJ COT Count on a tree

来源:互联网 发布:java捕获被0除的异常 编辑:程序博客网 时间:2024/05/19 18:10

SPOJ COT Count on a tree

主席树,倍增LCA

题意

求树上A,B两点路径上第K小的数

思路

同样是可持久化线段树,只是这一次我们用它来维护树上的信息。

我们之前已经知道,可持久化线段树实际上是维护的一个前缀和,而前缀和不一定要出现在一个线性表上。

比如说我们从一棵树的根节点进行DFS,得到根节点到各节点的距离dist[x]——这是一个根-x路径上点与根节点距离的前缀和。

利用这个前缀和,我们可以解决一些树上任意路径的问题,比如在线询问[a,b]点对的距离——答案自然是dist[a]+dist[b]-2*dist[lca(a,b)]。

同理,我们可以利用可持久化线段树来解决树上任意路径的问题。

DFS遍历整棵树,然后在每个节点上建立一棵线段树,某一棵线段树的“前一版本”是位于该节点父亲节点fa的线段树。

利用与之前类似的方法插入点权(排序离散)。那么对于询问[a,b],答案就是root[a]+root[b]-root[lca(a,b)]-root[fa[lca(a,b)]]上的第k大。

代码

#include<bits/stdc++.h>#define M(a,b) memset(a,b,sizeof(a))typedef long long LL;const int MAXN=100007;using namespace std;int cnt,root[MAXN],a[MAXN];struct Node{int l, r, sum;}T[MAXN*40];void update(int l,int r,int &x,int y,int pos,int c){    x=++cnt;    T[x]=T[y];T[x].sum+=c;    if(l==r) return;    int mid=(l+r)>>1;    if(pos<=mid) update(l,mid,T[x].l,T[y].l,pos,c);    else update(mid+1,r,T[x].r,T[y].r,pos,c);}int find_Kth(int l,int r,int rx,int ry,int rlca,int rflca,int k){    if(l>=r) return l;    int mid=(l+r)>>1;    int sum=T[T[rx].l].sum+T[T[ry].l].sum-T[T[rlca].l].sum-T[T[rflca].l].sum;    if(sum>=k) return find_Kth(l,mid,T[rx].l,T[ry].l,T[rlca].l,T[rflca].l,k);    else return find_Kth(mid+1,r,T[rx].r,T[ry].r,T[rlca].r,T[rflca].r,k-sum);}int val[MAXN];int s[MAXN];int hs[MAXN];struct Edge{    int to,ne;}e[MAXN<<1];int head[MAXN],ecnt;void addedge(int from,int to){    ecnt++;e[ecnt].to=to,e[ecnt].ne=head[from];head[from]=ecnt;    ecnt++;e[ecnt].to=from,e[ecnt].ne=head[to];head[to]=ecnt;}int parent[MAXN][50] ,depth[MAXN];void dfs(int n,int u,int la,int d){    parent[u][0]=la;    depth[u]=d;    update(1,n,root[u],root[la],hs[u],1);    for(int i=head[u];~i;i=e[i].ne)    {        if(e[i].to!=la)        {            dfs(n,e[i].to,u,d+1);        }    }}void init_lca(int n,int sz){    memset(depth, -1, sizeof depth);    for(int i=1; i<=n; i++)        if(depth[i]<0)            dfs(sz,i, 1, 0);    for(int k=0; k+1<30; k++)    {        for(int i=1; i<=n; i++)        {            if(parent[i][k]<0) parent[i][k+1] = -1;            else            {                parent[i][k+1] = parent[parent[i][k]][k];            }        }    }}LL lca(int u, int v){    if(depth[v] > depth[u]) swap(u, v);    int dis=depth[u]-depth[v],k=0;    while(dis)    {        if(dis&1)            u=parent[u][k];        dis>>=1;        k++;    }    k=0;    while (u!=v)    {        if ( parent[u][k]!= parent[v][k] || (parent[u][k]== parent[v][k] && k ==0) )        {            u=parent[u][k];            v=parent[v][k];            k++;        }        else k--;    }    return u;}int main(){    int n,m;scanf("%d%d",&n,&m);    for(int i=1;i<=n;i++)        scanf("%d",&val[i]),s[i]=val[i];    sort(s+1,s+n+1);    int sz=unique(s+1,s+n+1)-s-1;    for(int i=1;i<=n;i++)        hs[i]=lower_bound(s+1,s+sz+1,val[i])-s;    M(head,-1);ecnt=0;    for(int i=1;i<n;i++)    {        int a,b;scanf("%d%d",&a,&b);        addedge(a,b);    }    init_lca(n,sz);    parent[1][0]=0;    while(m--)    {        int a,b,k;scanf("%d%d%d",&a,&b,&k);        int l=lca(a,b);        //printf("%d %d %d %d\n",a,b,l,parent[l][0]);        //printf("%d %d\n",depth[a],depth[b]);        int res=find_Kth(1,sz,root[a],root[b],root[l],root[parent[l][0]],k);        printf("%d\n",s[res]);    }}
原创粉丝点击