[LOJ #150] 挑战多项式

LOJ #150

终于过了这个鬼题了QQQAQ

#include <iostream>
#include <cstdio>

template <typename T> void in(T &_){
    char _c=getchar(); int fl=1; _=0;
    while(!isdigit(_c)) fl=_c=='-'?-1:fl, _c=getchar();
    while(isdigit(_c)) _=_*10+_c-'0', _c=getchar(); _*=fl;
}

const int MOD=998244353, G=3, iG=332748118, inv2=499122177;

inline int ksm(int x, int y){
    int z=1;
    for(; y; y>>=1, x=1LL*x*x%MOD) if(y&1) z=1LL*z*x%MOD; return z;
}

namespace Quadratic{
    struct CipollaNumber{
        int A,B,om;
        CipollaNumber(){};
        CipollaNumber(int _A,int _B,int _O){A=_A,B=_B,om=_O;}
        friend CipollaNumber operator * (const CipollaNumber &A,const CipollaNumber &B){return CipollaNumber((A.A*static_cast<long long>(B.A)+static_cast<long long>(A.om)*A.B%MOD*B.B)%MOD,(static_cast<long long>(A.A)*B.B+static_cast<long long>(A.B)*B.A)%MOD,A.om);}
    };
    int Pow(CipollaNumber A,int B){
        CipollaNumber ret(1,0,A.om);
        for (B<<=1;B>>=1;A=A*A) if (B&1) ret=ret*A;
        return ret.A;
    }
    int Sqrt(int a){
        if (a==0) return 0;
        int d=1;
        for (;ksm(d*d-a+MOD,(MOD-1)>>1)<=1;d++);
        int ans=Pow(CipollaNumber(d,1,(d*d-a+MOD)%MOD),(MOD+1)>>1);
        return std::min(ans,MOD-ans);
    }
}

namespace Poly{
    int rev[400010], C[400010], D[400010], E[400010];

    void get_rev(int l){
        for(int i=0; i<l; i++) 
            rev[i]=(rev[i>>1]>>1)|((i&1) ? l>>1 : 0);
    }

    void ntt(int *P, int l, int opt){
        for(int i=0; i<l; i++) if(i>rev[i]) 
            std::swap(P[i], P[rev[i]]);
        for(int i=1; i<l; i<<=1){
            int W=opt>0 ? G : iG;
            W=ksm(W, (MOD-1)/(2*i));
            for(int j=0, p=2*i; j<l; j+=p)
                for(int k=0, w=1; k<i; k++, w=1LL*w*W%MOD){
                    int X=P[j+k], Y=1LL*P[i+j+k]*w%MOD;
                    P[j+k]=(X+Y)%MOD, P[i+j+k]=(X-Y)%MOD;
                }
        }
        if(opt>0) return; int inv=ksm(l, MOD-2);
        for(int i=0; i<l; i++) P[i]=1LL*P[i]*inv%MOD;
    }

    void Inv(int *A, int *B, int n){
        if(n==1){B[0]=ksm(A[0], MOD-2); return;}
        Inv(A, B, (n+1)/2);
        int l=1; while(l<(2*n)) l<<=1;
        for(int i=0; i<n; i++) C[i]=A[i];
        for(int i=n; i<l; i++) C[i]=0;
        get_rev(l);
        ntt(B, l, 1); ntt(C, l, 1);
        for(int i=0; i<l; i++) B[i]=1LL*B[i]*(2LL-1LL*C[i]*B[i]%MOD)%MOD;
        ntt(B, l, -1);
        for(int i=n; i<l; i++) B[i]=0;
    }

    void Dev(int *P, int l){
        for(int i=1; i<l; i++) P[i-1]=1LL*i*P[i]%MOD; P[l]=0;
    }

    void Int(int *P, int l){
        for(int i=l-1; i; i--) P[i]=1LL*ksm(i, MOD-2)*P[i-1]%MOD; P[0]=0;
    }

    void Ln(int *A, int *B, int n){
        for(int i=0; i<n; i++) D[i]=A[i];
        int l=1; while(l<2*n) l<<=1;
        for(int i=0; i<l; i++) B[i]=0;
        Inv(A, B, n); Dev(D, n);
        for(int i=n; i<l; i++) D[i]=0, B[i]=0;
        get_rev(l); ntt(D, l, 1); ntt(B, l, 1);
        for(int i=0; i<l; i++) B[i]=1LL*B[i]*D[i]%MOD;
        ntt(B, l, -1); for(int i=n; i<l; i++) B[i]=0;
        Int(B, n);
    }

    void Exp(int *A, int *B, int n){
        if(n==1){B[0]=1; return;}
        Exp(A, B, (n+1)/2);
        Ln(B, E, n);
        int l=1; while(l<2*n) l<<=1;
        get_rev(l);
        E[0]=(1-E[0]+A[0])%MOD;
        for(int i=1; i<n; i++) E[i]=(A[i]-E[i])%MOD;
        for(int i=n; i<l; i++) E[i]=0;
        ntt(B, l, 1); ntt(E, l, 1);
        for(int i=0; i<l; i++) B[i]=1LL*B[i]*E[i]%MOD;
        ntt(B, l, -1);
        for(int i=n; i<l; i++) B[i]=0;
    }

    void ksm(int *A, int *B, int n, int k){
        int l=1; while(l<2*n) l<<=1; get_rev(l);
        for(int i=0; i<l; i++) B[i]=0; B[0]=1;
        while(k){
            ntt(A, l, 1);
            if(k&1){
                ntt(B, l, 1);
                for(int i=0; i<l; i++) B[i]=1LL*B[i]*A[i]%MOD;
                ntt(B, l, -1);
                for(int i=n; i<l; i++) B[i]=0;
            }
            for(int i=0; i<l; i++) A[i]=1LL*A[i]*A[i]%MOD;
            ntt(A, l, -1); k>>=1;
            for(int i=n; i<l; i++) A[i]=0;
        }
    }

    void Sqrt(int *A, int *B, int n){        
        if(n==1){B[0]=Quadratic::Sqrt(A[0]); return;}
        Sqrt(A, B, (n+1)/2);
        int l=1; while(l<2*n) l<<=1; get_rev(l);
        for(int i=0; i<l; i++) D[i]=0;
        Inv(B, D, n); 
        for(int i=0; i<n; i++) D[i]=1LL*D[i]*inv2%MOD;
        ntt(B, l, 1);
        for(int i=0; i<l; i++) B[i]=1LL*B[i]*B[i]%MOD;
        ntt(B, l, -1);
        for(int i=0; i<n; i++) B[i]=(B[i]+A[i])%MOD;
        ntt(B, l, 1); ntt(D, l, 1);
        for(int i=0; i<l; i++) B[i]=1LL*B[i]*D[i]%MOD;
        ntt(B, l, -1);
        for(int i=n; i<l; i++) B[i]=0;
    }
}

int n, k, F[400010], A[400010], B[400010];

int main(){
    in(n); n++; in(k);
    for(int i=0; i<n; i++) in(F[i]), A[i]=F[i];
    Poly::Sqrt(A, B, n);
    for(int i=0; i<n; i++) A[i]=B[i], B[i]=0;
    Poly::Inv(A, B, n);
    for(int i=0; i<n; i++) A[i]=B[i], B[i]=0;
    Poly::Int(A, n);
    Poly::Exp(A, B, n);
    for(int i=0; i<n; i++) A[i]=B[i], B[i]=0;
    for(int i=0; i<n; i++) A[i]=(F[i]-A[i])%MOD;
    A[0]=(A[0]+2-F[0])%MOD;
    Poly::Ln(A, B, n);
    for(int i=0; i<n; i++) A[i]=B[i], B[i]=0;
    A[0]=(A[0]+1)%MOD;
    Poly::ksm(A, B, n, k);
    Poly::Dev(B, n);
    for(int i=0; i<n-1; i++) printf("%d ", (B[i]+MOD)%MOD);
    return 0;
}

发表评论

电子邮件地址不会被公开。 必填项已用*标注