2016计蒜客初赛第一场 青云的机房组网方案(困难):图论+虚树+容斥

来源:互联网 发布:大数据赚钱 编辑:程序博客网 时间:2024/04/28 02:36

题目链接:http://nanti.jisuanke.com/t/11135


题目概述:给一棵n个节点的树,每个节点有一个初始值ai。1<=n<=100000,1<=ai<=100000。求树上任意两个值互质点距离的和。


思路概述:

  1. 枚举 因数x,x是每种质因子至多有一个的数,记录一下x有几种质因子,方便之后容斥。
  2. 把所有x的倍数的权值的点找出来,预处理下可以做到找出来的点的dfs序是从小到大的,预处理也可以使得每次找x的倍数的权值的点不必线性扫一遍。
  3. 然后对这些点 O(n) 建虚树,具体操作是相邻两个点加进去 lca,用一个栈维护下父亲链即可。[bzoj3572]是一道典型的虚树的题目。
  4. 构建好树后在树上 dfs 两次可以求出所有x的倍数的权值的点对之间的距离和,就是第一遍dfs记录以节点u为根的子树中,有多少个x倍数的点(可能有一些是虚树添加进来的lca点),第二遍dfs其实是枚举每条边,计算(u,v)这条边的总价值,就是它出现的次数乘以它的权值;它出现的次数就是它子树中x倍数的点的个数,乘以不在它子树中x倍数的点的个数。
  5. 最后容斥下就可以求出答案。

由于所有步骤均是线性的,而所有虚树加起来的总点数也是线性乘上一个常数的,所以复杂度为 O(nK),K<=128。

