hdu5632 Rikka with Array 数位dp

Rikka with Array

Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 65536/65536 K (Java/Others)
Problem Description
As we know, Rikka is poor at math. Yuta is worrying about this situation, so he gives Rikka some math tasks to practice. There is one of them:

Yuta has an array A of length n,and the ith element of A is equal to the sum of all digits of i in binary representation. For example,A[1]=1,A[3]=2,A[10]=2.

Now, Yuta wants to know the number of the pairs (i,j)(1i<jn) which satisfy A[i]>A[j].

It is too difficult for Rikka. Can you help her?

The first line contains a number T(T10)——The number of the testcases.

For each testcase, the first line contains a number n(n10300).

For each testcase, print a single number. The answer may be very large, so you only need to print the answer modulo 998244353.

Sample Input

Sample Output
7When $n=10$, $A$ is equal to $1,1,2,1,2,2,3,1,2,2$.So the answer is $7$.

BestCoder Round #73 (div.2)


题目描述:给定一个数n(0 < n < 10^300),问有多少个数对(i , j),满足1<=i < j <= n且A[i] > A[j],其中A[x]是x化成二进制之后中1的个数

思路:定义dp[len][sum][limit]表示当前枚举到第len位,已经枚举出的两个数的数位中1的个数差为i - j + 1000(加1000是因为差可能为负),枚举到第len位时i和j的状态为limit时合法的数对,其中
               limit == 0      表示i<j < n
               limit == 1      表示i<j = n
               limit == 2      表示i= j < n
               limit == 3      表示i =j = n

#pragma warning(disable:4786)#pragma comment(linker, "/STACK:102400000,102400000")#include<iostream>#include<cstdio>#include<cstring>#include<algorithm>#include<stack>#include<queue>#include<map>#include<set>#include<vector>#include<cmath>#include<string>#include<sstream>#include<bitset>#define LL long long#define FOR(i,f_start,f_end) for(int i=f_start;i<=f_end;++i)#define mem(a,x) memset(a,x,sizeof(a))#define lson l,m,x<<1#define rson m+1,r,x<<1|1using namespace std;const int INF = 0x3f3f3f3f;const int mod = 998244353;const double PI = acos(-1.0);const double eps=1e-6;const int maxn = 1000;const int base = 1000;LL dp[4][maxn][maxn * 2] ;char s[maxn];int num[maxn] , bits[maxn];LL dfs(int len , int sum , int limit , int zero){    if(sum > base && sum - base > len - 1)      return 0;    if(len == 1)        return (sum < base) && (!zero);    if(dp[limit][len][sum] != -1)       return dp[limit][len][sum];    LL res = 0;    for(int i = 0 ; i < 2 ; i++){        for(int j = 0 ; j < 2 ; j++){            if(limit == 0){                res += dfs(len - 1 , sum + i - j , 0 , zero && j == 0);                res %= mod;            }            else if(limit == 1){                if(i > bits[len - 1])     continue;                res += dfs(len - 1 , sum + i - j , i == bits[len - 1] ? 1 : 0, zero && j == 0);                res %= mod;            }            else if(limit == 2){                if(i < j)       continue;                res += dfs(len - 1 , sum + i - j , i == j ? 2 : 0, zero && j == 0);                res %= mod;            }            else{                if(i == bits[len - 1]){                    if(j < i)       res += dfs(len - 1 , sum + i - j , 1, zero && j == 0);                    if(j == i)     res += dfs(len - 1 , sum + i - j , 3, zero && j == 0);                }                else if(i < bits[len - 1]){                    if(j < i)       res += dfs(len - 1 , sum + i - j , 0, zero && j == 0);                    if(j == i)     res += dfs(len - 1 , sum + i - j , 2, zero && j == 0);                }                res %= mod;            }        }    }    dp[limit][len][sum] = res;    return res;}//convert函数用于大十进制数转二进制数,其中len是大整数的位数,大整数已经存在num数组里了//如大整数为14236,则len = 5 , num[1] = 6 , num[2] = 3 , num[3] = 2 , num[4] = 4,num[5] = 1,num[0]只用于判断//设大整数为n,大整数位数为len,时间复杂度是(logn * len)int convert(int len){    int m = 1;    while(len){        for(int i = len ; i ; i--){            num[i - 1] += (num[i] & 1) * 10;            num[i] >>= 1;        }   //这个操作一结束num中存的大整数就变成了原来的一半        bits[m++] = (num[0] != 0);        num[0] = 0;        if(!num[len])            --len;    }    return m;}LL solve(int len){    int cnt = convert(len);    LL ret = dfs(cnt , base , 3  , 1);    return ret;}int main(){    int T;    scanf("%d" , &T);    mem(dp , -1);    while(T--){        mem(dp[1] , -1);        mem(dp[3] , -1);        scanf("%s" , s + 1);        int len = strlen(s + 1);        for(int i = 1 ; i<= len ; i++){            num[i] = s[len - i + 1] - '0';        }        LL ans = solve(len);        printf("%lld\n",ans);    }    return 0;}

