HDU 5401(计数dp)

来源:互联网 发布:淘宝物流宝 编辑:程序博客网 时间:2024/05/21 17:53

题意描述:

原先设定第0颗树只有一个节点0,现在要生成第i颗数,选  ai, bi, (ai < i, bi< i) 中两个节点(ci , di)相连接,构成一个新的树,且ai中节点的编号不变, bi中的所有节点编号都要在原来的基础上+ai树的大小,这样保证编号连续,对于每颗树T而言 ,,F(T)=n1i=0n1j=i+1d(vi,vj)(d(vi,vj)
即任意两点之间距离总和。


这是多校题解:

考虑爆搜,树iii生成后,两两点对路径分成两部分,一部分不经过中间的边,那么就是aia_iaibib_ibi的答案,如果经过中间的边,首先计算中间这条边出现的次数,也就是ai,bia_i,b_iai,bi子树大小的乘积。对于aia_iai,对答案的贡献为所有点到cic_ici的距离和乘上bib_ibi的子树大小。bib_ibi同理。

那么转化为计算在树iii中,所有点到某个点jjj的距离和。假设jjjaia_iai内,那么就转化成了aia_iaijjj这个点的距离总和加上bib_ibi内所有点到did_idi的总和加上did_idijjj的距离乘上子树bib_ibi的大小,称作第一类询问。

这样就化成了在树iii中两个点jjjkkk的距离,如果在同一棵子树中,可以递归下去,否则假设jjjaia_iaikkkbib_ibi中,那么距离为jjjcic_ici的距离加上kkkdid_idi的距离加上lil_ili,称作第二类询问。

然后对两类询问全都记忆化搜索即可。

接着考虑计算一下复杂度。

对于第二类询问,可以考虑询问的过程类似于线段树,只会有两个分支,中间的部分已经记忆化下来,不用再搜,时间复杂度O(m)O(m)O(m)

我们分析一下复杂度,首先对于第一类询问,在bib_ibi中到did_idi的点距离和已经由前面的询问得到,那么就转化为一个第一类询问和一个第二类询问,最多会被转化成O(m)O(m)O(m)个第二类询问。

所以每个询问复杂度是O(m2)O(m^2)O(m2),总复杂度O(m3)O(m^3)O(m3)

复杂度计算思考:

对于第一类询问,只会例如sum(a[i], c[i])递归计算时,每个会分成两个第一类询问和一个第二类询问,而两个第一类询问必有一个已经被计算过(可以手动分解看看前后关系)

,所以每次分解成一个第一类和一个第二类,复杂度为m*m。

dis计算也同理。

被记忆的也不会很多,每次最多多记录m*m个。

#include <iostream>#include <cstdlib>#include <cstdio>#include <algorithm>#include <cstring>#include <map>#include <set>#include <vector>#include <cctype>#include <cmath>#include <queue>#define ls rt<<1#define rs rt<<1|1#define lson l,m,rt<<1#define rson m+1,r,rt<<1|1#define mem(a,n) memset(a,n,sizeof(a))#define rep(i,n) for(int i=0;i<(int)n;i++)#define rep1(i,x,y) for(int i=x;i<=(int)y;i++)using namespace std;#pragma comment(linker, "/STACK:102400000,102400000")typedef pair<int,int> pii;typedef long long ll;const int inf = 0x3f3f3f3f;const ll oo = 1e12;typedef pair<ll,ll> pll;const int N = 65;const int mod = 1e9+7;map<pll,ll> M[N];map<ll,ll> M2[N];int n;ll a[N],b[N],c[N],d[N],siz[N],ms[N],l[N],ans[N];void init(){  for(int i = 0; i < N;i++)    M[i].clear(),M2[i].clear();  M[0][pll(0,0)]=0;  M2[0][0] = 0;  siz[0] = ms[0] = 1;}ll dis(int i,ll j,ll k){   if(j > k) swap(j,k);   if(M[i].count(pll(j,k))) return M[i][pll(j,k)];   if(k < siz[a[i]]) return M[i][pll(j,k)] = dis(a[i],j,k);   if(j >= siz[a[i]]) return M[i][pll(j,k)] = dis(b[i],j-siz[a[i]],k-siz[a[i]]);   return  M[i][pll(j,k)] = (dis(a[i],j,c[i])+l[i]+dis(b[i],d[i],k-siz[a[i]]))%mod;}ll sum(int i,ll j){   if(M2[i].count(j)) return M2[i][j];   if(j<siz[a[i]]) return  M2[i][j]=(sum(a[i],j)+(l[i]+dis(a[i],j,c[i]))*ms[b[i]]+sum(b[i],d[i]))%mod;   if(j>=siz[a[i]]) return M2[i][j]=(sum(a[i],c[i])+(l[i]+dis(b[i],j-siz[a[i]],d[i]))*ms[a[i]]+sum(b[i],j-siz[a[i]]))%mod;}ll cal(int i){   siz[i] = siz[a[i]]+siz[b[i]];   ms[i] = siz[i]%mod;   ans[i] = ans[a[i]]+ans[b[i]]+ms[a[i]]*ms[b[i]]%mod*l[i]%mod+ms[b[i]]*sum(a[i],c[i])+ms[a[i]]*sum(b[i],d[i]);   ans[i]=ans[i]%mod;   return ans[i];}int main(){   while(scanf("%d",&n)==1){      init();      for(int i=1;i<=n;i++){         scanf("%I64d %I64d %I64d %I64d %I64d",&a[i],&b[i],&c[i],&d[i],&l[i]);         printf("%I64d\n",cal(i));      }   }   return 0;}





0 0