HDU-6078 Wavel Sequence(dp+树状数组)

来源:互联网 发布:单片机时序图 编辑:程序博客网 时间:2024/06/07 16:16

传送门:HDU-6078

题意:有2个序列A和B,要从A,B中选子序列出来组成“小-大-小”这样的序列,且A,B对应的位置要相等,问有多少种选取方法

题解:dp+树状数组

设f[x][y][k]为当前A数组枚举到第x个,B数组枚举到第y个,起伏状态为k(0/1)时的方案数,考虑到用普通的dp转移会达到O(n^4),可以用二维树状数组进行维护,由于第一维具有递增的特性,因此只要维护第二维的下标和值,设C[i][j]为下标小于等于i,值小于等于j的前缀和

f[x][y][0]=t[1].sum(y-1,b[y]-1)

f[x][y][1]=t[0].sum(y-1,mx)-t[0].sum(y-1,b[y])

#include<stdio.h>#include<iostream>#include<algorithm>#include<stdlib.h>#include<math.h>#include<string.h>#include<set>#include<vector>#define lson l,m,rt<<1#define rson m+1,r,rt<<1|1#define first x#define second y#define eps 1e-5using namespace std;typedef long long LL;typedef pair<int, int> PII;const int inf = 0x3f3f3f3f;const int MX = 2e3 + 5;const LL mod = 998244353;int a[MX], b[MX], vis[MX], n, m;struct BTree {    LL A[MX][MX];    int mx;    void init() {        mx = 0;        memset(A, 0, sizeof(A));    }    inline int lowbit(int x) {        return x & (-x);    }    void add(int x, int y, int d) {        for (int i = x; i <= mx; i += lowbit(i))            for (int j = y; j <= mx; j += lowbit(j))                A[i][j] += d;    }    LL sum(int x, int y) {        LL ret = 0;        for (int i = x; i; i -= lowbit(i))            for (int j = y; j; j -= lowbit(j))                ret = (ret + A[i][j]) % mod;        return ret;    }} t[2];void pre_solve() {    int sz = 0;    memset(vis, 0, sizeof(vis));    for (int i = 1; i <= n; i++) vis[a[i]] = 1;    for (int i = 1; i <= m; i++) if (vis[b[i]]) b[++sz] = b[i];    m = sz;    sz = 0;    memset(vis, 0, sizeof(vis));    for (int i = 1; i <= m; i++) vis[b[i]] = 1;    for (int i = 1; i <= n; i++) if (vis[a[i]]) a[++sz] = a[i];    n = sz;    t[0].init();    t[1].init();    t[0].mx = t[1].mx = m + 1;    for (int i = 1; i <= n; i++) t[0].mx = max(t[0].mx, a[i] + 1);    t[1].mx = t[0].mx;    for (int i = n; i > 0; i--) a[i + 1] = a[i];    for (int i = m; i > 0; i--) b[i + 1] = b[i];    n++; m++;}LL f[MX][MX][2];int main() {    //freopen("in.txt", "r", stdin);    int T;    scanf("%d", &T);    while (T--) {        scanf("%d%d", &n, &m);        for (int i = 1; i <= n; i++) scanf("%d", &a[i]);        for (int i = 1; i <= m; i++) scanf("%d", &b[i]);        pre_solve();        t[1].add(1, t[1].mx, 1);        LL ans = 0;        for (int i = 2; i <= n; i++) {            for (int j = 2; j <= m; j++) {                if (a[i] != b[j]) continue;                f[i][j][0] = (t[1].sum(j - 1, t[1].mx) - t[1].sum(j - 1, b[j]) + mod) % mod;                f[i][j][1] = t[0].sum(j - 1, b[j] - 1);                ans = (ans + f[i][j][0] + f[i][j][1]) % mod;                t[0].add(j, b[j], f[i][j][0]);                t[1].add(j, b[j], f[i][j][1]);            }        }        printf("%lld\n", ans);    }    return 0;}


原创粉丝点击