A simple Blog for wyx I've been down the bottle hoping.
Handle 【弱省胡策】Round #5 FFT 多项式求逆
发表于: | 分类: Oi | 评论:0 | 阅读:251

给定序列${B}$,然后求一个序列${A}$,满足$B_i = \sum_{j=i}^n C_j^i A_j$

说出来你们可能不信我上来就想直接上反演qwq,算了我们还是说点正常的做法吧qwq

把组合数展开然后得到

$B_i * i! = \sum_{j=i}^n \frac{A_j j!}{(j-i)!} \pmod {998244353}$

然后令 $C_j = A_{n-j}\times (n-j)!, D_j = B_{n-j}\times j !$

所以就有 $D_i \ sum_{j=0}^i \frac{1}{(i-j)!} C_j$

令$F(x) = \sum_{i=0}^n D_i x^i$

$G(x) = \sum_{i=0}^n \frac{x^i}{i!}$

$H(x) = \sum_{i=0}^n C_i x^i$

然后显然就有$F(x) = G(x)*H(x), \pmod {x^{n+1}}$

然后直接挪一挪多项式求逆$n \log(n)$,据说有一部分是$e^x?$,不管了反正我求逆过了

#include <vector>
#include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1e6+5;
const int Max = 1e6;
typedef long long LL;
const int mod = 998244353;
typedef vector <int> pol;

inline int read() {
    int x=0,f=1;char ch = getchar();
    while(ch < '0' || ch > '9'){if(ch == '-')f=-1; ch = getchar();}
    while(ch >='0' && ch <='9'){x = (x<<1) + (x<<3) + ch - '0'; ch = getchar();}
    return x*f;
}

int n ;
int A[N], B[N];

inline int pow(int a,int b) {
    LL res = 1;
    for(;b;b>>=1,a=(LL)a*a%mod)
        if(b&1)
         res = res * a % mod;
    return res;
}

inline void NTT(int *a,int type) {
    register int i, k, j;
    for(i = j = 0; i < n; ++ i) {
        if(i > j) swap(a[i], a[j]);
        for(k = n >> 1; (j ^= k) < k; k >>= 1);
    }
    int x , y;
    for(int i = 2; i <= n; i <<= 1) {
        int wn = pow(3, (mod-1)/i);
        for(int j = 0; j < n; j += i) {
            int w = 1;
            for(int k = j; k < j + (i>>1); ++ k, w = (LL) w * wn % mod) {
                x = a[k], y = (LL)a[k + (i>>1)] * w % mod;
                a[k] = (x + y) % mod;
                a[k+(i>>1)] = (x - y + mod) % mod;
            }
        }
    }
    if(type == -1) {
        for(int i = 1; i < (n >> 1); i ++) swap(a[i], a[n - i]);
        int tt = pow(n, mod - 2);
        for(int i = 0; i < n; ++ i) a[i] = (LL) a[i] * tt % mod;
    }
}

inline void init(int len) {
    for(n = 1; n < len<<1; n <<= 1);
    memset( A, 0, n * sizeof(A[0]));
    memset( B, 0, n * sizeof(B[0]));
}

inline void check(pol &z) {
    int p = z.size();
    while(p && z[p-1] == 0) -- p;
    z.resize(p);
}

LL _c[2000+5];

pol operator *(pol a,pol b) {
    pol c (a.size() + b.size() - 1);
    if(c.size() <= 600) {
        for(int i = 0; i < a.size() + b.size(); ++ i) _c[i] = 0;
        for(int i = 0; i < a.size(); ++ i) {
            for(int j = 0; j < b.size(); ++ j) {
                (_c[i+j] += (LL) a[i] * b[j] % mod) %= mod;
            }
        } 
        for(int i = 0; i < c.size(); ++ i) c[i] = _c[i];
    }
    else {
        init(c.size());
        for(int i = 0; i < a.size(); ++ i) A[i] = a[i];
        for(int i = 0; i < b.size(); ++ i) B[i] = b[i];
        NTT(A, 1);
        NTT(B, 1);
        for(int i = 0; i < n; ++ i) A[i] = (LL) A[i] * B[i] % mod;
        NTT(A, -1);
        for(int i = 0; i < c.size(); ++ i) c[i] = A[i];
    }
    return c;
}

pol operator * (const pol &w, int x) {
    pol a(w);
    for(int i = 0; i < a.size(); ++i) a[i] = (LL) a[i] * x % mod;
    return a;
}

pol operator - (const pol &a,const pol &b) {
    pol c(max(a.size(), b.size()));
    for(int i = 0; i < a.size(); ++ i) c[i] = a[i];
    for(int i = 0; i < b.size(); ++ i) c[i] = (c[i] - b[i] + mod) % mod;
    return c;
}

pol inv(pol a) {
    if(a.size() == 1) {
        a.resize(1);
        a[0] = pow(a[0], mod-2);
        return a;
    }
    pol b = inv(pol(a.data(), a.data() + (a.size() -1)/2 + 1));
    pol c = b * b * a;
    c = b * 2 - c;
    c.resize(a.size());
    return c;
}

int fac[N], ifac[N];
LL temp1[N], temp2[N];

int main() {
    fac[0] = 1;
    for(int i = 1; i <= Max; ++ i) fac[i] = (LL)fac[i-1] * i % mod;
    ifac[Max] = pow(fac[Max], mod-2);
    for(int i = Max-1; ~i ; -- i) ifac[i] = (LL)ifac[i+1] * (i+1) % mod;
    int n = read();
    pol  a, b;
    a.resize(n+1);
    b.resize(n+1);
    for(int i = 0; i <= n; ++ i) temp1[i] = read(); 
    for(int i = 0; i <= n; ++ i) a[i] = (LL) temp1[n-i] * fac[n-i] % mod;
    for(int i = 0; i <= n; ++ i) b[i] = (LL) ifac[i];
    pol c = a * inv(b);
    while(c.size() <= n + 5) c.push_back(0);
    for(int i = 0; i <= n; ++ i) temp2[i] = (LL) c[n-i] * ifac[i] % mod;
    for(int i = 0; i <= n; ++ i) printf("%d ", temp2[i]);
    puts("");
} 

Title - Artist
0:00

站点地图 网站地图
Copyright © 2015-2017 A simple Blog for wyx
Powered by Typecho自豪的采用Sgreen主题

TOP