spoj COT 可持久化数据结构 (LCA模版)

来源:互联网 发布:怎么做淘宝客漏洞赚钱 编辑:程序博客网 时间:2024/05/19 19:56

查询树链第K大 。                


每个版本的线段树维护的是 从这个节点到 根的 树链的版本, 由于树链第K大,在统计比X 小的数个数时 是可以 进行加减法运算的,所以  就可以用可持久化数据结构。

维护个数时 , sum = f(a) + f(b) - f(c) -f(d)    : c 为 a,b 的最近公共祖先, d 为 c 的父亲节点。这样就是 四个版本运算。

同时:二分可以直接在树上跑,判断 左半区域的和 是否大于K,大于K 说明第K大的值 还在 左区间, 相反在右区间里查 第K -sum 大的数。

复杂度 O(nlgn) 如果直接二分区间 复杂度是O(nlgnlgn)。

倍增 LCA 算法:

const int K = 18;int d[maxn];int p[maxn][K];void dfs(int rt,int f){      d[rt]=d[f]+1;      p[rt][0]=f;      int pos = mp[num[rt]];    root[rt] = update(pos,1,n,1,root[f]);    for(int i=1;i<K;i++) p[rt][i] = p[p[rt][i-1]][i-1];    for(int i=head[rt];i!=-1;i= edge[i].next){     int son = edge[i].v;     if(son==f)continue;     dfs(son,rt);    }  }  int lca(int a,int b){    if(d[a]>d[b]) swap(a,b);    if(d[a]<d[b]){        int del = d[b]-d[a];        for(int i=0;i<K;i++) if(del &(1<<i)) b= p[b][i];    }    if(a!=b){        for(int i= K-1;i>=0;i--){            if(p[a][i]!= p[b][i]){                a = p[a][i],b = p[b][i];            }        }        a= p[a][0],b = p[b][0];    }    return a;}

代码:

#include <vector>#include <list>#include <map>#include <set>#include <deque>#include <stack>#include <cstring>#include <bitset>#include <algorithm>#include <functional>#include <numeric>#include <utility>#include <sstream>#include <iostream>#include <iomanip>#include <cstdio>#include <cmath>#include <cstdlib>#include <ctime>#include <assert.h>#include <queue>#define REP(i,n) for(int i=0;i<n;i++)#define TR(i,x) for(typeof(x.begin()) i=x.begin();i!=x.end();i++)#define ALLL(x) x.begin(),x.end()#define SORT(x) sort(ALLL(x))#define CLEAR(x) memset(x,0,sizeof(x))#define FILLL(x,c) memset(x,c,sizeof(x))using namespace std;const double eps = 1e-9;#define LL long long #define pb push_backconst int maxn  = 101000;const int K = 18;int n ,m ;int num[maxn];int d[maxn];int p[maxn][K];map<int,int>mp;map<int,int>::iterator it;int idx[maxn];int head[maxn];struct Edge{    int v;    int next;}edge[2*maxn];int tot;void init(){    memset(head,-1,sizeof(head));    CLEAR(d);    CLEAR(p);    tot = 0;}void add(int u,int v){    tot ++;    edge[tot].v= v;    edge[tot].next = head[u];    head[u] = tot;}struct Node{Node *l,*r;int sum;}nodes[maxn*40];Node *root[maxn];Node *null;int C;void inits(){C= 0;null = &nodes[C++];root[0] = null;null->l = null->r = null;null->sum = 0;}Node *update(int pos,int left ,int right,int val,Node *root){ Node *rt = &nodes[C++]; rt->l = root->l; rt->r = root->r; rt->sum = root->sum; if(left ==right){   rt->sum +=val;      return rt; } int mid =(left +right)/2; if(pos<=mid){ rt ->l =update(pos,left,mid,val,root->l); }else{ rt ->r = update(pos,mid+1,right,val,root->r); } rt->sum = rt->l->sum + rt->r->sum; return rt;}int query(int k,int left ,int right,Node *rt,Node *rt2,Node *rt3,Node *rt4){//cout << left << " lr "<<right<<endl; if(left ==right){ return left; } int mid = (left +right)/2;// cout <<rt->sum<<" "<< rt2->sum <<"   "<<rt3->sum<<" "<<rt4->sum<<endl; int sum = rt->l->sum + rt2->l->sum - rt3->l->sum - rt4->l->sum;// cout << sum <<" sum k " << k << " "<<mid << endl; if(sum>=k){   return query(k,left,mid,rt->l,rt2->l,rt3->l,rt4->l); }else{ return query(k-sum,mid+1,right,rt->r,rt2->r,rt3->r,rt4->r); }}int get(int a,int b,int c,int d,int  k){return query(k,1,n,root[a],root[b],root[c],root[d]);}void dfs(int rt,int f){      d[rt]=d[f]+1;      p[rt][0]=f;      int pos = mp[num[rt]];    root[rt] = update(pos,1,n,1,root[f]);    for(int i=1;i<K;i++) p[rt][i] = p[p[rt][i-1]][i-1];    for(int i=head[rt];i!=-1;i= edge[i].next){     int son = edge[i].v;     if(son==f)continue;     dfs(son,rt);    }  }  int lca(int a,int b){if(d[a]>d[b]) swap(a,b);if(d[a]<d[b]){int del = d[b]-d[a];for(int i=0;i<K;i++) if(del &(1<<i)) b= p[b][i];}if(a!=b){for(int i= K-1;i>=0;i--){if(p[a][i]!= p[b][i]){a = p[a][i],b = p[b][i];}}a= p[a][0],b = p[b][0];}return a;}void solve(){    init();inits();for(int i =1;i<n;i++){int u,v;scanf("%d%d",&u,&v);add(u,v);add(v,u);}dfs(1,0);for(int i=1;i<=m;i++){int a,b,k;scanf("%d%d%d",&a,&b,&k);int t1 = lca(a,b);int t2 = p[t1][0];int ans = get(a,b,t1,t2,k);printf("%d\n",idx[ans]);}}int main(){    while(~scanf("%d%d",&n,&m)){    mp.clear();    for(int i=1;i<=n;i++){    scanf("%d",&num[i]);    mp[num[i]] = 1;    }    int tot2 = 0;        for(it = mp.begin();it!=mp.end();it++){        tot2 ++ ;        it->second = tot2;        //cout << tot2 << "  "<<it->first<<endl;        idx[tot2] = it->first;         }                   solve();    }    return 0;}


原创粉丝点击