您的位置:首页 > 其它

BZOJ 1036: [ZJOI2008]树的统计Count

2015-09-22 22:07 399 查看

题意:

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 III. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

题解:

这是树链剖分的模板题,不过我树链剖分写挂了T_T,只有抄网上的版。。。

代码:

来源:http://coraon.com/zjoi-2008/

#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
#define MAXN 30001
#define INF 0x3f3f3f3f
#define lchild rt << 1, l, m
#define rchild rt << 1 | 1, m + 1, r
using namespace std;
int n, w[MAXN], mw[MAXN];
vector<int>e[MAXN];

class Segment_Tree{
private:
int sum[MAXN << 2], upper[MAXN << 2];
void push_up(int rt){
sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
upper[rt] = max(upper[rt << 1], upper[rt << 1 | 1]);
}
public:
void build(int rt = 1, int l = 1, int r = n){
if(l == r){ sum[rt] = upper[rt] = mw[l]; return; }
sum[rt] = 0; upper[rt] = -INF;
int m = (l + r) >> 1;
build(lchild); build(rchild);
push_up(rt);
}
void update(int P, int val, int rt = 1, int l = 1, int r = n){
if(l == r) { sum[rt] = upper[rt] = val; return; }
int m = (l + r) >> 1;
if(P <= m) update(P, val, lchild);
else update(P, val, rchild);
push_up(rt);
}
int query(int L, int R, bool opt, int rt = 1, int l = 1, int r = n){
if(L <= l && r <= R){
if(opt) return upper[rt];
else return sum[rt];
}
int m = (l + r) >> 1;
if(opt){
int lans = -INF, rans = -INF;
if(L <= m) lans = query(L, R, opt, lchild);
if(R > m) rans = query(L, R, opt, rchild);
return max(lans, rans);
}
else{
if(L > m) return query(L, R, opt, rchild);
else if(R <= m) return query(L, R, opt, lchild);
else return query(L, m, opt, lchild) + query(m + 1, R, opt, rchild);
}
}
};

class HLD: public Segment_Tree{
public:
int dep[MAXN], fa[MAXN], sz[MAXN];
int son[MAXN], top[MAXN], dfn[MAXN], dfs_clock;

void init(){
memset(dep, 0, sizeof(dep));
memset(son, 0, sizeof(son));
dep[1] = 1;
dfs_clock = 0;
}

void dfs1(int u){
sz[u] = 1;
for(int i = 0; i < e[u].size(); i++){
int v = e[u][i];
if(dep[v]) continue;
dep[v] = dep[u] + 1;
fa[v] = u;
dfs1(v);
sz[u] += sz[v];
if(sz[son[u]] < sz[v])
son[u] = v;
}
}

void dfs2(int u, int tp){
top[u] = tp;
dfn[u] = ++dfs_clock;
mw[dfn[u]] = w[u];
if(son[u]) dfs2(son[u], tp); //拉链
for(int i = 0; i < e[u].size(); i++){
int v = e[u][i];
if(v == fa[u] || v == son[u]) continue;
dfs2(v, v); //建链
}
}

int getsum(int u, int v){
int ans = 0;
while(top[u] != top[v]){ //一直爬直到在u, v同一条重链
if(dep[top[u]] > dep[top[v]]) swap(u, v);
ans += query(dfn[top[v]], dfn[v], 0);
v = fa[top[v]];
}
if(dep[u] > dep[v]) swap(u, v);
ans += query(dfn[u], dfn[v], 0); //属于同一条重链的时候直接区间询问
return ans;
}

int getmax(int u, int v){
int ans = -INF;
while(top[u] != top[v]){
if(dep[top[u]] > dep[top[v]]) swap(u, v);
ans = max(ans, query(dfn[top[v]], dfn[v], 1));
v = fa[top[v]];
}
if(dep[u] > dep[v]) swap(u, v);
ans = max(ans, query(dfn[u], dfn[v], 1));
return ans;
}
}hld;

int main(){
#ifdef _DEBUG
freopen("d:\\2008.txt", "r", stdin);
#endif
char opt[10];
int u, v, m;
while(scanf("%d", &n) != EOF){
for(int i = 1; i <= n; i++)
e[i].clear();
hld.init();
for(int i = 1; i < n; i++){
scanf("%d %d", &u, &v);
e[u].push_back(v);
e[v].push_back(u);
}
for(int i = 1; i <= n; i++)
scanf("%d", w + i);
hld.dfs1(1);
hld.dfs2(1, 1);
hld.build();
scanf("%d", &m);
for(int i = 0; i < m; i++){
scanf("%s %d %d", opt, &u, &v);
if(opt[0] == 'C')
hld.update(hld.dfn[u], v);
else if(opt[1] == 'M')
printf("%d\n", hld.getmax(u, v));
else
printf("%d\n", hld.getsum(u, v));
}
}
return 0;
}


内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: