您的位置:首页 > 其它

[poj1741]Tree(点分治+容斥原理)

2017-09-06 17:51 295 查看

题意:求树中点对距离<=k的无序点对个数。

解题关键:树上点分治,这个分治并没有传统分治的合并过程,只是分成各个小问题,并将各个小问题的答案相加即可,也就是每层的复杂度并不在合并的过程,是在每层的处理过程。

此题维护的是树上路径,考虑点分治。

点分治的模板题,首先设点x到当前子树跟root的距离为,则满足${d_x} + {d_y} \le k$可以加进答案,但是注意如果x,y在同一棵子树中,就要删去对答案的贡献,因为x,y会在其所在的子树中在计算一次。同一棵子树中不必考虑是否在其同一棵子树中的问题,因为无论是否在他的同一棵子树,都会对他的父节点产生影响。而这些影响都是无意义的。

注意无根树转有根树的过程,需要选取树的重心防止复杂度从$O(n{\log ^2}n)$退化为$O({n^2})$

复杂度:$O(n{\log ^2}n)$

 

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
#include<iostream>
#include<cmath>
#define inf 0x3f3f3f3f
#define maxn 10004
using namespace std;
typedef long long ll;
int head[maxn],cnt,n,k,ans,size,s[maxn],f[maxn],root,depth[maxn],num;//vis代表整体的访问情况,每个dfs不应该只用vis来存储
bool vis[maxn];
struct edge{
int to,w,nxt;
}e[maxn<<1];
void add_edge(int u,int v,int w){
e[cnt].to=v;
e[cnt].w=w;
e[cnt].nxt=head[u];
head[u]=cnt++;
}

inline int read(){
char k=0;char ls;ls=getchar();for(;ls<'0'||ls>'9';k=ls,ls=getchar());
int x=0;for(;ls>='0'&&ls<='9';ls=getchar())x=(x<<3)+(x<<1)+ls-'0';
if(k=='-')x=0-x;return x;
}

void get_root(int u,int fa){//get_root会用到size
s[u]=1;f[u]=0;//f是dp数组
for(int i=head[u];i!=-1;i=e[i].nxt){
int v=e[i].to;
if(v==fa||vis[v]) continue;
get_root(v,u);
s[u]+=s[v];
f[u]=max(f[u],s[v]);
}
f[u]=max(f[u],size-s[u]);
root=f[root]>f[u]?u:root;
}

void get_depth_size(int u,int fa,int dis){//同时获取size和depth
depth[num++]=dis;
s[u]=1;
for(int i=head[u];i!=-1;i=e[i].nxt){
int v=e[i].to;
if(v==fa||vis[v]) continue;
get_depth_size(v,u,dis+e[i].w);
s[u]+=s[v];
}
}

int calc(int u,int fa,int w){
num=0;
get_depth_size(u,fa,w);
sort(depth,depth+num);
int ret=0;
for(int l=0,r=num-1;l<r;){
if(depth[l]+depth[r]<=k) ret+=r-l++;
else r--;
}
return ret;
}

void work(int u){
vis[u]=true;
ans+=calc(u,-1,0);
for(int i=head[u];i!=-1;i=e[i].nxt){
int v=e[i].to;
if(vis[v]) continue;
ans-=calc(v,u,e[i].w);
size=s[v],root=0;
get_root(v,u);
work(root);
}
}

void init(){
memset(vis,false, sizeof vis);
memset(head,-1,sizeof head);
ans=cnt=0;
}

int main(){
int a,b,c;
f[0]=inf;
while(scanf("%d%d",&n,&k)&&(n||k)){
init();
for(int i=0;i<n-1;i++){
a=read(),b=read(),c=read();
add_edge(a,b,c);
add_edge(b,a,c);
}
size=n,root=0;
get_root(1,-1);
work(root);
printf("%d\n",ans);
}
return 0;

}

 

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: