HDU5293树形DP取一棵树上不共点的任意个树枝

来源:互联网 发布:测辐射的软件 编辑:程序博客网 时间:2024/04/29 03:29
#pragma comment(linker, "/STACK:1024000000,1024000000")#include<iostream>#include<cstdio>#include<algorithm>#include<cstring>#include<cmath>#include<vector>#include<map>#include<set>#include<time.h>#include<string>#include<map>#include <algorithm>#include <climits>#define cl(a,b)    memset(a,b,sizeof(a))#define max(x,y) ((x)>(y)?(x):(y))#define min(x,y) ((x)<(y)?(x):(y))#define REP(i,n) for(int i=0;i<n;++i)#define REP1(i,a,b) for(int i=a;i<=b;++i)#define REP2(i,a,b) for(int i=a;i>=b;--i)#define MP make_pair#define LL long long#define ULL unsigned long long#define X first#define Y second#define MAXN 200050#define MOD 60#define lson i<<1#define rson i<<1|1#define eps 1e-8using namespace std;int nx[MAXN * 2], head[MAXN], to[MAXN * 2];int ecnt;int val[MAXN];int a[MAXN];int l[MAXN], r[MAXN];void addedge(int x, int y) {    to[ecnt] = y, nx[ecnt] = head[x], head[x] = ecnt++;    to[ecnt] = x, nx[ecnt] = head[y], head[y] = ecnt++;}int fa[MAXN][20];int dep[MAXN];int cid = 0;void dfs(int cur, int faa, int depth) {    l[cur] = ++cid;    dep[cur] = depth;    for (int i = 1; i < 20; ++i)        fa[cur][i] = fa[fa[cur][i - 1]][i - 1];    for (int i = head[cur]; i != -1; i = nx[i]) {        int u = to[i];        if (u == faa)            continue;        fa[u][0] = cur;        val[u] = val[cur] + a[u];        dfs(u, cur, depth + 1);    }    r[cur] = ++cid;}int getlca(int x, int y) {    if (dep[x] < dep[y])        swap(x, y);    for (int i = 19; i >= 0; --i) {        if (dep[fa[x][i]] >= dep[y])            x = fa[x][i];        if (x == y)            return x;    }    for (int i = 19; i >= 0; --i)        if (fa[x][i] != fa[y][i]) {            x = fa[x][i];            y = fa[y][i];        }    return fa[x][0];}struct node {    int u, v, lca, val;} p[MAXN];bool cmp(node a, node b) {    return dep[a.lca] > dep[b.lca];}bool vis[MAXN];void gao(int cur, int fa) {    vis[cur] = true;    for (int i = head[cur]; i != -1; i = nx[i]) {        int u = to[i];        if (u == fa)            continue;        if (vis[u])            continue;        gao(u, cur);    }}int sum[MAXN];int d[MAXN];int n, m;int cs[MAXN], cd[MAXN];vector<int> e[MAXN];int lowbit(int x) {    return x & (-x);}void add(int p, int va, int *c) {    while (p <= 2*n) {        c[p] += va;        p += lowbit(p);    }}int getsum(int p, int *c) {    int res = 0;    while (p) {        res += c[p];        p -= lowbit(p);    }    return res;}void solve(int cur, int fa) {    sum[cur] = d[cur] = 0;    for (int i = head[cur]; i != -1; i = nx[i]) {        int u = to[i];        if (fa == u)            continue;        solve(to[i], cur);        sum[cur] += d[u];    }    d[cur] = sum[cur];    for (int i = 0; i < e[cur].size(); ++i) {        int u = p[e[cur][i]].u;        int v = p[e[cur][i]].v;        int tmp = getsum(l[u], cs) + getsum(l[v], cs) - getsum(l[u], cd)                - getsum(l[v], cd) + sum[cur];        d[cur] = max(d[cur], tmp + p[e[cur][i]].val);    }    add(l[cur], d[cur], cd);    add(r[cur], -d[cur], cd);    add(l[cur], sum[cur], cs);    add(r[cur], -sum[cur], cs);}int main() {//freopen("data.in","r",stdin);//freopen("data.out","w",stdout);int tt;scanf("%d",&tt);    while(tt--) {    scanf("%d%d", &n, &m);        cid=0;        for(int i=0;i<=2*n;++i)cd[i]=cs[i]=0;        for (int i = 0; i <= n; ++i) {            vis[i] = false;            head[i] = -1;            e[i].clear();        }        ecnt = 0;        for (int i = 1; i < n; ++i) {            int x, y;            scanf("%d%d", &x, &y);            addedge(x, y);        }        fa[1][0] = 1;        val[1] = 0;        dfs(1, -1, 0);        for (int i = 0; i < m; ++i) {            scanf("%d%d%d", &p[i].u, &p[i].v,&p[i].val);            p[i].lca = getlca(p[i].u, p[i].v);            e[p[i].lca].push_back(i);        }        solve(1, -1);        printf("%d\n", d[1]);    }    return 0;}/*3 21 21 31 2 11 3 27 31 21 32 42 53 63 72 3 44 5 36 7 3 */

0 0
原创粉丝点击