您的位置:首页 > 其它

bzoj 1036: [ZJOI2008]树的统计Count(树链剖分)

2017-10-03 13:07 495 查看
参考:http://hzwer.com/2543.html

初学树链剖分,借鉴大佬代码。。。

这应该是树链剖分裸题。。。

#include <bits/stdc++.h>
using namespace std;

const int inf = 0x7fffffff;
const int MAXN = 30005;
const int MAXM = 60005;

struct Edge
{
int to,next;
}edge[MAXM];
int head[MAXN],tot;
int v[MAXN],siz[MAXN],fa[MAXN],dep[MAXN];
int pos[MAXN],bl[MAXN],son[MAXN];//son存储节点的重儿子
int n,q,sz;

struct Seg
{
int l,r,mx,sum;
};
Seg t[MAXN<<2];

void addedge(int u, int v)
{
edge[tot].to = v;
edge[tot].next = head[u];
head[u] = tot++;
}

void init()
{
memset(head,-1,sizeof(head));
tot = 0;
int x,y;
scanf("%d",&n);
for(int i = 1; i < n; ++i)
{
scanf("%d %d",&x,&y);
addedge(x,y);
addedge(y,x);
}
for(int i = 1; i <= n; ++i)
scanf("%d",&v[i]);
}

void dfs1(int u, int pre)
{
siz[u] = 1;
int v;
for(int i = head[u]; i != -1; i = edge[i].next)
{
v = edge[i].to;
if(v == pre) continue;
dep[v] = dep[u]+1;
fa[v] = u;
dfs1(v,u);
siz[u] += siz[v];
if(son[u] == 0 || siz[son[u]] < siz[v])
son[u] = v;
}
}

void dfs2(int u, int top)
{
int v;
sz++;
pos[u] = sz;//分配在线段树中的编号
bl[u] = top;//记录链顶的点
if(son[u] == 0) return;
dfs2(son[u],top);//这是顺着重边向下找
//下边是轻边
for(int i = head[u]; i != -1; i = edge[i].next)
{
v = edge[i].to;
if(dep[v] > dep[u] && son[u] != v)
dfs2(v,v);
}
}

void build(int k, int l, int r)
{
t[k].l = l;
t[k].r = r;
if(l == r) return;
int mid = (l+r) >> 1;
build(k<<1, l, mid);
build(k<<1|1, mid+1, r);
}

void change(int k, int x, int y)
{
int l = t[k].l;
int r = t[k].r;
int mid = (l+r) >> 1;
if(l == r)
{
t[k].sum = t[k].mx = y;
return;
}
if(x <= mid) change(k<<1,x,y);
else change(k<<1|1,x,y);
t[k].sum = t[k<<1].sum + t[k<<1|1].sum;
t[k].mx = max(t[k<<1].mx,t[k<<1|1].mx);
}

int querymx(int k, int x, int y)
{
int l = t[k].l;
int r = t[k].r;
if(l == x && r == y) return t[k].mx;
int mid = (l+r) >> 1;
if(y <= mid) return querymx(k<<1,x,y);
else if(x > mid) return querymx(k<<1|1,x,y);
else return max(querymx(k<<1,x,mid),querymx(k<<1|1,mid+1,y));
}

int solvemx(int x, int y)
{
int mx = -inf;
while(bl[x] != bl[y])
{
if(dep[bl[x]] < dep[bl[y]]) swap(x,y);
mx = max(mx,querymx(1,pos[bl[x]],pos[x]));
x = fa[bl[x]];
}
if(pos[x] > pos[y]) swap(x,y);
mx = max(mx,querymx(1,pos[x],pos[y]));
return mx;
}

int querysum(int k, int x, int y)
{
int l = t[k].l;
int r = t[k].r;
int mid = (l+r) >> 1;
if(l == x && y == r) return t[k].sum;
if(y <= mid) return querysum(k<<1,x,y);
else if(x > mid)return querysum(k<<1|1,x,y);
else return querysum(k<<1,x,mid)+querysum(k<<1|1,mid+1,y);
}

int solvesum(int x, int y)
{
int sum = 0;
while(bl[x] != bl[y])
{
if(dep[bl[x]] < dep[bl[y]]) swap(x,y);
sum += querysum(1,pos[bl[x]],pos[x]);
x = fa[bl[x]];
}
if(pos[x] > pos[y]) swap(x,y);
sum += querysum(1,pos[x],pos[y]);
return sum;
}

void solve()
{
build(1,1,n);
for(int i = 1; i <= n; ++i)
change(1,pos[i],v[i]);
int q,x,y;
scanf("%d",&q);
char cmd[10];
while(q--)
{
scanf(" %s %d %d",cmd,&x,&y);
if(cmd[0] == 'C')
{
v[x] = y;
change(1,pos[x],y);
}
else if(cmd[1] == 'M')
printf("%d\n",solvemx(x,y));
else
printf("%d\n",solvesum(x,y));
}
}

int main()
{
init();
dfs1(1,-1);
dfs2(1,1);
solve();
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: