LCA和RMQ

来源:互联网 发布:侠客行 知乎 编辑:程序博客网 时间:2024/06/06 14:21

这两天复习了LCA和RMQ相关的题。两道题分别使用了离线的Tarjan算法,以及LCA转化为RMQ问题最后用ST算法解决。

Tarjan算法和ST算法相对都是比较好coding的,LCA转RMQ因为多了一步转换,最后的代码量会高出不少。但ST算法的主体貌似比Tarjan算法直观,也不容易写错。

容易写错的部分主要还是下标以及初始化,比如并查集用f[k]==k还是f[k]==0代表根,边的存储从0还是从1开始,以及对应的head数组是0还是-1.

POJ 1968

#include <iostream>#include <cstdio>using namespace std;struct Edge{int x,y,l,next;};Edge e[40100*2],q[2*40100];int a[40100];int b[40100];int ans[10010];int n,m;int tot;int qtot;int ff[40100];int v[40100];int d[40100];int fa(int k){if (ff[k]== k)return k;ff[k] = fa(ff[k]);return ff[k];}void addedge(int x,int y,int len){e[tot].x = x;e[tot].y = y;e[tot].l = len;e[tot].next = a[x];a[x] = tot++;}void addq(int x,int y){q[qtot].x = x;q[qtot].y = y;q[qtot].next = b[x];b[x] = qtot++;}void init(){tot = 1;qtot = 1;scanf("%d%d",&n,&m);char ch;int x,y,len;for (int i = 0; i < m; i++){scanf("%d%d%d %c",&x,&y,&len,&ch);addedge(x,y,len);addedge(y,x,len);}scanf("%d",&m);for (int i = 0; i < m; i++){scanf("%d%d",&x,&y);addq(x,y);addq(y,x);}}void tarjan(int k){if (v[k]) return;ff[k] = k;v[k] = 1;int j = a[k];int y;while (j != 0){y = e[j].y;if (!v[y]){d[y] = d[k] + e[j].l;tarjan(y);ff[y] = k;}j = e[j].next;} j = b[k];while (j != 0){ y = q[j].y;if (v[y] && ans[(j+1)/2] == 0){ans[(j+1)/2] = d[y] + d[k] - 2*d[fa(y)];//printf("%d %d %d %d %d %d \n",k,y,fa(y),d[k], d[y] ,d[fa(y)]);}j = q[j].next;}}void print(){for (int i = 1 ; i <= m; i++){printf("%d\n",ans[i]);}}int main(){init();tarjan(1);print();return 0;}

POJ 2763

#include <iostream>#include <cstdio>#include <cstring>#include <cmath>using namespace std;struct Edge{    int u,v,w,next;};const int N = 110001;const int M = 20;Edge e[N*2];int head[N],dep[N],d[N],first[N],ver[N*2],R[N*2];int vis[N];int tot;int n,q,s;int dp[N*2][M];void addedge(int x,int y,int w){    e[tot].u = x;    e[tot].v = y;    e[tot].w = w;    e[tot].next = head[x];    head[x] = tot;    tot++;}void dfs(int u,int dd){   if (vis[u]) return;   tot++;    vis[u] = 1;dep[u] = dd; first[u] = tot; ver[tot] = u; R[tot] = dd;   int j = head[u];  while ( j!=-1)  {      int v = e[j].v;      if (!vis[v])      {          d[v] = d[u] + e[j].w;          dfs(v,dd+1);          tot++; ver[tot] = u;R[tot] = dd;      }      j = e[j].next;  } }void ST(int l,int r){    for (int i = l; i <= r; i++)    {        dp[i][0] = i;    }    int log;    for (log = 0; 1 << log <= r-l+1;log++);log--;    for (int j = 1; j <= log; j++)    {        for (int i = l; i + (1 << j)-1<=r; i++)        {            int a = dp[i][j-1];            int b = dp[i+(1<<j-1)][j-1];            if (R[a] < R[b])                dp[i][j] = a;            else dp[i][j] = b;        }    }}int RMQ(int l,int r){    int log;    for (log = 0; 1 << log <= r-l+1;log++); log--;    int a = dp[l][log];    int b = dp[r-(1<<log)+1][log];    if (R[a] < R[b]) return a;    else return b;}int LCA(int u,int v){    int l = first[u];    int r = first[v];    if (l >= r)    {        int temp = l;        l = r;        r = temp;    }    int w = RMQ(l,r);    return ver[w];}void init(){    int x,y,w;    memset(head,-1,sizeof(head));    memset(vis,0,sizeof(vis));    memset(d,0,sizeof(d));    tot = 0;    for (int i = 0; i < n-1; i++)    {        scanf("%d %d %d",&x,&y,&w);        addedge(x,y,w);        addedge(y,x,w);    }    tot = 0;     dfs(1,1);    ST(1,tot);}int caldis(int u,int v, int w){    return d[u] + d[v] - 2*d[w];}void travel(int u,int delta,int tag){    int j = first[u];    for(int i = j; i <= tot && R[i] >= R[j]; i++)    {        int v = ver[i];        if (vis[v] != tag)        {            vis[v] = tag;            d[v] += delta;        }    }}void solve(){    memset(vis,-1,sizeof(vis));    int u,v,w;    for (int i = 0; i < q; i++)    {        scanf("%d",&w);        if (w == 0)        {            scanf("%d",&u);            int v = LCA(s,u);            printf("%d\n",caldis(s,u,v));                s = u;        }        else if (w == 1)        {            int k,delta;            scanf("%d%d",&k,&u);            int j = k*2-1;            delta = u - e[j].w ;            e[j].w = e[j^1].w = u;            u = e[j].u;            v = e[j].v;            if (dep[u] > dep[v])            {                int temp = u;                u = v;                v= temp;            }            travel(v,delta,i);        }    }}int main(){    while(scanf("%d%d%d",&n,&q,&s) != EOF)    {        init();        solve();    }    return 0;}


0 0
原创粉丝点击