您的位置:首页 > 其它

组合数取模之逆元方法+模板

2017-10-08 22:15 411 查看
参自:
http://www.cnblogs.com/liziran/p/6804803.html https://baike.baidu.com/item/%E8%B4%B9%E9%A9%AC%E5%B0%8F%E5%AE%9A%E7%90%86/4776158?fr=aladdin
现在目标是求









Cnm%p,p为素数(经典p=1e9+7)

虽然有



























Cnm=n!m!(n−m)!,但由于取模的性质对于除法不适用,所以









Cnm%p≠









































(n!%pm!%p∗(n−m)!%p)%p

所以需要把“除法”转换成“乘法”,才能借助取模的性质在不爆long long的情况下计算组合数。这时候就需要用到“逆元”!
逆元:对于a和p,若a*b%p≡1,则称b为a%p的逆元。


那这个逆元有什么用呢?试想一下求







(ab)%p,如果你知道b%p的逆元是c,那么就可以转变成







(ab)%p
= a*c%p = (a%p)(c%p)%p

那怎么求逆元呢?这时候就要引入强大的费马小定理!
费马小定理(Fermat's little theorem)是数论中的一个重要定理,在1636年提出,其内容为: 假如p是质数,且gcd(a,p)=1,那么 a(p-1)≡1(mod p),即:假如a是整数,p是质数,且a,p互质(即两者只有一个公约数1),那么a的(p-1)次方除以p的余数恒等于1。


接着因为







ap−1 = 











ap−2∗a,所以有











ap−2∗a%p≡1!对比逆元的定义可得,







ap−2是a的逆元!

所以问题就转换成求解







ap−2,即变成求快速幂的问题了(当然这需要满足p为素数)。

现在总结一下求解









Cnm%p的步骤:

通过循环,预先算好所有小于max_number的阶乘(%p)的结果,存到fac[max_number]里 (fac[i] = i!%p)

求m!%p的逆元(即求fac[m]的逆元):根据费马小定理,x%p的逆元为







xp−2,因此通过快速幂,求解

















fac[m]p−2%p,记为M

求(n-m)!%p的逆元:同理为求解





















fac[n−m]p−2%p,记为NM










Cnm%p =
((fac
*M)%p*NM)%p

模板:

#include <bits/stdc++.h>///codeforces 869C代码 主函数三个循环可以合并成一个循环+三个if
using namespace std;
const int MAXN = 5050;
const int mod = 998244353;
typedef unsigned long long LL;

LL a,b,c,aa,bb,cc;

LL inv[MAXN],fac[MAXN];

inline int Inv(int x){///x^(mod-2)
int res = 1;
int p = mod - 2;
while (p) {
if (p & 1) res = LL(res) * x % mod;
p >>= 1;
x = LL(x) * x % mod;
}
return res;
}

inline int C(int n, int k){
if (n < 0 || k < 0 || k > n) return 0;
return LL(fac
) * inv[k] % mod * inv[n - k] % mod;
}

void init(){
fac[0] = inv[0] = 1;
for (int i = 1; i < MAXN; i++) {
fac[i] = LL(fac[i - 1]) * i % mod;
inv[i] = Inv(fac[i]);///预处理fac[i]^(p-2)
}
}

int main(){
init();
cin>>a>>b>>c;
LL ans=0LL,ans1=0LL,ans2=0LL,ans3=0LL;
LL tmp;
aa=min(a,b);
for(LL i=0;i<=aa;++i){
tmp=(LL)C(a,i);
tmp=tmp*(LL)C(b,i)%mod;
tmp=tmp*(LL)fac[i]%mod;
ans1=(ans1+tmp)%mod;
}
bb=min(c,b);
for(LL i=0;i<=bb;++i){
tmp=(LL)C(c,i);
tmp=tmp*(LL)C(b,i)%mod;
tmp=tmp*(LL)fac[i]%mod;
ans2=(ans2+tmp)%mod;
}
cc=min(a,c);
for(LL i=0;i<=cc;++i){
tmp=(LL)C(a,i);
tmp=tmp*(LL)C(c,i)%mod;
tmp=tmp*(LL)fac[i]%mod;
ans3=(ans3+tmp)%mod;
}
ans=(ans1*ans2)%mod*ans3%mod;
printf("%d\n",ans%mod);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: