您的位置:首页 > 其它

树形背包DP的两种优化方式——vijos1676、codeforces815c

2018-03-08 16:14 573 查看

1.O(nm)——vijos1676陶陶吃苹果

背景

陶陶很喜欢吃苹果。

描述

curimit知道陶陶很喜欢吃苹果。于是curimit准备在陶陶生日的时候送给他一棵苹果树。

curimit准备了一棵这样的苹果树作为生日礼物:这棵苹果树有n个节点,每个节点上有c[i]个苹果,这棵树高度为h。

可是,当curimit把这棵树给陶陶看的时候,陶陶却说:“今年生日不收礼,收礼只收节点数减高度不超过k的苹果树。”这下curimit犯难了,curimit送来的树枝繁叶茂,不满足节点数-高度≤k。于是curimit决定剪掉一些枝条,使得修剪过后的树满足节点数-高度≤k,但是curimit又想保留尽量多的苹果数目。curimit想请你帮他算算经过修剪后的树最多能保留多少个苹果。

注:

一, 节点1为树根,不能把它剪掉。

二, 1个节点的树高度为1。

对于一课剪枝完满足条件的树,树的size最大就是m+树的最长链,换种说法就是取m个结点+一条链求怎么样最大化收益,同时还要满足树形背包dp的选了父节点才能选子节点的条件。

易知,选的链越长越好,所以答案所包含的那条最长链一定是跟到叶子结点的。

使用论文里的状态设置方式,dp[u][i]表示考虑了u结点左边的所有节点和所有子节点,花费i所能得到的最大收益。那么对于任意叶节点v,dp[v][i]就表示v结点左边结点中花费i的最大收益,从左往右dfs一遍再从右往左dfs一遍,得到两个dp数组,就能枚举叶节点和左边的花费得到答案。

#include<bits/stdc++.h>

using namespace std;
typedef long long ll;
const int maxn=4005;
const int maxm=505;
const int maxe=maxn*2;
const ll mod=1e9+7;
const int inf=0x3f3f3f3f;
int n,m,k;
int dpl[maxn][maxm];
int dpr[maxn][maxm];
int c[maxn];
vector<int> e[maxn];
int sum[maxn];
int a,b;
int ans;

void dfsl(in
4000
t u,int fa){
for(int i=0;i<e[u].size();i++){
int v=e[u][i];
if(v==fa)continue;
sum[v]=sum[u]+c[v];
for(int i=0;i<=k;i++)dpl[v][i]=dpl[u][i];
dfsl(v,u);
for(int i=1;i<=k;i++){
dpl[u][i]=max(dpl[u][i],dpl[v][i-1]+c[v]);
}
}
}

void dfsr(int u,int fa){
bool leaf=1;
for(int i=e[u].size()-1;i>=0;i--){
int v=e[u][i];
if(v==fa)continue;
leaf=0;
for(int i=0;i<=k;i++)dpr[v][i]=dpr[u][i];
dfsr(v,u);
for(int i=1;i<=k;i++){
dpr[u][i]=max(dpr[u][i],dpr[v][i-1]+c[v]);
}
}
if(leaf){
for(int i=0;i<=k;i++){
ans=max(ans,dpl[u][i]+dpr[u][k-i]+sum[u]);
}

}
}

int main(){
scanf("%d%d",&n,&k);
scanf("%d%d",&a,&c[1]);
for(int i=2;i<=n;i++){
scanf("%d%d",&a,&c[i]);
e[a].push_back(i);
e[i].push_back(a);
}
sum[1]=c[1];
dfsl(1,0);
dfsr(1,0);
printf("%d\n",ans);
return 0;
}


2.O(n^2)——codeforces 815C - Karen and Supermarket

超市买东西,每个东西有个原价和使用优惠券能减少的费用。不过优惠券有前置使用条件,就是要使用优惠券i必须要使用优惠券xi(就是树形背包dp的条件),问最多能买多少东西能不超过预算b。

b的范围是1e9,肯定不能拿来设置状态,dp[i][j][k]表示以i为根节点,使不使用,买j个物品最少需要的费用。

优化的方式就是转移的时候利用size数组来减少枚举次数,使得枚举次数变成合法点对数n^2。

#include<bits/stdc++.h>

using namespace std;
typedef long long ll;
const int maxn=5005;
const int maxm=505;
const int maxe=maxn*2;
const ll mod=1e9+7;
const int inf=0x3f3f3f3f;
int n,m,k,a;
int dp[maxn][maxn][2];
vector<int> e[maxn];
int c[maxn],d[maxn];
int siz[maxn];

int dfs(int u,int fa){
dp[u][0][0]=0;
dp[u][1][0]=c[u];
dp[u][1][1]=c[u]-d[u];
siz[u]=1;
for(int i=0;i<e[u].size();i++){
int v=e[u][i];
if(v==fa)continue;
dfs(v,u);
for(int i=siz[u];i>=0;i--){
//for(int i=0;i<=siz[u];i++){
for(int j=1;j<=siz[v];j++){
dp[u][i+j][0]=min(dp[u][i+j][0],dp[u][i][0]+dp[v][j][0]);
}
}
for(int i=siz[u];i>=0;i--){
//for(int i=1;i<=siz[u];i++){
for(int j=1;j<=siz[v];j++){
dp[u][i+j][1]=min(dp[u][i+j][1],dp[u][i][1]+min(dp[v][j][1],dp[v][j][0]));
}
}
siz[u]+=siz[v];
}
}

int main(){
scanf("%d%d",&n,&m);
scanf("%d%d",&c[1],&d[1]);
for(int i=2;i<=n;i++){
scanf("%d%d%d",&c[i],&d[i],&a);
e[a].push_back(i);
e[i].push_back(a);
}
memset(dp,0x3f,sizeof(dp));
dfs(1,0);
for(int i=n;i>=0;i--){
if(dp[1][i][0]<=m||dp[1][i][1]<=m){
printf("%d",i);
break;
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: