hdu2227 Find the nondecreasing subsequences(dp+线段树or树状数组优化)

来源:互联网 发布:自动驾驶算法工程师 编辑:程序博客网 时间:2024/05/12 08:12

Problem Description
How many nondecreasing subsequences can you find in the sequence S = {s1, s2, s3, …., sn} ? For example, we assume that S = {1, 2, 3}, and you can find seven nondecreasing subsequences, {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}.

Input
The input consists of multiple test cases. Each case begins with a line containing a positive integer n that is the length of the sequence S, the next line contains n integers {s1, s2, s3, …., sn}, 1 <= n <= 100000, 0 <= si <= 2^31.

Output
For each test case, output one line containing the number of nondecreasing subsequences you can find from the sequence S, the answer should % 1000000007.

Sample Input
3
1 2 3

Sample Output
7

大致题意:告诉你n个数,让你求出这个数列不递减的子序列的个数,对1e9+7取模

思路:很容易可以想到dp方程,假设dp[i]表示到第i个位置的所有不递减子序列个数.
状态转移方程即:dp[i]=sum(dp[j])+1(j < i &&num[j]<=num[i]),但这样子的话时间复杂度就达到了n^2.考虑到有个sum,所以我们可以先将这n个数离散化,然后再用线段树(或者树状数组)来优化,那么查询sum的时间就降到了logn。

代码如下

代码1.

/*线段树*/#include<cstring> #include<cstdio> #include<iostream>   #include <algorithm>  #define ll long long int  #define lson l,m,rt<<1  #define rson m+1,r,rt<<1|1  using namespace std;const int M=1e5+5; const int mod=1e9+7;int dp[M<<2];  struct node{    int value;    int ii;}num[M];int cmp(node a,node b){    return a.value<b.value;}inline void PushPlus(int rt)  {       dp[rt]= (dp[rt<<1]+dp[rt<<1|1])%mod; }  void Updata(int p,ll add, int l, int r, int rt)//单点更新,p位置上的数值增加add {      if( l == r )      {           dp[rt]=(dp[rt]+add)%mod;        return ;      }      int m = ( l + r ) >> 1;      if(p <= m)          Updata(p, add, lson);      else          Updata(p, add, rson);      PushPlus(rt);  }  ll Query(int L,int R,int l,int r,int rt)  {      if(L>R)    return 0;    if( L <= l && r <= R )      {          return dp[rt];      }      int m = ( l + r ) >> 1;      ll ans=0;      if(L<=m )          ans=(ans+Query(L,R,lson))%mod;      if(R>m)          ans=(ans+Query(L,R,rson))%mod;      return ans;  }  int main()  {        int mark[M];    int n,k;     while(scanf("%d",&n)!=EOF)    {        memset(dp,0,sizeof(dp));        for(int i=1;i<=n;i++)        {            scanf("%d",&num[i].value);            num[i].ii=i;        }        sort(num+1,num+1+n,cmp);//离散化         int k=-1,t=0;        for(int i=1;i<=n;i++)        {            if(num[i].value!=k)            {                ++t;                k=num[i].value;            }            mark[num[i].ii]=t;        }        for(int i=1;i<=n;i++)        {            int sum=(Query(1,mark[i],1,t,1)+1)%mod;//查询sum(dp[j]),j<=mark[i]             Updata(mark[i],sum,1,t,1);//修改mark[i]位置上的值         }        printf("%d\n",Query(1,t,1,t,1));    }    return 0;  }  

代码2.

#include<cstring> #include<cstdio> #include<iostream>   #include <algorithm>  #define ll long long int   using namespace std;const int N=1e5+5; const int mod=1e9+7;ll dp[N];int n;struct node{    int value;    int ii;}num[N];int cmp(node a,node b){    return a.value<b.value;}int lowbit(int x){    return x&-x;}int sum(int x){    int s=0;    while(x>0)    {        s=(s+dp[x])%mod;        x=x-lowbit(x);    }    return s;}void add(int x,int date){    while(x<=n)    {        dp[x]=(dp[x]+date)%mod;        x=x+lowbit(x);    }}int main()  {    int mark[N];    while(scanf("%d",&n)!=EOF)    {        memset(dp,0,sizeof(dp));        for(int i=1;i<=n;i++)        {            scanf("%d",&num[i].value);            num[i].ii=i;        }        sort(num+1,num+1+n,cmp);//离散化         int k=-1,t=0;        for(int i=1;i<=n;i++)        {            if(num[i].value!=k)            {                ++t;                k=num[i].value;            }            mark[num[i].ii]=t;        }        for(int i=1;i<=n;i++)        {            int ans=(sum(mark[i])+1)%mod;            add(mark[i],ans);        }        printf("%d\n",sum(t));    }    return 0;  }   
阅读全文
0 0