您的位置:首页 > 其它

一个矩阵乘法优化期望dp的题

2017-12-26 23:26 381 查看
小Y是一个喜欢玩游戏的OIer。一天,她正在玩一款游戏,要打一个Boss。

虽然这个Boss有 1010010^{100}10​100​​ 点生命值,但它只带了一个随从——一个只有 mmm 点生命值的“恐怖的奴隶主”。

这个“恐怖的奴隶主”有一个特殊的技能:每当它被扣减生命值但没有死亡(死亡即生命值 ≤0\leq
0≤0),且Boss的随从数量小于上限 kkk,便会召唤一个新的具有 mmm 点生命值的“恐怖的奴隶主”。

现在小Y可以进行 nnn 次攻击,每次攻击时,会从Boss以及Boss的所有随从中的等概率随机选择一个,并扣减 111 点生命值,她想知道进行 nnn 次攻击后扣减Boss的生命值点数的期望。为了避免精度误差,你的答案需要对 998244353998244353998244353 取模。

其中m只有3

按照我的思路,设计状态f[i][a][b][c]代码已经打了i次,1血的奴隶主数量为a,2血为b,3血为c的情况下期望扣BOSS多少血,然后枚举下一次打谁,转移

可是这时候出现一个问题,答案怎么算?很显然当前状态的期望次数还要乘上一个奇怪的类似概率的东西。

如果分开处理就显得非常复杂

这时候换个思路,f[i][a][b][c]代表剩下多少次攻击次数(abc的意义同前面)的情况下期望打几次BOSS。

举个例子,如果你剩下0步,那么显然你一次都打不了。答案就是f
[][][],比如说m=3就是f
[0][0][1],m=2就是f
[0][1][0];

那么f[i][a][b][c]由f[i-1][a'][b'][c']转移其中状态abc能推到状态a'b'c'。(比如说abc=4 1 2,a'b'c'=5 0 3)

以上几句话是这道题的核心,期望dp倒着转移在此题得到很深刻的体现

所以说设计转移矩阵p[i][j]代表状态j能推到状态i的概率,但是有一种转移f[i][a][b][c]=(f[i-1][a][b][c]+1)*(1/(a+b+c+1))无法处理,很简单,我们只需要在转移矩阵的最后一列多开一行,这行的第j列表示的是第j个状态的1/(a[j]+b[j]+c[j]+1),那么在答案矩阵的最后多开一列,这一列的第一个数为1,那么在转移的时候就会把1/(a+b+c+1)多算一次,注意卡常

#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int MOD = 998244353;
int p[205][205], gay[10][10][10];
int x[205], y[205], z[205], len;
int T, m, K, i, j, k, inv[21];
long long n;
inline int ksm(int x, int y, int z)
{
int b = 1;
while (y)
{
if (y & 1) b = 1ll * b * x % z;
x = 1ll * x * x % z;
y >>= 1;
}
return b;
}
struct sb{
int c[205][205];
inline friend sb operator * (const sb &a, const sb &b)
{
sb c;
for(int i = 1; i <= len; i ++)
for(int j = 1; j <= len; j ++)
{
long long res = 0;
for(int k = 1; k <= len; k ++)
{
res += (long long)a.c[i][k] * b.c[k][j];
if (res >= (long long)8e18) res %= MOD;
}
c.c[i][j] = res % MOD;
}
return c;
}
inline friend sb operator + (const sb &a, const sb &b)
{
sb c;
for(int i = 1; i <= 1; i ++)
for(int j = 1; j <= len; j ++)
{
long long res = 0;
for(int k = 1; k <= len; k ++)
{
res += (long long)a.c[i][k] * b.c[k][j];
if (res >= (long long)8e18) res %= MOD;
}
c.c[i][j] = res % MOD;
}
return c;
}
};
sb f[70], a;
int main()
{
freopen("patron.in", "r", stdin);
freopen("patron.out", "w", stdout);
cin >> T >> m >> K;
for(i = 1; i <= 20; i ++)
inv[i] = ksm(i, MOD - 2, MOD);
for(i = 0; i <= K; i ++)
for(j = 0; j <= K - i; j ++)
for(k = 0; k <= K - i - j; k ++)
{
if (m == 1 && (j || k)) continue;
if (m == 2 && k) continue;
x[++len] = i; y[len] = j; z[len] = k;
gay[i][j][k] = len;
}
len ++;
for(i = 1; i <= len; i ++)
p[len][i] = inv[x[i] + y[i] + z[i] + 1];
for(i = 1; i < len; i ++)
for(j = 1; j < len; j ++)
if (i == j) p[i][i] = inv[x[i] + y[i] + z[i] + 1];
else {
if (z[i])
{
int nexx = x[i], nexy = y[i] + 1, nexz = z[i] - 1;
if (x[i] + y[i] + z[i] < K)
{
if (m == 1) nexx ++;
if (m == 2) nexy ++;
if (m == 3) nexz ++;
}
if (nexx == x[j] && nexy == y[j] && nexz == z[j]) p[j][i] = 1ll * inv[x[i] + y[i] + z[i] + 1] * z[i] % MOD;
}
if (y[i])
{
int nexx = x[i] + 1, nexy = y[i] - 1, nexz = z[i];
if (x[i] + y[i] + z[i] < K)
{
if (m == 1) nexx ++;
if (m == 2) nexy ++;
if (m == 3) nexz ++;
}
if (nexx == x[j] && nexy == y[j] && nexz == z[j]) p[j][i] = 1ll * inv[x[i] + y[i] + z[i] + 1] * y[i] % MOD;
}
if (x[i])
{
int nexx = x[i] - 1, nexy = y[i], nexz = z[i];
if (nexx == x[j] && nexy == y[j] && nexz == z[j]) p[j][i] = 1ll * inv[x[i] + y[i] + z[i] + 1] * x[i] % MOD;
}
}
for(i = 1; i <= len; i ++)
for(j = 1; j <= len; j ++)
f[0].c[i][j] = p[i][j];
for(i = 1; i <= 60; i ++)
f[i] = f[i - 1] * f[i - 1];
while (T --)
{
cin >> n;
memset(a.c, 0, sizeof(a.c));
a.c[1][len] = 1;
for(i = 60; i >= 0; i --)
if ((1ll << i) <= n)
{
n -= 1ll << i;
a = a + f[i];
}
cout << a.c[1][gay[(m == 1) ? 1 : 0][(m == 2) ? 1 : 0][(m == 3) ? 1 : 0]] << endl;
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: