[树状数组+上升子序列] HDU 3030 Increasing Speed Limits

来源:互联网 发布:.space是什么域名 编辑:程序博客网 时间:2024/06/05 04:11

题意:求一组数有多少严格上升子序列。以1 2 1 2 3为例,他的严格上升子序列有1 2 1 2 3 12 12 13 123 123 23 12 13 123 23

话说这个输入真的好坑,看了半天才看懂,算出来的数还会爆int ...- -|||

思路

因为要求严格上升,所以同样大小的数只能出现一次。因此可以先对数组进行排序、去重操作。得到的新数组即为离散化的数组,每个数对应的下标即大小(第i大)。

树状数组中保存的是当前状态下,以第i大的数结尾的严格上升子序列的个数。

对于原数组的每一个数,找到他的下标后求出在他之前的上升序列个数并+1,更新树状数组。若s[i]=3,loc=4,那Dp[i]存储的是以当前的3结尾的上升序列的个数,Sum(loc)表示在0-i一共累计有Sum(loc)个上升序列,且最大值小于s[i]。

#include <cmath>#include <algorithm>#include <cstdio>#include <iostream>#include <cstring>#include <vector>#include <queue>#include <string>#include <map>#include <cstdlib>#include <cmath>using namespace std;typedef long long ll;typedef unsigned int ui;#define mp(a,b) make_pair(a,b)#define mem(a,b) memset(a,b,sizeof(a))#define debug(a) printf("Debug: %d\n",a)const int Mod = 1000000007;const int maxn =5e5+50;const int INF = 0x3f3f3f3f;int T,n,m;ll x,y,z;ll a[maxn];ll s[maxn];ll id[maxn];ll dp[maxn],sum[maxn];// 1 2 1 2 3// 1 2 1 2 3 12 12 13 123 123 23 12 13 123 23ll lowbit(ll x){return x&(-x);}void add(int i,ll x){for (;i<=n;i+=lowbit(i))sum[i]=(sum[i]+x)%Mod;}ll Sum(int i){ll res=0;for (;i>0;i-=lowbit(i))res=(res+sum[i])%Mod;return res;}ll solve(){mem(sum,0);ll ans=0;memcpy(id,s,sizeof(s));sort(id,id+n);int cnt=unique(id,id+n)-id;for (int i=0;i<n;i++){int loc=lower_bound(id,id+cnt,s[i])-id;dp[i]=Sum(loc)+1;/*若s[i]=3,loc=4,那Dp[i]存储的是以当前的3结尾的上升序列的个数,Sum(loc)表示在0-i一共累计有Sum(loc)个上升序列,且最大值小于s[i]。(加1是因为1个也算) *///printf("I: %d Loc: %d Dp[i]: %lld Sum(loc): %lld\n",i,loc,dp[i],Sum(loc));add(loc+1,dp[i]);}for (int i=0;i<n;i++){ans=(ans+dp[i])%Mod;}return ans;}int main(){    #ifndef ONLINE_JUDGE        freopen("in.txt","r",stdin);    #endif    scanf("%d",&T);    for (int cs=1;cs<=T;cs++)    {        int ans=0;        scanf("%d %d %lld %lld %lld",&n,&m,&x,&y,&z);        for (int i=0;i<m;i++)        {            scanf("%lld",a+i);                  }        for (int i=0;i<n;i++)        {        s[i]=a[i%m];        a[i%m]=(x*a[i%m]+y*(i+1))%z;          }      //  for (int i=1;i<=n;i++) printf("%d %lld %d\n",i,s[i],id[i]);      //  puts("");        printf("Case #%d: %lld\n",cs,solve());    }}



0 0