HDU

来源:互联网 发布:手机投资黄金软件 编辑:程序博客网 时间:2024/06/05 08:07

题目链接


题意:

n个节点的树,节点的点权为ai,要求找出有多少个二元组(u,v)满足

1:u是v的祖先且u!=v

2:a[u]*a[v]<=K


思路:

   先把2转化一下:a[u] <=k/a[v] 因为都是整数所以不整除也没影响. 那么就是对每一个v找到他所有祖先里满足

上面那个a[u] <=k/a[v]不等式即可.

这个过程可以考虑为一个dfs过程,我们一边dfs一边查询即可。这里可以用一个树状数组,每遍历到一个就把他对应

树状数组的值+1,遍历v结点时,要查询他的祖先只需要查询有多少在k/a[v]

(其实就是快速查询当前有多少个点的值小于val的问题,BIT每个点的值对应BIT的下标)

但是这里A很大我们还是要离散化一下,有一个小问题就是如果仅仅把a[i]离散化了那么k/a[i]的值就会改变,所以我们这里考虑将k/a[i] 也一起加入进来离散化.(做除法记得判断a是否为0,为0的直接设为inf,不过数据好像没有?)

另外一个问题就是dfs过程中,每个点可能会受到其兄弟子树的影响,所以对于每个点为根的子树我们查询完后再给他清0,排除对其他子树的影响.


#pragma comment(linker, "/STACK:1024000000,1024000000")  #include<bits/stdc++.h>using namespace std;typedef long long ll;const int mod=1e9+7;const int maxn=2e5+10;const ll inf = 1e15;ll s[maxn],a[maxn],Hash[maxn],in[maxn];vector<vector<int> >vt(maxn); int n,m;ll k,cnt;int lowbit(int x){return x&-x;}void add(int x,int d){while(x < maxn){s[x] += d;x += lowbit(x);}return ;}int sum(int x){int res = 0;while(x > 0){res += s[x];x -= lowbit(x);}return res;}void dfs(int x){int l = lower_bound(Hash,Hash+m,a[x] ? k/a[x] : inf)-Hash+1;cnt += sum(l);int r = lower_bound(Hash,Hash+m,a[x])-Hash+1;add(r,1);int len = vt[x].size();for(int i = 0;i < len;++i){int v = vt[x][i];dfs(v);}add(r,-1);return ;}int main(){int _;cin>>_;while(_--){memset(in,0,sizeof in);memset(s,0,sizeof s);for(int i = 1;i <= n;i++)vt[i].clear();m = 0;scanf("%d %lld",&n,&k);for(int i = 1;i <= n;++i){scanf("%lld",&a[i]);Hash[m++] = a[i];if(a[i] == 0)Hash[m++] = inf;elseHash[m++] = k/a[i];}sort(Hash,Hash+m);m = unique(Hash,Hash+m) - Hash; for(int i = 1;i < n;++i){int u,v;scanf("%d %d",&u,&v);vt[u].push_back(v);in[v]++;}cnt = 0;for(int i = 1;i <= n;++i){if(!in[i]){dfs(i);}}printf("%lld\n",cnt);}return 0;}

PS:

此题当时本队做麻烦了,,,用了dfs序+主席树做的。。。

#include<bits/stdc++.h>using namespace std;typedef long long ll;const int maxn = 1e5+5;vector<int> g[maxn];int n, a[maxn], in[maxn*4], out[maxn*4], fa[maxn*4], tot1, tot2;int lson[maxn<<5], rson[maxn<<5], sum[maxn<<5];int T[maxn], Hash[maxn];int cb[maxn], cur, d;ll k;int build(int l, int r){    int rt = ++tot2;    sum[rt] = 0;    if(l < r)    {        int mid = (l+r)/2;        lson[rt] = build(l, mid);        rson[rt] = build(mid+1, r);    }    return rt;}int update(int pre, int l, int r, int x){    int rt = ++tot2;    lson[rt] = lson[pre], rson[rt] = rson[pre], sum[rt] = sum[pre]+1;    if(l < r)    {        int mid = (l+r)/2;        if(x <= mid)            lson[rt] = update(lson[pre], l, mid, x);        else            rson[rt] = update(rson[pre], mid+1, r, x);    }    return rt;}int query(int u, int v, int l, int r, int k){    if(l >= r) return sum[v]-sum[u];    int mid = (l+r)/2;    int ans = 0;    if(k <= mid)        ans += query(lson[u], lson[v], l, mid, k);    else    {        ans += sum[lson[v]]-sum[lson[u]];        ans += query(rson[u], rson[v], mid+1, r, k);    }    return ans;}void dfs(int u, int pre){    int x = lower_bound(Hash+1, Hash+1+d, a[u])-Hash;//    cout << cur << ' ' << u << ' ' << x << endl;    T[cur] = update(T[cur-1], 1, d, x);    cur++;    in[u] = ++tot1;    fa[u] = pre;    for(int i = 0; i < g[u].size(); i++)    {        int v = g[u][i];        if(v != pre)            dfs(v, u);    }    out[u] = tot1;}int main(void){    int _;    cin >> _;    while(_--)    {        tot1 = tot2 = 0;        memset(cb, 0, sizeof(cb));        for(int i = 0; i < maxn; i++)            g[i].clear();        scanf("%d%lld", &n, &k);        for(int i = 1; i <= n; i++)            scanf("%d", &a[i]), Hash[i] = a[i];        sort(Hash+1, Hash+1+n);        d = unique(Hash+1, Hash+1+n)-Hash-1;        T[0] = build(1, d);//        for(int i = 1; i <= n; i++)//        {//            int x = lower_bound(Hash+1, Hash+1+d, a[i])//        }        for(int i = 1; i < n; i++)        {            int u, v;            scanf("%d%d", &u, &v);            g[u].push_back(v);            cb[v]++;        }        int tRoot;        for(int i = 1; i <= n; i++)            if(!cb[i])            {                tRoot = i;                break;            }        cur = 1;        dfs(tRoot, 0);//        printf("%d\n", query(T[0], T[3], 1, d, 3));//        for(int i = 1; i <= n; i++)//            cout << i << ' ' << in[i] << ' ' << out[i] << endl;        ll ans = 0;        for(int i = 1; i <= n; i++)        {            ll des = k/a[i];            int x = upper_bound(Hash+1, Hash+1+d, des)-Hash-1;            ans += query(T[in[i]], T[out[i]], 1, d, x);        }        printf("%lld\n", ans);    }    return 0;}

原创粉丝点击