#include <bits/stdc++.h>using namespace std;typedef long long ll;typedef pair<int,int> pii;const int maxp = 316;const int maxn = 100100;const int maxl = 17;bool isprime[maxp+5];int prime[maxp+5],pnum = 0;int a[maxn];int anc[maxn][maxl+1],dep[maxn],cur[maxn],Stack[maxn<<1],dfn[maxn];int n,dfs_seq=0;vector<int> arr[maxn];vector<int> factor[maxn];int rongchi[maxn];struct Node{    int head[maxn],nex[maxn<<1],point[maxn<<1],weight[maxn<<1],siz[maxn];    bool selected[maxn];    map<int,int> label;    int ne,total,nl;    ll ans;    void init(int tot)    {        label.clear();        total = tot;        ne = 0;        ans = 0;        nl = 0;    }    void addedge(int u,int v, int w)    {        if(label.count(u)) u = label[u];        else        {            label[u] = ++nl;            u = nl;            head[u] = -1;            selected[u] = false;        }        if(label.count(v)) v = label[v];        else        {            label[v] = ++nl;            v = nl;            head[v] = -1;            selected[v] = false;        }        point[ne] = v;        nex[ne] = head[u];        weight[ne] = w;        head[u] = ne++;        point[ne] = u;        nex[ne] = head[v];        weight[ne] = w;        head[v] = ne++;    }    void set(int x)    {        if(!label.count(x))        {            label[x] = ++nl;            x = nl;        }        else x = label[x];        head[x] = -1;        selected[x] = true;    }    void dfs1(int root,int fa)    {        siz[root] = selected[root]?1:0;        for(int i=head[root]; i!=-1; i=nex[i])        {            if(point[i] == fa) continue;            dfs1(point[i],root);            siz[root] += siz[point[i]];        }    }    void dfs2(int root,int fa)    {        for(int i=head[root]; i!=-1; i=nex[i])        {            if(point[i]==fa) continue;            ans += 1LL*weight[i]*siz[point[i]]*(total-siz[point[i]]);            dfs2(point[i],root);        }    }} g1,g2;void dfs(int root){    int top = 0;    dep[root] = 1;    for(int i=0; i<=maxl; i++)    {        anc[root][i] = root;    }    Stack[++top] = root;    memcpy(cur,g1.head,sizeof(cur));    while(top)    {        int x = Stack[top];        if(x != root)        {            for(int i=1; i<=maxl; i++)            {                int y = anc[x][i-1];                anc[x][i] = anc[y][i-1];            }        }        for(int &i = cur[x]; i!= -1; i=g1.nex[i])        {            int y = g1.point[i];            if(y != anc[x][0])            {                dep[y] = dep[x]+1;                anc[y][0] = x;                Stack[++top] = y;            }        }        while(top && cur[Stack[top]] == -1) top--;    }}void swim(int &x,int H){    for(int i=0; H>0; i++)    {        if(H&1) x = anc[x][i];        H >>= 1;    }}int lca(int x,int y){    int i;    if(dep[x] > dep[y]) swap(x,y);    swim(y,dep[y]-dep[x]);    if(x == y) return x;    while(true)    {        for(i=0; anc[x][i] != anc[y][i]; i++);        if(i == 0) return anc[x][0];        x = anc[x][i-1];        y = anc[y][i-1];    }    return -1;}void getfactor(int pointer,int cur,int acc,const vector<pii> &tmp,int num){    if(pointer >= tmp.size()){        if(cur > 1) factor[num].push_back(cur);        return;    }    if(acc == 0) getfactor(pointer,cur*tmp[pointer].first,acc+1,tmp,num);    getfactor(pointer+1,cur,0,tmp,num);}void init(int ma){    memset(isprime,true,sizeof(isprime));    memset(rongchi,0,sizeof(rongchi));    int maxpp = (int)sqrt(ma+1.0);    for(int i=2; i<=maxpp; i++)    {        if(isprime[i]) prime[pnum++] = i;        for(int j=0; j<pnum; j++)        {            if(i*prime[j] > maxpp) break;            isprime[i*prime[j]] = false;            if(i%prime[j] == 0) break;        }    }    vector<pii> tmp;    for(int i=2; i<=ma; i++)    {        tmp.clear();        int ti = i;        for(int j=0; j<pnum && ti > 1; j++)        {            int cnt = 0;            while(ti%prime[j] == 0)            {                ti /= prime[j];                cnt++;                rongchi[i]++;            }            if(cnt) tmp.push_back(make_pair(prime[j],cnt));        }        if(ti > 1) tmp.push_back(make_pair(ti,1)),rongchi[i]++;        getfactor(0,1,0,tmp,i);    }}void dfs_dfn(int index){    dfn[index] = dfs_seq++;    int fsiz = factor[a[index]].size();    for(int i=0; i<fsiz; i++)    {        arr[factor[a[index]][i]].push_back(index);    }    for(int i=g1.head[index]; i!=-1; i=g1.nex[i])    {        if(g1.point[i] == anc[index][0]) continue;        dfs_dfn(g1.point[i]);    }}void work(int index){    int top=0;    int cnt = arr[index].size();    g2.init(cnt);    for (int i=0; i<cnt; i++)    {        g2.set(arr[index][i]);        if (!top)        {            Stack[++top]=arr[index][i];            continue;        }        int u=lca(Stack[top],arr[index][i]);        while (dfn[u]<dfn[Stack[top]])        {            if (dfn[u]>=dfn[Stack[top-1]])            {                g2.addedge(u,Stack[top],dep[Stack[top]]-dep[u]);                if (Stack[--top]!=u) Stack[++top]=u;                break;            }            g2.addedge(Stack[top-1],Stack[top],dep[Stack[top]]-dep[Stack[top-1]]),top--;        }        Stack[++top]=arr[index][i];    }    while (top>1) g2.addedge(Stack[top-1],Stack[top],dep[Stack[top]]-dep[Stack[top-1]]),top--;    g2.dfs1(1,0);    g2.dfs2(1,0);}int main(){    int u,v,ma = -1;    scanf("%d",&n);    g1.init(n);    for(int i=1; i<=n; i++) scanf("%d",a+i), g1.set(i), ma = max(ma,a[i]);    for(int i=0; i<n-1; i++)    {        scanf("%d%d",&u,&v);        g1.addedge(u,v,1);    }    init(ma);    dfs(1);    dfs_dfn(1);    g1.dfs1(1,0);    g1.dfs2(1,0);    ll ans = g1.ans;    for(int i=2; i<=ma; i++)    {        if(arr[i].size())        {            work(i);            if(rongchi[i]&1) ans -= g2.ans;            else ans += g2.ans;        }    }    cout<<ans<<endl;    return 0;}


0 0