题意
给一个正整数 $M$ 和长度为 $N$ 的序列 $A=(A_1,A_2,\ldots,A_n)$。找到以下这个序列 $X=(X_1,X_2,\ldots,X_{N+1})$ 的方案数,满足:
- $1 \leq X_i \leq M(1 \leq i \leq N+1)$
- $A_iX_i \leq X_{i+1}(1 \leq i \leq N)$
答案对 $998244353$ 取模。$1 \leq N \leq 1000,1 \leq M \leq 10^{18},1 \leq A_i \leq 10^9$。
题解
这个题一开始想的时候没想对地方,一直在想这其实是个求多项式逆在某一位置的值,但是应该迅速意识到多项式指数特别大的情况一般不太可做,应该思考能否把 $M$ 这种东西扔到定义域里。
写一下暴力 dp:设 $f_{i,j}$ 表示填完 $[i,N+1]$ 的 $X$ 了,并且 $A_i = j$ 的方案数,转移枚举上一个数字是什么转移,有:
$$ f_{i,j} = \sum_{k \geq j \times A_i} f_{i+1,k}f_{i,j} = \sum_{k \geq j \times A_i} f_{i+1,k} $$
但是第二维太大了,不会算。对于某一维度太大的 dp,我们考虑是否在这一维上是一个多项式。经过打表发现它确实是的。
证明一下,设 $s_i = \sum_j f_{i,j}$,可以把 dp 式子改写为:
$$ f_{i,j} = s_{i+1}-\sum_{k < j \times A_i} f_{i+1,k} $$
第一项是一个定值,第二项是第 $jA_i$ 项前缀和,由于这里的 $j$ 是线性的,所以这个多项式应该就是 $n-i+1$ 次多项式。
考虑如何转移:转移我们需要知道 $s_{i+1}$,还需要多次询问 $f_{i+1}$ 的前缀和。多项式求前缀和让我们联想到拉格朗日插值。我们维护出 $f_i$ 这个多项式的点值,然后每次做一个前缀和,当作 $n-i+2$ 次多项式插值就可以求出前缀和了。
这里有一点小细节:我们对于 $i$ 需要计算出 $lim_i$ 表示 $f_{i}$ 最大理论可能有值的位置,也就是 $lim_i = \lfloor \frac{lim_{i+1}}{A_i} \rfloor$。那么我们每次转移的时候求 $s_{i+1}$ 要代入 $lim_{i+1}$ 去插值,后面转移求 $f_i$ 的时候不需要取担心这个边界细节,因为超过边界的值下次转移不会用到,不超过边界的值转移一定不会跨过 $lim_{i+1}$。
但是这样每次转移需要插值 $n$ 次,总复杂度是 $O(n^3)$ 的,无法接受。
目前我知道的做法有两种优化方法,这两种方法都要注意到,$A_i > 1$ 的部分只有 $O(\log M)$ 个。
优化 1
我们考虑优化掉 $A_i = 1$ 的转移,当 $A_i=1$ 的时候,转移形如:
$$ f_{i,j} = s_{i+1}-\sum_{k<j} f_{i+1,k} $$
发现我们如果维护点值的话后面那个前缀和是可以直接预处理的,这一部分单次转移复杂度就可以变成 $O(n)$ 了。总复杂度 $O(n^2 + n^2\log M)$。
优化 2
考虑我们把所有 $A_i=1$ 的放到一起转移!由于点值可以任意取 $lim_i$ 内的值,我们考虑取 $lim_i,lim_{i}-1,\ldots$ ,那么这一段长度为 $r$ 的 $A_i=1$ 的转移就是求 $r$ 次后缀和,由于 $lim_i$ 一直不变,可以用组合数算贡献。
代码
我的代码是优化 1。
#include <bits/stdc++.h>
#define fi first
#define se second
#define DB double
#define U unsigned
#define P std::pair
#define LL long long
#define LD long double
#define pb emplace_back
#define MP std::make_pair
#define SZ(x) ((int)x.size())
#define all(x) x.begin(),x.end()
#define CLR(i,a) memset(i,a,sizeof(i))
#define FOR(i,a,b) for(int i = a;i <= b;++i)
#define ROF(i,a,b) for(int i = a;i >= b;--i)
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl
const int MAXN = 1e5+5;
const int ha = 998244353;
inline int qpow(int a,int n=ha-2){
int res = 1;
while(n){
if(n&1) res = 1ll*res*a%ha;
a = 1ll*a*a%ha;
n >>= 1;
}
return res;
}
int f[MAXN],g[MAXN],fac[MAXN],inv[MAXN];
LL lim[MAXN];
int n,A[MAXN];LL M;
int pre[MAXN],suf[MAXN];
inline void add(int &x,int y){
x += y-ha;x += x>>31&ha;
}
inline int cha(int m,int x){
if(x <= m) return f[x];
int ans = 0;
pre[0] = 1;FOR(i,1,m) pre[i] = 1ll*pre[i-1]*(x-i+ha)%ha;
suf[m+1] = 1;ROF(i,m,1) suf[i] = 1ll*suf[i+1]*(x-i+ha)%ha;
FOR(i,1,m){
int c = 1ll*pre[i-1]*suf[i+1]%ha*f[i]%ha*inv[i-1]%ha*inv[m-i]%ha;
if((m-i)&1) c = (ha-c)%ha;
add(ans,c);
}
return ans;
}
int main(){
fac[0] = 1;FOR(i,1,MAXN-1) fac[i] = 1ll*fac[i-1]*i%ha;
inv[MAXN-1] = qpow(fac[MAXN-1]);ROF(i,MAXN-2,0) inv[i] = 1ll*inv[i+1]*(i+1)%ha;
f[1] = 1;f[2] = 4;f[3] = 9;
scanf("%d%lld",&n,&M);
FOR(i,1,n) scanf("%d",A+i);
lim[n+1] = M;
f[1] = 1;f[2] = 1;
ROF(i,n,1){// deg=n-i+1
FOR(j,2,n-i+2) add(f[j],f[j-1]);
lim[i] = lim[i+1]/A[i];
int sm = cha(n-i+2,lim[i+1]%ha);
if(A[i] == 1){
FOR(j,1,n-i+3) g[j] = (sm-f[j-1]+ha)%ha;
}
else{
FOR(j,1,n-i+3) g[j] = (sm-cha(n-i+2,(1ll*j*A[i]-1)%ha)+ha)%ha;
}
FOR(j,1,n-i+3) f[j] = g[j];
}
FOR(i,1,n+2) add(f[i],f[i-1]);
printf("%d\n",cha(n+2,lim[1]%ha));
return 0;
}