# Atcoder ABC300

# E - Dice Product 3

# 题面

你有一个整数 1 和一个骰子,骰子以相等的概率显示出介于 1166 之间的整数。当你的整数严格小于 NN 时,你重复下面的操作掷骰子。如果骰子显示xx ,则将你的整数乘以 xx 。求你的整数最终是 NN 的概率 (模为 998244353998244353 )

如何求模数为 998244353 的概率?

我们可以证明所求的概率总是有理数。此外,在本题的限制条件下,当该值表示为 PQ\frac{P}{Q} 与两个共质整数 PPQQ 时。我们可以证明有一个唯一的整数 R。使得R×QP(mod998244353)R \times Q \equiv P\pmod{998244353}0R<9982443530 \leq R \lt 998244353 。求这个 R。

解法

p(x)p(x) 为骰子转到 xx 的概率,那么考虑 p(x)p(x) 是怎么由更小的状态转移而来。

  • 首先 p1=1p_1=1 ,骰子显示 11 对于概率会有变化。

  • 如果 x0(mod2)x \equiv 0(mod\ 2),那么骰子有可能上一次 (或之前) 投到 22 过,那么就是由 15\frac{1}{5} 的可能使 x2\frac{x}{2} 转化为 xx

    \ldots

dp(n)=16(dp(n)+dp(2n)+dp(3n)+dp(4n)+dp(5n)+dp(6n)).\mathrm{dp}(n) = \frac{1}{6} (\mathrm{dp}(n) + \mathrm{dp}(2n) + \mathrm{dp}(3n) + \mathrm{dp}(4n) + \mathrm{dp}(5n) + \mathrm{dp}(6n) ).

这个等式包含了 dp(n)dp(n) ,因此我们将它们去掉:

dp(n)16dp(n)=16(dp(2n)+dp(3n)+dp(4n)+dp(5n)+dp(6n))\mathrm{dp}(n) - \frac{1}{6} \mathrm{dp}(n) = \frac{1}{6} (\mathrm{dp}(2n) + \mathrm{dp}(3n) + \mathrm{dp}(4n) + \mathrm{dp}(5n) + \mathrm{dp}(6n) )

56dp(n)=16(dp(2n)+dp(3n)+dp(4n)+dp(5n)+dp(6n))\frac{5}{6} \mathrm{dp}(n) = \frac{1}{6} (\mathrm{dp}(2n) + \mathrm{dp}(3n) + \mathrm{dp}(4n) + \mathrm{dp}(5n) + \mathrm{dp}(6n) )

dp(n)=15(dp(2n)+dp(3n)+dp(4n)+dp(5n)+dp(6n))\mathrm{dp}(n) = \frac{1}{5} (\mathrm{dp}(2n) + \mathrm{dp}(3n) + \mathrm{dp}(4n) + \mathrm{dp}(5n) + \mathrm{dp}(6n) )

所以,我们有:

p(x)={1x=1i=16p(xi)[x0modi]5otherwisep(x) = \begin{cases} 1 & x = 1 \\ \dfrac{\sum_{i=1}^6 p(\frac{x}{i})[x \equiv 0 \mod i]}{5} & \text{otherwise} \end{cases}

转移过程用一个 map 记录一下即可。

Code
代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <map>
#include <queue>
#include <stack>
#include <set>
#include <iomanip>
using namespace std;
#define int long long
const int mod=998244353;
int qmi(int a,int b){
    int res=1;
    while(b){
        if(b&1){
            res=(res%mod*a%mod)%mod;
        }
        a=(a%mod*a%mod)%mod;
        b>>=1;
    }
    return res;
}
int inv(int x){
    return qmi(x,mod-2);
}
map<int,int> m;
int solve(int x){
    if(m[x]){
        return m[x];
    }
    for(int i=2;i<=6;i++){
        if(!(x%i)){
            (m[x]+=solve(x/i)*inv(5)%mod)%=mod;
        }
    }
    return m[x];
}
inline int read(){
    int x=0,w=1;
    char ch=getchar();
    for(;ch>'9'||ch<'0';ch=getchar()) if(ch=='-') w=-1;
    for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
    return x*w;
}
inline void write(int x){
    if(x<0) putchar('-'),x=-x;
    if(x>9) write(x/10);
    putchar(x%10+'0');
}
signed main(){
    int n;
    cin>>n;
    m[1]=1;
    cout<<solve(n);
}
更新于 阅读次数