题目大意

一开始序列为空。每次你有 $P_a$ 的概率往序列末尾添加一个a,有 $P_b$ 的概率往序列末尾添加一个b。当子序列ab的出现次数 $\geq K$ 时停止,问子序列ab的期望出现次数。
保证有 $P_a+P_b = 1$(显然)

题解

考虑 dp,设 $f_{i,j,k}$ 表示拿了 $i$ 个 a,$j$ 个 b,子序列个数为 k 的期望长度。转移平凡:
$$f_{i,j,k} = P_a* f_{i+1,j,k}+P_b* f_{i,j+1,k+i}$$
我们发现 $j$ 在转移中是毫无意义的,不如去掉,状态变成了 $f_{i,j}$ 表示拿了 $i$ 个 a,子序列个数为 $j$。转移方程:
$$f_{i,j} = P_a* f_{i+1,j}+P_b* f_{i,i+j}$$
考虑如果 $i=0$ (即前面一堆 b) 的时候会怎么样:不难发现这样对答案的贡献是无意义的,因为每一个合法的方案前面都可以加上这样多的 b,取完平均值后就没有贡献了,所以我们需要求的是 $f_{1,0}$。
现在考虑如果状态 $i+j \geq k$ 的时候我们如何求(因为数组已经开不下了),考虑这种情况下再加入一个 $b$ 就结束了。设答案为 $T$ ,令 $q = i+j$ 不妨枚举 $a$ 出现的次数:
\begin{align*}
T &= \sum_{i=0}^{\infty} P_a^i P_b(q+i) \\
&= P_b\sum_{i=0}^{\infty}P_ a^i * i + P_ b* q\sum_ {i=0}^{\infty}P_a^i \tag{1} \\
&= \frac{P_a}{P_b} + \frac{q}{1-P_a} \tag{2} \\
&= \frac{P_a}{P_b} + q
\end{align*}
接下来我们证明从 $(1)\to (2)$ 的加号前面的推导:
令 $t = P_b\sum_{i=0}^{\infty} P_a^i * i$
那么有:
\begin{align*}
t &= P_b\sum_{i=0}^{\infty} P_a^i * i \\
&= P_b\sum_{i=1}^{\infty} P^i_a * i \\
&= P_b \sum_{i=0}^{\infty} P^{i+1}_ a * (i+1) \\
&= P_aP_b \sum_{i=0}^{\infty} P_{a}^i * (i+1) \\
&= P_aP_b(\sum_{i=0}^{\infty}P^i_a* i+\sum_{i=0}^{\infty}P_a^i)\\
&= P_aP_b(\frac{t}{P_b}+\frac{1}{1-P_a})
\end{align*}
解得
$t = \frac{P_a}{P_b}$
所以对于 $i+j \geq k$ 的状态期望都是 $i+j+\frac{P_a}{P_b}$ 了。于是就做完了。(建议写记忆化搜索)

代码

/*
 * Author: RainAir
 * Time: 2019-10-09 10:35:18
 */
#include <algorithm>
#include <iostream>
#include <cstring>
#include <climits>
#include <cstdlib>
#include <cstdio>
#include <bitset>
#include <vector>
#include <cmath>
#include <ctime>
#include <queue>
#include <stack>
#include <map>
#include <set>

#define fi first
#define se second
#define U unsigned
#define P std::pair
#define LL long long
#define pb push_back
#define MP std::make_pair
#define all(x) x.begin(),x.end()
#define CLR(i,a) memset(i,a,sizeof(i))
#define FOR(i,a,b) for(int i = a;i <= b;++i)
#define ROF(i,a,b) for(int i = a;i >= b;--i)
#define DEBUG(x) std::cerr << #x << '=' << x << std::endl

const int MAXN = 2000+5;
const int ha = 1e9 + 7;

inline int qpow(int a,int n=ha-2){
    int res = 1;
    while(n){
        if(n & 1) res = 1ll*res*a%ha;
        a = 1ll*a*a%ha;
        n >>= 1;
    }
    return res;
}

int k,pa,pb;
int f[MAXN][MAXN];

inline int dfs(int i,int j){//   i 个 a,j 个 ab
    if(i+j >= k) return ((i+j)%ha+1ll*pa*qpow(pb)%ha)%ha;
    if(f[i][j] != -1) return f[i][j];
    return f[i][j] = 1ll*(1ll*dfs(i+1,j)*pa%ha+1ll*dfs(i,i+j)*pb%ha)%ha;
}

int main(){
    scanf("%d%d%d",&k,&pa,&pb);CLR(f,-1);
    int t = pa+pb;t = qpow(t);
    pa = 1ll*pa*t%ha;pb = 1ll*pb*t%ha;
    printf("%d\n",dfs(1,0));
    return 0;
}
Last modification:April 7th, 2020 at 10:14 pm
如果觉得我的文章对你有用,请随意赞赏