poj1741 Tree 点分治

来源:互联网 发布:软件需求分为哪几类 编辑:程序博客网 时间:2024/05/16 20:31

    第一次写点分治啊,果然还是黄学长的代码框架好。。。

    这绝对是点分治的最经典的题目了。大概写一下自己的理解吧。

    首先,对于一棵树,求出其重心并作为根节点。然后链就可以分为两类:经过根节点的和不经过根节点的。对于不经过根节点的,在子树中递归调用即可。由于每次根节点都取树的重心,所以递归一次点的个数至少除以2,递归层数不超过logN层。另一方面,每一层都可以大致看成有O(N)级别个点,对于这些点的操作时间是O(NlogN)级别的。因此总的时间复杂度O(Nlog^2N)。

    对于经过根节点的链,我们进行如下操作。首先将根节点深度定为0,然后求出子树中每个节点的深度dep[i]。然后将dep[]排序,如果dep[x]+dep[y]<=k,显然就是满足条件的链。排序后O(N)扫一遍即可。

    但是上述操作仍有许多值得考究的地方。首先,仅仅只有dep[x]+dep[y]<=k是不能充分说明经过根节点的,还有可能x和y是在用一个子树中的。但如果直接强行求出经过根节点即x和y在不同子树同时满足dep[x]+dep[y]<=k的话是很满发的。所以可以先求出满足dep[x]+dep[y]<=k的链的个数sum,然后对于每一刻子树,再求一遍满足dep[x]+dep[y]<=k的个数tmp,在sum中减去所有的tmp即可。这样算法就被正确地实现了。

    可是O(Nlog^2N)的时间复杂度好像不是很好看啊。实际上,多出来的一个logN主要是耗时在sort上。在求平面最近点对时我们就可以用在子节点(差不多这个意思)中排序然后再归并的方法。这里对于点分治的dep数组也是一样的,可以在子树中排完序然后再归并。这样连求dep的数组都不需要了。不过这个归并实现比较麻烦,因为有多个子树,可能要用堆维护,再考虑到常数的差距可能反而更慢,以及实现的复杂度。我这么弱暂时不考虑。

下面给出AC代码:

#include<iostream>#include<cstdio>#include<cstring>#include<algorithm>#define inf 1000000000#define N 300005using namespace std;int n,m,cnt,tot,rt,sum,ans,fst[N],pnt[N],len[N],nxt[N],c[N],d[N],sz[N],f[N];bool vis[N];int read(){int x=0; char ch=getchar();while (ch<'0' || ch>'9') ch=getchar();while (ch>='0' && ch<='9'){ x=x*10+ch-'0'; ch=getchar(); }return x;}void add(int aa,int bb,int cc){pnt[++tot]=bb; nxt[tot]=fst[aa]; len[tot]=cc; fst[aa]=tot;}void dfs(int x,int last){sz[x]=f[x]=1; int p;for (p=fst[x]; p; p=nxt[p]){int q=pnt[p]; if (q==last || vis[q]) continue;dfs(q,x); sz[x]+=sz[q];f[x]=max(f[x],sz[q]);}f[x]=max(f[x],sum-sz[x]); if (f[x]<f[rt]) rt=x;}void getdep(int x,int last){c[++cnt]=d[x]; int p;for (p=fst[x]; p; p=nxt[p]){int q=pnt[p]; if (q==last || vis[q]) continue;d[q]=d[x]+len[p]; getdep(q,x);}}int work(int x,int dep){d[x]=dep; cnt=0;getdep(x,0); sort(c+1,c+cnt+1);int tmp=0,l,r=cnt;for (l=1; l<r; l++){while (l<r && c[l]+c[r]>m) r--;tmp+=r-l;}return tmp;}void solve(int x){ans+=work(x,0);vis[x]=1; int p;for (p=fst[x]; p; p=nxt[p]){int q=pnt[p]; if (vis[q]) continue;ans-=work(q,len[p]);rt=0; sum=sz[q]; dfs(q,x); solve(rt);}}int main(){while (n=read()){m=read(); int i;tot=ans=0; memset(fst,0,sizeof(fst));memset(vis,0,sizeof(vis));for (i=1; i<n; i++){int x=read(),y=read(),z=read();add(x,y,z); add(y,x,z);}sum=n; f[rt=0]=inf; dfs(1,0);solve(rt); printf("%d\n",ans);}return 0;}


2015.11.15

by lych

0 0