qscoj 喵哈哈村与哗啦啦村的大战(四)(树形DP求非严格路径数量)

来源:互联网 发布:java在线 编辑:程序博客网 时间:2024/04/30 03:23

喵哈哈村与哗啦啦村的大战(四)

发布时间: 2017年3月28日 20:03   最后更新: 2017年3月28日 20:06   时间限制: 1000ms   内存限制: 128M

喵哈哈村因为和哗啦啦村争夺稀有的水晶资源,展开了激烈的战斗!

喵哈哈村的部落可以视为由n个节点组成,其中有n-1条边连接这n个节点,使得任意两个节点都会有一条路径相连接。每个节点上都有一个点权a[i]。

如果说存在一条路径上的权值满足非严格的单调递增或者非严格的单调递减的话,就说这条路径是一条好路径。

现在问题来了,给你一棵树,问你这棵树上有多少条路径是好路径。

本题包含若干组测试数据。
第一行一个n,表示有n个节点。
第二行n个整数a[i],表示节点的权值。
接下来n-1行每行两个整数x,y。表示x,y节点之间有一条边相互连接。

满足:1<=n<=100000 1<=a[i]<=1e9 1<=x,y<=n

输出好路径的个数。

 复制
41 7 1 91 31 4 2 1
5
 复制
61 1 2 2 3 31 22 33 44 55 6
15

第一次求出子节点的递增递减相等路径数量,第二次计算子路劲叠加而成的路径数量,中间过程注意去重


#include <iostream>#include <cstdio>#include <cstring>#include <cmath>#include <algorithm>#include <vector>#include <map>using namespace std;const int N = 110000;int a[N];vector<int>p[N];int d[N][4], ans;void dfs(int u,int pre){    int up=0, down=0, mid=0;//比子节点大,比子节点小,和子节点相等;    for(int i=0;i<p[u].size();i++)    {        int v=p[u][i];        if(v==pre) continue;        dfs(v,u);        if(a[u]>=a[v]) d[u][0]+=(d[v][0]+1),up+=(d[v][0]+1);        if(a[u]<=a[v]) d[u][1]+=(d[v][1]+1),down+=(d[v][1]+1);        if(a[u]==a[v]) d[u][2]+=(d[v][2]+1),mid+=(d[v][2]+1);    }    ans=ans+up+down-mid;    for(int i=0;i<p[u].size();i++)    {        int v=p[u][i];        if(v==pre) continue;        if(a[u]>a[v])        {            up-=(d[v][0]+1);            ans+=(d[v][0]+1)*down;        }        else if(a[u]<a[v])        {            down-=(d[v][1]+1);            ans+=(d[v][1]+1)*up;        }        else        {            up-=(d[v][0]+1);            down-=(d[v][1]+1);            mid-=(d[v][2]+1);            ans+=(d[v][1]+1)*up;            ans+=(d[v][0]+1)*down;            ans-=(d[v][2]+1)*mid;        }    }    return ;}int main(){    int n;    while(scanf("%d", &n)!=EOF)    {        for(int i=1;i<=n;i++)        {            scanf("%d", &a[i]);            p[i].clear();        }        for(int i=1;i<n;i++)        {            int x, y;            scanf("%d %d", &x, &y);            p[x].push_back(y),p[y].push_back(x);        }        memset(d,0,sizeof(d));        ans=0;        dfs(1,-1);        printf("%d\n",ans);    }    return 0;}

0 0
原创粉丝点击