【树分治】 ZOJ Travel

来源:互联网 发布:淘宝店卖家信誉等级表 编辑:程序博客网 时间:2024/06/06 01:04

离线每个询问,然后做树分治。。

#include <bits/stdc++.h>using namespace std;typedef long long LL;#define lowbit(x) (x&(-x))#define pii pair<int, int> #define mp(x, y) make_pair(x, y)const int maxn = 100005;const int maxm = 200005;const int INF = 0x3f3f3f3f;struct Edge{int v;Edge *next;}*H[maxn], *edges, E[maxm];vector<pii> q[maxn], dis1, dis2, dis;bool done[maxn];int size[maxn];int res[maxn];int mx[maxn];int a[maxn];int tree[maxn];int tree1[maxn];int tree2[maxn];int n, m, root, nsize;void addedges(int u, int v){edges->v = v;edges->next = H[u];H[u] = edges++;}void init(){edges = E;memset(H, 0, sizeof H);memset(res, 0, sizeof res);memset(done, 0, sizeof done);}void getroot(int u, int fa){mx[u] = 0, size[u] = 1;for(Edge *e = H[u]; e; e = e->next) if(!done[e->v] && e->v != fa) {int v = e->v;getroot(v, u);size[u] += size[v];mx[u] = max(mx[u], size[v]);}mx[u] = max(mx[u], nsize - size[u]);if(mx[u] < mx[root]) root = u;}void add(int x, int v, int tree[]){x++;for(int i = x; i <= n + 1; i += lowbit(i)) tree[i] += v;}int sum(int x, int tree[]){x++;int ans = 0;for(int i = x; i > 0; i -= lowbit(i)) ans += tree[i];return ans;}void dfs(int u, int fa, int dep, int flag){if(flag == 0) dis.push_back(mp(dep, u));if(flag == 1) dis1.push_back(mp(dep, u));if(flag == 2) dis2.push_back(mp(dep, u));for(Edge *e = H[u]; e; e = e->next) if(e->v != fa && !done[e->v]) {int v = e->v;if(flag == 0) {if(a[u] == a[v]) dfs(v, u, dep + 1, 0);if(a[u] > a[v]) dfs(v, u, dep + 1, 1);if(a[u] < a[v]) dfs(v, u, dep + 1, 2);}else if(flag == 1) {if(a[u] >= a[v]) dfs(v, u, dep + 1, 1);}else {if(a[u] <= a[v]) dfs(v, u, dep + 1, 2);}}}void solve(int u){done[u] = true;dis.clear();dis1.clear();dis2.clear();dfs(u, u, 0, 0);for(int i = 0; i < dis.size(); i++) add(dis[i].first, 1, tree);for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, 1, tree1);for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, 1, tree2);for(int i = 0; i < q[u].size(); i++) {int t = 0, d = q[u][i].first, id = q[u][i].second;t = sum(d, tree) + sum(d, tree1) + sum(d, tree2);res[id] += t;}for(Edge *e = H[u]; e; e = e->next) if(!done[e->v]) {int v = e->v;dis.clear();dis1.clear();dis2.clear();if(a[u] == a[v]) dfs(v, v, 1, 0);if(a[u] > a[v]) dfs(v, v, 1, 1);if(a[u] < a[v]) dfs(v, v, 1, 2);for(int i = 0; i < dis.size(); i++) add(dis[i].first, -1, tree);for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, -1, tree1);for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, -1, tree2);for(int i = 0; i < dis.size(); i++) {int dist = dis[i].first, x = dis[i].second;for(int j = 0; j < q[x].size(); j++) {int t = 0, d = q[x][j].first, id = q[x][j].second;if(d >= dist) t = sum(d - dist, tree) + sum(d - dist, tree1) + sum(d - dist, tree2);res[id] += t;}}for(int i = 0; i < dis1.size(); i++) {int dist = dis1[i].first, x = dis1[i].second;for(int j = 0; j < q[x].size(); j++) {int t = 0, d = q[x][j].first, id = q[x][j].second;if(d >= dist) t = sum(d - dist, tree) + sum(d - dist, tree2);res[id] += t;}}for(int i = 0; i < dis2.size(); i++) {int dist = dis2[i].first, x = dis2[i].second;for(int j = 0; j < q[x].size(); j++) {int t = 0, d = q[x][j].first, id = q[x][j].second;if(d >= dist) t = sum(d - dist, tree) + sum(d - dist, tree1);res[id] += t;}}for(int i = 0; i < dis.size(); i++) add(dis[i].first, 1, tree);for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, 1, tree1);for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, 1, tree2);}dis.clear();dis1.clear();dis2.clear();dfs(u, u, 0, 0);for(int i = 0; i < dis.size(); i++) add(dis[i].first, -1, tree);for(int i = 0; i < dis1.size(); i++) add(dis1[i].first, -1, tree1);for(int i = 0; i < dis2.size(); i++) add(dis2[i].first, -1, tree2);for(Edge *e = H[u]; e; e = e->next) if(!done[e->v]) {int v = e->v;mx[0] = nsize = size[v];getroot(v, root = 0);solve(root);}}void work(){scanf("%d%d", &n, &m);for(int i = 1; i <= n; i++) scanf("%d", &a[i]);for(int i = 1; i < n; i++) {int u, v;scanf("%d%d", &u, &v);addedges(u, v);addedges(v, u);}for(int i = 1; i <= n; i++) q[i].clear();for(int i = 1; i <= m; i++) {int x, d;scanf("%d%d", &x, &d);q[x].push_back(mp(d, i));}mx[0] = nsize = n;getroot(1, root = 0);solve(root);for(int i = 1; i <= m; i++) printf("%d\n", res[i]);}int main(){int _;scanf("%d", &_);while(_--) {init();work();}return 0;}


0 0