您的位置:首页 > 其它

bzoj4584 [Apio2016]赛艇

2017-04-25 18:34 501 查看
传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=4584

【题解】

令f[i,j,k]表示前i个学校,赛艇最远放在j区间,且j这个区间放了k个赛艇。

那么显然区间可以离散(这里用左闭右开方便),那么就是一个大概O(n^3)的做法了。

好像就行了?据说还要常数优化qwq

丢个方程

令s[j]表示Σf[i-1,j',*],其中j'<=j。

令len[j]表示第j个区间的实际长度。

f[i,j,1] = f[i-1,j,k] + s[j]*len[j]

(前i-1个就放到这么多了,前i-1个还没放到这,现在放在这里,有len[j]种方案,因为放哪都行)

f[i,j,k] = f[i-1,j,k] + f[i-1,j,k-1]*(len[j]-k+1)/k

(前i-1个就放到这么多了,前i-1个这个区间还差一个,把他塞进来)

后面那坨是怎么回事呢,因为原来是C(len[j],k-1),现在变成C(len[j],k),改变的量。

# include <vector>
# include <stdio.h>
# include <string.h>
# include <algorithm>
// # include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
const int M = 5e2 + 10, N = 1000 + 10;
const int mod = 1e9+7;

# define RG register
# define ST static

int n, m;
struct intervals {
// [l, r)
int l, r;
intervals() {}
intervals(int l, int r) : l(l), r(r) {}
}p[M];

vector<int> ps;
int len
;
int f
[M], s
;
int cnt[M]
, inv
;

inline int pwr(int a, int b) {
int ret = 1;
while(b) {
if(b&1) ret = 1ll * ret * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
}
return ret;
}

int main() {
scanf("%d", &n);
for (int i=1; i<=n; ++i) {
scanf("%d%d", &p[i].l, &p[i].r); ++p[i].r;
ps.push_back(p[i].l), ps.push_back(p[i].r);
}
sort(ps.begin(), ps.end());
ps.erase(unique(ps.begin(), ps.end()), ps.end());
m = ps.size();
for (int i=1; i<=n; ++i) {
int L = lower_bound(ps.begin(), ps.end(), p[i].l)-ps.begin()+1;
int R = lower_bound(ps.begin(), ps.end(), p[i].r)-ps.begin()+1;
p[i] = intervals(L, R);
}
for (int i=1; i<m; ++i) len[i] = ps[i] - ps[i-1];

for (int i=0; i<=1000; ++i) inv[i] = pwr(i, mod-2);

for (int i=0; i<m; ++i) s[i] = 1;

for (int i=1; i<=n; ++i) {
for (int j=1; j<m; ++j) cnt[i][j] = cnt[i-1][j];
for (int j=p[i].l; j<p[i].r; ++j) ++cnt[i][j];
}

for (int i=1; i<=n; ++i) {
for (int j=p[i].l; j<p[i].r; ++j) {
for (int k=cnt[i][j]; k>=2; --k) {
f[j][k] = f[j][k] + 1ll * f[j][k-1] * (len[j]-k+1) % mod * inv[k] % mod;
if(f[j][k] >= mod) f[j][k] -= mod;
}
f[j][1] = f[j][1] + 1ll * s[j-1] * len[j] % mod;
if(f[j][1] >= mod) f[j][1] -= mod;
}
s[0] = 1;
for (int j=1; j<m; ++j) {
s[j] = s[j-1];
for (int k=1; k<=cnt[i][j]; ++k) {
s[j] = s[j] + f[j][k];
if(s[j] >= mod) s[j] -= mod;
}
}
}

int ans = 0;
for (int i=1; i<m; ++i)
for (int j=1; j<=n; ++j) {
ans = ans + f[i][j];
if(ans >= mod) ans -= mod;
}
printf("%d\n", ans);

return 0;
}


View Code
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: