codeforces 895D

来源:互联网 发布:windows日志服务器搭建 编辑:程序博客网 时间:2024/05/21 10:49

(组合数学)
题意:给定两个由全英文字母构成的字符串s1,s2,保证字典序s1<s2,求有多少个s3,使得s1<s3<s2并且s3s1的字母所组成的一个排列。

思路:
  首先统计s1串有哪些字母,每个字母出现了多少次,并用一个num数组保存下来。这样我们起码知道s3由哪些字母构成。然后想到可以先算字典序比s1小的字符串的数目,再算出字典序比s2小的字符串的数目,那么这两个结果一减,就是答案了。
  第一种数目比较好求,直接求出s1是它本身排列的第几位即可。而s3由于不是由s1转换而成的就有点不好办,于是想到可以将s3转换成s1字母构成的某种形式,再利用相同的方法求。
  在求本身排列是第几位的过程中,用i遍历0len,依次求出:从当前位比较结果来看,有多少串字典序小于s1。然后对于每个i,求出当前位的结果(这里用到了排列组合知识+一个小优化)。最后相加得到答案。

(公式懒得写了..这个编辑器有点不太会用。话说,还是要多敲题,这题思路虽然出来了,敲得贼慢…)

代码:

#include <cstdio>#include <cstring>#include <algorithm>#define LL long longusing namespace std;const int maxn = 1000010;const LL mod = 1e9 + 7;bool s2_smaller, s2_bigger;int num[30], sum[30];char s1[maxn], s2[maxn];LL fac[maxn], inv_fac[maxn];LL pow(LL a, LL b) {    LL ret = 1;    while(b) {        if(b&1)            ret = ret * a % mod;        a = a * a % mod;        b >>= 1;    }    return ret;}void pre_treat(int len) {    fac[0] = inv_fac[0] = 1;    for(int i=1; i<=len; i++) {        fac[i] = (LL)i * fac[i-1] % mod;        inv_fac[i] = pow((LL)fac[i], mod-2);    }}void init(int len) {    for(int i=0; i<len; i++)        num[s1[i]-'a'] ++;    sum[0] = num[0];    for(int i=1; i<26; i++)        sum[i] = sum[i-1] + num[i];}void transform_s2(int len) {    init(len);    for(int i=0; i<len; i++) {        int alpha = s2[i] - 'a';        if(!s2_smaller && !s2_bigger) {            if(num[alpha])                num[alpha] --;            else {                for(int j=alpha-1; j>=0; j--) if(num[j]) {                    s2_smaller = true; //puts("small");                    s2[i] = j + 'a'; num[j] --;                    break ;                }                if(s2_smaller) continue ;                for(int j=alpha+1; j<26; j++) if(num[j]) {                    s2_bigger = true; //puts("big");                    s2[i] = j + 'a'; num[j] --;                    break ;                }                if(s2_bigger) continue ;            }        }        else if(s2_smaller) {            for(int j=25; j>=0; j--) if(num[j]) {                s2[i] = j + 'a'; num[j] --;                break ;            }        }        else {//s2_bigger            for(int j=0; j<26; j++) if(num[j]) {                s2[i] = j + 'a'; num[j] --;                break ;            }        }    }}LL solve(char *s, int len) {    init(len);    LL ret = 0;    for(int i=0; i<len; i++) {        int alpha = s[i] - 'a';        if(alpha > 0 && sum[alpha-1] > 0) {            LL t1 = fac[len-i-1], t2 = 0;            for(int j=0; j<26; j++) if(num[j]) {                t1 = t1 * inv_fac[num[j]] % mod;                if(j < alpha)                    t2 = (t2 + num[j]) % mod;            }            ret = (ret + t1 * t2 % mod) % mod;        }        //printf("------i:%d ret:%I64d\n",i,ret);        num[alpha] --;        for(int j=alpha; j<26; j++)            sum[j] --;    }    return ret;}int main() {    //freopen("test.txt","r",stdin);    s2_smaller = s2_bigger = false;    scanf("%s%s",s1,s2);    int len = strlen(s1);    pre_treat(len);    transform_s2(len);    LL l = solve(s1, len), r = solve(s2, len);    //printf("l:%I64d r:%I64d\n",l,r);    LL ans = (r - l + mod) % mod;    if(!s2_smaller && !s2_bigger)        ans = (ans - 1 + mod) % mod;    else if(s2_smaller)        ans = ans;    else        ans = (ans - 1 + mod) % mod;    printf("%I64d\n",ans);    return 0;}
原创粉丝点击