题目链接

注意到两个排列本质不同,当且仅当它们的析合树不同。所以我们就是要对所有有 $n$ 个叶子的析合树计数。

考虑析合树的性质:首先析合树是区分儿子顺序的,所以我们会思考能否用生成函数去表示一棵树。设 $F(x)$ 表示析合树的生成函数,其中 $[x^m]F(x)$ 表示有 $m$ 个叶子的析合树的个数,我们先不考虑 $[x^0]F(x)$ 是啥,之后要用的时候再特殊定义。

那么下一步我们要考虑析合树有几种节点,发现只有三种:

  • 合点,要求至少有 $2$ 个子节点
  • 析点,要求至少有 $4$ 个子节点
  • 叶子,也就是要求下面不能接节点

那么可以得到 $F(x) = \frac{F(x)^2}{1-F(x)} + \frac{F(x)^4}{1-F(x)} + x$。接下来为了方便,我们省略 $(x)$ 不写。

由于要求所有位置的值,所以我们会尝试牛顿迭代解这个多项式。做一些移项可以得到:

$$ F^4 + 2F^2 - (x+1)F + x = 0 $$

(在这里你可以发现,带入显然有 $[x^0]F(x)=0$)

接下来尝试牛顿迭代。构造 $G(F) = F^4 + 2F^2 - (x+1)F + x$,显然 $G'(F) = 4F^3 + 4F - (x+1)$,所以可以得到:

$$ F_{n} = F_{n-1} - \frac{F^4+2F^2-(x+1)F+x}{4F^3 + 4F - (x+1)} $$

这个具体计算的时候,就只需要先求一遍 $F$ 的 DFT,然后就可以求出 $F^2,F^3,F^4$ 的点值,$x$ 的点值自然也是单位根,就可以求出分数上面和下面的多项式是啥了。NTT 转回去之后再做一下多项式求逆,就可以求出来了。复杂度 $O(n \log n)$。

#include <bits/stdc++.h>

#define fi first
#define se second
#define DB double
#define U unsigned
#define P std::pair
#define LL long long
#define LD long double
#define pb emplace_back
#define MP std::make_pair
#define SZ(x) ((int)x.size())
#define all(x) x.begin(),x.end()
#define CLR(i,a) memset(i,a,sizeof(i))
#define FOR(i,a,b) for(int i = a;i <= b;++i)
#define ROF(i,a,b) for(int i = a;i >= b;--i)
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl

const int MAXN = 262144 + 5;

int r[MAXN],N,ha,g;
using Poly = std::vector<int>;

inline int qpow(int a,int n=ha-2){
    int res = 1;
    while(n){
        if(n & 1) res = 1ll*res*a%ha;
        a = 1ll*a*a%ha;
        n >>= 1;
    }
    return res;
}

std::vector<int> dec;

inline void fj(int n){
    int q = std::sqrt(1.0*n);
    FOR(i,2,q){
        if(!(n%i)){
            dec.pb(i);
            while(!(n%i)) n /= i;
        }
    }
}

inline void getg(){
    fj(ha-1);
    FOR(i,2,ha-1){
        bool flag = 1;
        for(auto p:dec){
            flag &= qpow(i,(ha-1)/p) != 1;
        }
        if(flag){g = i;break;}
    }
}

inline void add(int &x,int y){
    x += y-ha;x += x>>31&ha;
}

inline void init(int n){
    N = 1;int len = 0;while(N <= n) N <<= 1,++len;
    FOR(i,0,N-1) r[i] = (r[i>>1]>>1)|((i&1)<<(len-1));
}

inline void NTT(Poly &A,int opt){
    A.resize(N);
    FOR(i,0,N-1) if(i < r[i]) std::swap(A[i],A[r[i]]);
    for(int mid = 1;mid < N;mid <<= 1){
        int W = qpow(g,(ha-1)/(mid<<1));
        for(int i = 0;i < N;i += (mid<<1)){
            for(int j = 0,w = 1;j < mid;++j,w = 1ll*w*W%ha){
                int x = A[i+j],y = 1ll*w*A[i+mid+j]%ha;
                A[i+j] = (x+y)%ha;A[i+mid+j] = (x+ha-y)%ha;
            }
        }
    }
    if(opt == -1){
        int inv = qpow(N);
        std::reverse(A.begin()+1,A.end());
        FOR(i,0,N-1) A[i] = 1ll*A[i]*inv%ha;
    }
}

inline Poly operator * (Poly A,Poly B){
    int len = SZ(A)+SZ(B)-2;
    init(len);
    NTT(A,1);NTT(B,1);
    FOR(i,0,N-1) A[i] = 1ll*A[i]*B[i]%ha;
    NTT(A,-1);A.resize(len+1);
    return A;
}

inline Poly Inv(Poly A,int len){// length=n
    assert(A[0] != 0);
    Poly a,b,B;B.resize(2);B[0] = qpow(A[0]);
    for(int n = 2;n < len*2;n <<= 1){
        // 正在算 n/2 -> n
        a.resize(n);b.resize(n);
        FOR(i,0,n-1) a[i] = A[i],b[i] = B[i];
        init(n);NTT(a,1);NTT(b,1);
        B.resize(N);
        FOR(i,0,N-1) B[i] = 1ll*b[i]*(2+ha-1ll*a[i]*b[i]%ha)%ha;
        NTT(B,-1);FOR(i,n,N-1) B[i] = 0;
    }
    B.resize(len);return B;
}

inline Poly work(int len){
    Poly F,a,b;F.resize(2);F[0] = 0;
    for(int n = 2;n < len*2;n <<= 1){
        // 正在算 n/2 -> n
        init(n);Poly tmp = F;NTT(tmp,1);
        a.resize(N);b.resize(N);
        int W = qpow(g,(ha-1)/N),w = 1;
        FOR(i,0,N-1){
            int f1 = tmp[i],f2 = 1ll*f1*f1%ha,f3 = 1ll*f2*f1%ha,f4 = 1ll*f3*f1%ha;
            a[i] = (1ll*f4 + 2ll*f2 + ha-1ll*(w+1)*f1%ha + w)%ha;
            b[i] = (4ll*f3 + 4ll*f1 + ha-(w+1)%ha)%ha;
            w = 1ll*w*W%ha;
        }
        NTT(a,-1);NTT(b,-1);a.resize(n);b.resize(n);
        a = a*Inv(b,n);
        F.resize(N);
        FOR(i,0,n-1) F[i] = (F[i]+ha-a[i])%ha;
        FOR(i,n,N-1) F[i] = 0;
    }
    F.resize(len);
    return F;
}

int main(){
    int n;scanf("%d%d",&n,&ha);getg();
    Poly res = work(n+1);
    FOR(i,1,n) printf("%d\n",res[i]);
    return 0;
}
Last modification:June 19th, 2021 at 07:50 pm
如果觉得我的文章对你有用,请随意赞赏