HDU 6061 RXD and functions(NTT+卷积)

来源:互联网 发布:非诚勿扰杨宇航淘宝店 编辑:程序博客网 时间:2024/05/18 00:46

传送门

RXD and functions

Time Limit: 4000/2000 MS (Java/Others)    Memory Limit: 524288/524288 K (Java/Others)
Total Submission(s): 532    Accepted Submission(s): 211


Problem Description
RXD has a polynomial function f(x), f(x)=ni=0cixi
RXD has a transformation of function Tr(f,a), it returns another function g, which has a property that g(x)=f(xa).
Given a1,a2,a3,,am, RXD generates a polynomial function sequence gi, in which g0=f and gi=Tr(gi1,ai)
RXD wants you to find gm, in the form of mi=0bixi
You need to output bi module 998244353.
n105
 

Input
There are several test cases, please keep reading until EOF.
For each test case, the first line consists of 1 integer n, which means degF.
The next line consists of n+1 intergers ci,0ci998244353, which means the coefficient of the polynomial.
The next line contains an integer m, which means the length of a.
The next line contains m integers, the i - th integer is ai.
There are 11 test cases.
0ai998244353

m105
 

Output
For each test case, output an polynomial with degree n, which means the answer.
 

Sample Input
2
0 0 1
1
1
 

Sample Output
1 998244351 1

Hint

(x1)2=x22x+1

题目大意:

已知多项式 f(x)=ni=0cixif(xai)

解题思路:

t=ai

f(x)=ni=0ci(x+t)i

(x+t)i用二项式定理展开后得到:

f(x)i!=i=0ncij=0ixjCjitij=i=0nxij=incjCijtji=i=0nxij=incjj!i!(ji)!tji=i=0nxii!j=incjj!tji(ji)!

我们将其系数提出来:
f(x)=j=incjj!tji(ji)!

bi=cni(ni)!则有:
f(x)=j=inbnjtji(ji)!

通过观察发现: (nj)+(ji)=ni

现在设:h=ni,则有:

f(x)=j=0hbhjtjj!

这是个卷积形式,因为是模意义下的,我们可以根据 NTT 求出f(x)的系数,但是在真正的 f(x) 里,f(i)=f(ni)i!
我们需要预处理阶乘逆元,以及阶乘,然后通过 NTT 计算 f(x),最后计算 f(x) 系数就OK了。
代码:

#include <bits/stdc++.h>using namespace std;typedef long long LL;const int MAXN = 3e5+5;const LL MOD = 998244353;const double eps = 1e-8;const double PI = acos(-1.0);LL c[100005], fac[100005], Inv[100005], ans[MAXN];void Init(){    fac[0] = Inv[0] = fac[1] = Inv[1] = 1;    for(int i=2; i<100005; i++) fac[i] = fac[i-1] * i % MOD;    for(int i=2; i<100005; i++) Inv[i] = (MOD - MOD / i) * Inv[MOD % i] % MOD;    for(int i=2; i<100005; i++) Inv[i] = Inv[i] * Inv[i-1] % MOD;}///const LL P = (479 << 21) + 1;//费马素数const LL P = MOD;const LL G = 3;//原根const LL NUM = 20;LL  wn[NUM];LL Pow(LL a, LL b, LL m){    LL ans = 1;    a %= m;    while(b){        if(b & 1) ans = ans*a%m;        b>>=1;        a = a*a%m;    }    return ans;}void GetWn(){    for(int i = 0; i < NUM; i++)    {        LL t = 1LL << i;        wn[i] = Pow(G, (P - 1) / t, P);    }}void change(LL * y, LL len) {    LL i, j, k;    for (i = 1, j = len / 2; i < len - 1; i++) {        if (i < j) swap(y[i], y[j]);        k = len / 2;        while (j >= k) {            j -= k;            k /= 2;        }        if (j < k) j += k;    }}void ntt(LL *y, LL len, LL on) {    change(y, len);    LL id = 0;    for (LL h = 2; h <= len; h <<= 1) {        id++;        for (LL j = 0; j < len; j += h) {            LL w = 1;            for (LL k = j; k < j + h / 2; k++) {                LL u = y[k] % P;                LL t = w * y[k + h / 2] % P;                y[k] = (u + t) % P;                y[k + h / 2] = (u - t + P) % P;                w = w * wn[id] % P;            }        }    }    if(on == -1)    {        for(LL i = 1; i < len / 2; i++)            swap(y[i], y[len - i]);        LL inv = Pow(len, P - 2, P);        for(LL i = 0; i < len; i++)            y[i] = y[i] * inv % P;    }}LL a[MAXN], b[MAXN];int main(){    ///freopen("in.txt", "r", stdin);    Init();    GetWn();    int n;    while(~scanf("%d", &n)){        memset(a, 0, sizeof(a));        memset(b, 0, sizeof(b));        n++;        for(int i=0; i<n; i++) scanf("%lld", &c[i]);        int m; scanf("%d", &m);        LL x, sum = 0;        while(m--) scanf("%lld", &x), sum = (sum + x) % MOD;        sum = (-sum + MOD) % MOD;        if(sum == 0){            for(int i=0; i<n; i++) printf("%lld ", c[i]);            puts("");            continue;        }        LL len = 1;        while(len < (2*n)) len<<=1LL;        LL s = 1;        for(int i=0; i<n; i++){            a[i] = c[n-1-i] * fac[n-1-i] % MOD;            b[i] = s * Inv[i] % MOD;            s = s * sum % MOD;        }        ntt(a, len, 1), ntt(b, len, 1);        for(int i=0; i<len; i++) a[i] = a[i] * b[i] % MOD;        ntt(a, len, -1);        for(int i=0; i<n; i++) printf("%lld ",a[n-1-i]*Inv[i]%MOD);        puts("");    }    return 0;}