您的位置:首页 > 运维架构

bzoj4182 shopping [树形dp+点分治]

2018-03-16 08:00 218 查看
Description:

树上每个点有容量,花费,价值。你有mm元钱,问选出一个连通块的最大价值。

Solution:

问题在于必须选出一个连通快。每次树形dp时改动一下,递归子树时去掉一份下一个节点的花费,强制选择下一个节点一次,这样就是选择一个联通块了,每次背包二进制分解一下,再用点分治优化即可。

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 505;
int n, m, root, ans;
vector<int> G
;
int vis
, mx
, sz
, w
, c
, d
, dp
[4005];
int getsize(int u, int last) {
int ret = 1;
for(int i = 0; i < G[u].size(); ++i) {
if(G[u][i] != last && !vis[G[u][i]]) {
ret += getsize(G[u][i], u);
}
}
return ret;
}
void findroot(int u, int last, int S) {
mx[u] = 0;
sz[u] = 1;
for(int i = 0; i < G[u].size(); ++i) {
int v = G[u][i];
if(v != last && !vis[v]) {
findroot(v, u, S);
sz[u] += sz[v];
mx[u] = max(mx[u], mx[v]);
}
}
mx[u] = max(mx[u], S - sz[u]);
if(mx[u] < mx[root]) {
root = u;
}
}
void dfs(int u, int last, int m) {
if(m <= 0) {
return;
}
for(int i = 0, j = d[u]; j; ++i) {
if(j >= (1 << i)) {
for(int W = c[u] << i, V = w[u] << i, k = m; k >= W; --k) {
dp[u][k] = max(dp[u][k], dp[u][k - W] + V);
}
j -= 1 << i;
} else {
for(int W = c[u] * j, V = w[u] * j, k = m; k >= W; --k) {
dp[u][k] = max(dp[u][k], dp[u][k - W] + V);
}
j = 0;
}
}
for(int i = 0; i < G[u].size(); ++i) {
int v = G[u][i];
if(v != last && !vis[v]) {
for(int j = 0; j <= m - c[v]; ++j) {
dp[v][j] = dp[u][j];
}
dfs(v, u, m - c[v]);
for(int j = c[v]; j <= m; ++j) {
dp[u][j] = max(dp[u][j], dp[v][j - c[v]] + w[v]);
}
}
}
}
void solve(int u) {
root = 0;
findroot(u, 0, getsize(u, 0));
vis[root] = 1;
memset(dp[root], 0, sizeof(dp[root]));
dfs(root, 0, m - c[root]);
for(int i = 0; i <= m - c[root]; ++i) {
ans = max(ans, dp[root][i] + w[root]);
}
for(int i = 0; i < G[root].size(); ++i) {
if(!vis[G[root][i]]) {
solve(G[root][i]);
}
}
}
int main() {
int T;
scanf("%d", &T);
mx[0] = 0x3f3f3f3f;
while(T--) {
ans = 0;
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; ++i) {
G[i].clear();
vis[i] = 0;
scanf("%d", &w[i]);
}
for(int i = 1; i <= n; ++i) {
scanf("%d", &c[i]);
}
for(int i = 1; i <= n; ++i) {
scanf("%d", &d[i]);
--d[i];
}
for(int i = 1; i < n; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
solve(1);
printf("%d\n", ans);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: