2017 ACM-ICPC 亚洲区(西安赛区)网络赛 A Tree 树分治 矩阵 没有逆元和交换律的树链统计

来源:互联网 发布:软件测试自学网站 编辑:程序博客网 时间:2024/05/21 09:30

链接

https://nanti.jisuanke.com/t/17114

题意

给一棵大小为n(n<=3000)的树,树上每个节点有一个64 * 64的01矩阵,q(q=30000)次询问,每次询问u到v路径上的矩阵之积。

思路

比赛的时候自己想用类似高斯消元的方法求逆矩阵,然后答案就是u到根节点的乘积乘以lca到根节点的逆元再乘上根节点到v的乘积乘上根节点到lca的逆元。可惜很多矩阵是没有逆元的 : (

后来cls说他是用树分治过的这题,因为一条链在树分治中只会经过一个根节点,自己想想的确是这样,然后码了下码出来了 QAQ

具体做法是对一个询问(u, v),在节点u处存一个v的询问,v处存一个u的询问,给树分治中的每块子树一个编号,当一个点u发现其对应的某个询问v的编号和自己的编号一样时,说明两者在一个子树中,且此时两点到根节点的矩阵之积都已经算出来了,此时算出答案,存下来就好。

树分治的复杂度是O(6464nlogn),每个询问会判断logn次,其中一次会记录答案,则总复杂度是O(6464nlogn+qlogn+6464q)

ps:因为是01矩阵,所以矩阵乘法可以用bitset加速。

代码

#include <cstdio>#include <cstring>#include <iostream>#include <algorithm>#include <bitset>#include <vector>using namespace std;#define PB push_backtypedef long long LL;typedef unsigned long long ULL;typedef pair<int, int> P;const int N = 3e3 + 5;const int LIM = 64;const int MOD = 19260817;LL p19[70], p26[70];struct Mat {  bitset<LIM> h[LIM], v[LIM];  void init() {    for (int i = 0; i < LIM; ++i) {      h[i].reset(); v[i].reset();    }    for (int i = 0; i < LIM; ++i) {      h[i].set(i); v[i].set(i);    }  }  void init(int m[LIM + 5][LIM + 5]) {    for (int i = 0; i < LIM; ++i) for (int j = 0; j < LIM; ++j) {      if (m[i + 1][j + 1]) {        h[i].set(j); v[j].set(i);      }      else {        h[i].reset(j); v[j].reset(i);      }    }  }  Mat operator*(const Mat& r) const {    Mat ret;    int t;    for (int i = 0; i < LIM; ++i) for (int j = 0; j < LIM; ++j) {      t = ((h[i] & r.v[j]).count()) & 1;      if (t) {        ret.h[i].set(j); ret.v[j].set(i);      }      else {        ret.h[i].reset(j); ret.v[j].reset(i);      }    }    return ret;  }  LL val() const  {    LL ret = 0;    for (int i = 0; i < LIM; ++i) for (int j = 0; j < LIM; ++j) {      if (h[i].test(j)) {        ret = (ret + p19[i + 1] * p26[j + 1] % MOD) % MOD;      }    }    return ret;  }};ULL seed;int n, q;int m[N][LIM + 5][LIM + 5];int vis[N], tim, siz[N], root, max_sub;int fst[N * 10];bool used[N];LL ans[N * 10];Mat mat[2][N], val[N];vector<int> edges[N];vector<P> querys[N];inline void init() {  for (int i = 1; i <= n; ++i) {    for (int p = 1; p <= 64; ++p) {      seed ^= seed * seed + 15;      for (int q = 1; q <= 64; ++q) {        m[i][p][q] = (seed >> (q - 1)) & 1;      }    }  }}void findroot(int u, int fa, int num) {  siz[u] = 1;  int K = 0;  for (int v: edges[u]) {    if (used[v] || v == fa) continue;    findroot(v, u, num);    K = max(K, siz[v]);    siz[u] += siz[v];  }  K = max(K, num - siz[u]);  if (K < max_sub) max_sub = K, root = u;}void dfs(int u, int fa) {  vis[u] = tim;  mat[0][u] = mat[0][fa] * val[u];  mat[1][u] = val[u] * mat[1][fa];  for (int v: edges[u]) {    if (used[v] || v == fa) continue;    dfs(v, u);  }  int id;  for (P q: querys[u]) if (vis[q.first] == tim) {    id = q.second;    if (fst[id] == u) {      ans[id] = (mat[1][u] * val[root] * mat[0][q.first]).val();    }    else {      ans[id] = (mat[1][q.first] * val[root] * mat[0][u]).val();    }  }}inline void work2(int u) {  vis[u] = ++tim;  mat[0][u].init(); mat[1][u].init();  for (int v: edges[u]) {    if (used[v]) continue;    dfs(v, u);  }  int id;  for (P q: querys[u]) if (vis[q.first] == tim) {    id = q.second;    if (fst[id] == u) {      ans[id] = (val[u] * mat[0][q.first]).val();    }    else {      ans[id] = (mat[1][q.first] * val[u]).val();    }  }}void work1(int u, int num) {  max_sub = num;  root = u;  findroot(u, 0, num);  u = root;  used[u] = true;  work2(u);  for (int v: edges[u]) {    if (used[v]) continue;    work1(v, siz[v]);  }}int main() {  for (int i = p19[0] = 1; i < 70; ++i) p19[i] = p19[i - 1] * 19 % MOD;  for (int i = p26[0] = 1; i < 70; ++i) p26[i] = p26[i - 1] * 26 % MOD;//  while (!IOerror) {  while (~scanf("%d%d", &n, &q)) {//    read(n); read(q);    for (int i = 1; i <= n; ++i) {      edges[i].clear();      querys[i].clear();      used[i] = false;      vis[i] = 0;    }    int u, v;    for (int i = 1; i < n; ++i) {      scanf("%d%d", &u, &v);//      read(u); read(v);      edges[u].PB(v);      edges[v].PB(u);    }    cin >> seed;//    read(seed);    init();    for (int i = 1; i <= n; ++i) {      val[i].init(m[i]);    }    for (int i = 1; i <= q; ++i) {      scanf("%d%d", &u, &v);//      read(u); read(v);      fst[i] = u;      querys[u].PB(P(v, i));      querys[v].PB(P(u, i));    }    tim = 0;    work1(1, n);    for (int i = 1; i <= q; ++i) printf("%lld\n", ans[i]);  }}
阅读全文
0 0
原创粉丝点击