您的位置:首页 > 大数据 > 人工智能

hdu 5293 Tree chain problem【树状dp+dfs序+树状数组】

2017-10-16 14:49 435 查看

题目大意:

在一棵树中,给出若干条链和链的权值,求选取不相交的链使得权值和最大。

解题思路:

树形DP。

设dp[i]表示i的子树下的最优权值和,sum[i]表示不考虑i点时子树的最优权值和,即(j是i的儿子),显然dp[i]>=sum[i]。那么问题是考虑i点时dp[i]的值是多少,假设有一条链通过i,且端点a和b都在i的子树里,即LCA(a,b)=i,如果考虑加上这条链的权值,那么a->i, b->i的路上的点v都不能有链经过它们(题目要求链不相交),那么-dp[v],但至少有sum[v],即dp[i]=max(sum[i],max(sum[i]+∑(-dp[v]+sum[v])+c),其中v是某条链上的点。

那么怎么快速求出sigma的值呢,想到树状数组维护前缀和。先算出每个点的入出时间戳,记为l[i],r[i],每维护一个点,就在树状数组l[i]上加sum[i]-dp[i],在r[i]+1上减去sum[i]-dp[i]。那维护i时,对于以i为根的一条链的任意一个端点a,在树状数组还未加入i的信息时,query(1~l[a])即为a->i路径上除i点sum-dp的和,即dp[i]=max(sum[i],max(sum[i]+∑(qeury(1~l[a])+query(1~l[b]))+c);

#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<queue>
#include<vector>
#include<ctime>
#define ll long long
using namespace std;

int getint()
{
int i=0,f=1;char c;
for(c=getchar();(c<'0'||c>'9')&&c!='-';c=getchar());
if(c=='-')c=getchar(),f=-1;
for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
return i*f;
}

const int N=200005;
int T,n,m;
int tot,first
,nxt
,to
;
int idx,l
,r
,dp
,sum
,dep
,fa
[20];
int c
;
struct node
{
int x,y,w,lca;
}p
;
vector<int>chain
;

void add(int x,int y)
{
nxt[++tot]=first[x],first[x]=tot,to[tot]=y;
}

void dfs(int u)
{
l[u]=++idx;
for(int i=1;i<20;i++)fa[u][i]=fa[fa[u][i-1]][i-1];
for(int e=first[u];e;e=nxt[e])
{
int v=to[e];
if(v!=fa[u][0])
{
dep[v]=dep[u]+1;
fa[v][0]=u;
dfs(v);
}
}
r[u]=idx+1;
}

int lca(int x,int y)
{
if(dep[x]<dep[y])swap(x,y);
int delta=dep[x]-dep[y];
for(int i=19;i>=0;i--)
if(delta&(1<<i))x=fa[x][i];
for(int i=19;i>=0;i--)
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return x==y?x:fa[x][0];
}

void Insert(int x,int val)
{
for(int i=x;i<=n+1;i+=i&(-i))
c[i]+=val;
}

int query(int x)
{
int res=0;
for(int i=x;i;i-=i&(-i))
res+=c[i];
return res;
}

void solve(int u)
{
dp[u]=sum[u]=0;
for(int e=first[u];e;e=nxt[e])
{
int v=to[e];
if(v!=fa[u][0])
{
solve(v);
sum[u]+=dp[v];
}
}
dp[u]=sum[u];
for(int i=0;i<chain[u].size();i++)
{
int x=p[chain[u][i]].x;
int y=p[chain[u][i]].y;
int tmp=query(l[x])+query(l[y]);
dp[u]=max(dp[u],tmp+sum[u]+p[chain[u][i]].w);
}
Insert(l[u],sum[u]-dp[u]);
Insert(r[u],dp[u]-sum[u]);
}

int main()
{
//freopen("lx.in","r",stdin);
//freopen("chain.out","w",stdout);
int x,y,z;
T=getint();
while(T--)
{
tot=idx=0;
memset(first,0,sizeof(first));
memset(fa,0,sizeof(fa));
memset(c,0,sizeof(c));
for(int i=1;i<=n;i++)chain[i].clear();
n=getint(),m=getint();
for(int i=1;i<n;i++)
{
x=getint(),y=getint();
add(x,y),add(y,x);
}
dfs(1);
for(int i=1;i<=m;i++)
{
p[i].x=getint(),p[i].y=getint(),p[i].w=getint();
p[i].lca=lca(p[i].x,p[i].y);
chain[p[i].lca].push_back(i);
}
solve(1);
cout<<dp[1]<<'\n';
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: