HDU 6078 Wavel Sequence 计数dp(思维)

来源:互联网 发布:淘宝店铺手机怎么激活 编辑:程序博客网 时间:2024/06/05 20:49

传送门:HDU6078

题意:给出两个序列A和B,让你找出两组等长下标序列f和g,使得对于每个i,Afi == Bgi ,并且Afi序列为波浪序列。  问能找出多少种这样的下标序列。

波浪序列定义: a1<a2>a3<a4>a5<a6


 a1 < a2 > a3 < a4 > a5 < a6...

思路:先贴上官方题解:


f_{i,j,k}fi,j,k表示仅考虑a[1..i]a[1..i]b[1..j]b[1..j],选择的两个子序列结尾分别是a_iaib_jbj,且上升下降状态是kk 时的方案数,则f_{i,j,k}=\sum f_{x,y,1-k}fi,j,k=fx,y,1k,其中x<i,y<jx<i,y<j。暴力转移的时间复杂度为O(n^4)O(n4),不能接受。

考虑将枚举决策点x,yx,y的过程也DP掉。设g_{i,y,k}gi,y,k表示从某个f_{x,y,k}fx,y,k作为决策点出发,当前要更新的是ii的方案数,h_{i,j,k}hi,j,k表示从某个f_{x,y,k}fx,y,k作为决策点出发,已经经历了gg的枚举,当前要更新的是jj的方案数。转移则是要么开始更新,要么将ii或者jj继续枚举到i+1i+1以及j+1j+1。因为每次只有一个变量在动,因此另一个变量恰好可以表示上一个位置的值,可以很方便地判断是否满足上升和下降。

官方题解只看懂了暴力之前的部分。。于是百度了一发dalao们的做法,发现还是挺简洁易懂的:

dp[i][j][0]表示以a[i]和b[j]为公共序列结尾且为波谷的情况总和。 
dp[i][j][1]则表示波峰的情况总和。 
sum[i][j][0]表示∑(dp[k][j][0] | 1<=k<=j-1)。 //目前以b[j]为波谷结尾的‘总’匹配数
sum[i][j][1]则表示∑(dp[k][j][1] | 1<=k<=j-1)。 //目前以b[j]为波峰结尾的‘总’匹配数
那么对于每个a[i],只有存在j使得b[j]==a[i]时,

dp[i][j][0]等于∑(sum[i-1][k][1] | 1<=k<=j-1&&b[k]>a[i])+1,//代码中用cnt1动态的求这部分和

dp[i][j][1]等于∑(sum[i-1][k][0] | 1<=k<=j-1&&b[k]<=a[i]-1). //代码中用cnt0动态的求这部分和

以上转自:点击打开链接


总的来说就是用一个类似前缀和的sum数组优化掉了一个n的复杂度(内层枚举1...i的过程),又用两个变量动态求sum数组的和又优化掉了一个n的复杂度(内层枚举1...j的过程),让总体复杂度从n^4变成了n^2.


因为总的dp过程只与i和i-1相关,因此所有数组都可以优化掉一维。

代码:

#include<bits/stdc++.h>#define ll long long#define pb push_back#define fi first#define se second#define pi acos(-1)#define inf 0x3f3f3f3f#define lson l,mid,rt<<1#define rson mid+1,r,rt<<1|1#define rep(i,x,n) for(int i=x;i<n;i++)#define per(i,n,x) for(int i=n;i>=x;i--)using namespace std;typedef pair<int,int>P;const int MAXN = 2010;const int mod = 998244353;int gcd(int a,int b){return b?gcd(b,a%b):a;}int a[MAXN], b[MAXN];int dp[MAXN][2];int sum[MAXN][2];int main(){    int T;    cin >> T;    while(T--)    {        int n, m;        cin >> n >> m;        for(int i = 1; i <= n; i++)        scanf("%d", a + i);        for(int i = 1; i <= m; i++)        scanf("%d", b + i);        memset(dp, 0, sizeof(dp));        memset(sum, 0, sizeof(sum));        int ans = 0;        for(int i = 1; i <= n; i++)        {            int cnt0 = 0, cnt1 = 1;//第一个数字只能为波谷,说明第‘0’个数字为波峰             for(int j = 1; j <= m; j++)            {                dp[j][0] = dp[j][1] = 0;                if(a[i] == b[j])                {                    dp[j][0] = cnt1;                    dp[j][1] = cnt0;                }                else if(a[i] > b[j])                cnt0 = (cnt0 + sum[j][0]) % mod;                else                cnt1 = (cnt1 + sum[j][1]) % mod;                sum[j][0] = (sum[j][0] + dp[j][0]) % mod;                sum[j][1] = (sum[j][1] + dp[j][1]) % mod;            }        }        for(int j = 1; j <= m; j++)        {            ans = (ans + sum[j][0]) % mod;            ans = (ans + sum[j][1]) % mod;        }        cout << ans << endl;    }     return 0;}