您的位置:首页 > 其它

树分治(Tree,poj 1741)

2016-10-16 22:55 363 查看
参考博客 http://www.cnblogs.com/kuangbin/p/3454883.html
参考论文 http://wenku.baidu.com/link?url=DusTYd_4dgXuIS_G88sIwkAzGc5oM3CRzwx0EZcEWeiOtBh9Va2Xywzm_jhdeYkJ2E25Af9JWlB_PzLGDm0BVxVzYXyArOPJHOZ275YtFBy href="http://www.cnblogs.com/kuangbin/p/3454883.html" target=_blank>





其实感觉和树形DP很类似,都是子树处理好了,再根据处理结果来处理当前节点。

#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
#define INF 0X3F3F3F3F
#define maxn 10010

int n,k;
int tot;
bool vis[maxn];
int dp[maxn];
int temp[maxn];
int MIN;
int le,ri;

struct Edge
{
int to,w,next;
}edges[maxn<<1];
int head[maxn];

void init()
{
tot=0;
memset(head,-1,sizeof(head));
}

void addedge(int u,int v,int w)
{
edges[tot].to=v;
edges[tot].w=w;
edges[tot].next=head[u];
head[u]=tot++;
}

int dfs(int u,int f)
{
dp[u]=1;
for(int i=head[u];i!=-1;i=edges[i].next)
{
int v=edges[i].to;
if(vis[v]||v==f) continue;
dp[u]+=dfs(v,u);
}
return dp[u];
}

void get(int u,int f,int tot,int& root)
{
int MAX=tot-dp[u];
for(int i=head[u];i!=-1;i=edges[i].next)
{
int v=edges[i].to;
if(vis[v]||v==f) continue;
get(v,u,tot,root);
MAX=max(MAX,dp[v]);
}
if(MAX<MIN)
{
MIN=MAX;
root=u;
}
}

void add(int u,int f,int sum)
{
temp[ri++]=sum;
for(int i=head[u];i!=-1;i=edges[i].next)
{
int v=edges[i].to;
if(vis[v]||v==f) continue;
add(v,u,sum+edges[i].w);
}
}

int cul(int l,int r)
{
sort(temp+l,temp+r);
int ret=0;
int e=r-1;
for(int i=l;i<r;i++)
{
if(temp[i]>k) break;
while(e>=l&&temp[i]+temp[e]>k) e--;
ret+=e-l+1;
if(e>=i) ret--;
}
return ret>>1;
}

int solve(int u)
{
int tot=dfs(u,-1);
int root;
MIN=INF;
get(u,-1,tot,root);
int ret=0;
vis[root]=true;
for(int i=head[root];i!=-1;i=edges[i].next)
{
int v=edges[i].to;
if(vis[v]||v==root) continue;
ret+=solve(v);
}
le=ri=0;
for(int i=head[root];i!=-1;i=edges[i].next)
{
int v=edges[i].to;
if(vis[v]||v==root) continue;
add(v,root,edges[i].w);
ret-=cul(le,ri);
le=ri;
}
ret+=cul(0,ri);
for(int i=0;i<ri;i++)
if(temp[i]<=k) ret++;
else break;
vis[root]=false;
return ret;
}

int main()
{
while(scanf("%d %d",&n,&k)==2&&(n+k))
{
init();
int u,v,w;
for(int i=1;i<n;i++)
{
scanf("%d %d %d",&u,&v,&w);
addedge(u,v,w);
addedge(v,u,w);
}
printf("%d\n",solve(1));
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: