您的位置:首页 > 其它

洛谷P3384【模板】树链剖分

2017-12-17 21:20 435 查看
这题是树链剖分模板……

还是考得比较全面

树链剖分解决的是什么问题呢?

我们都知道维护一个带修一维序列,可以用线段数或树状数组解决,将效率从n^2降到nlogn,但是对于树形结构的数据呢?然后树链剖分就出现了,它把树形结构剖成一条条链,在链上用数据结构维护

树链剖分的方法有:轻重链剖分,长短链剖分,血统剖分……最后一个欧洲人专属

在解决实际问题中更常见的是前者,但是长短链也有应用哦,比如这边走bzoj3252

然而轻重链剖分到底是什么呢,我们需要知道几个概念:

1.重儿子

2.重边

3.重链

4.LCA

对于树上除叶子节点以外的节点,都有至少一个儿子,在这些儿子中,我们要选出一个最‘重’的,重的定义,定义size(X)为以X为根的子树的节点个数,有点像陈启峰大佬的SBT中的size数组,令V为U的儿子节点中size值最大的节点,V就是重儿子,那么边(U,V)被称为重边,树中重边之外的边被称为轻边,很多重边连成的链叫重链,那么,一棵树均摊log层,总共有不超过log条重链,我们就把一棵树划分成了log条一维序列,再用上dfs序,就可以用一棵线段树维护了

剖分前:



剖分后



对于最近公共祖先LCA,就先展示一下连重边,重链,求LCA的代码吧:

void dfs1(int x)
{
dep[x] = dep[fa[x]] + 1;siz[x] = 1;
for(int i = head[x]; i; i = E[i].next)
{
int to = E[i].to;
if(fa[x] != to && !fa[to])
{
val[to] = E[i].len;
fa[to] = x;
dfs1(to);
siz[x] += siz[to];
if(siz[son[x]] < siz[to])son[x] = to;
}
}
}
inline void dfs2(int x)
{
if(x == son[fa[x]])top[x] = top
b406
[fa[x]];
else top[x] = x;
if(son[x]) dfs2(son[x]);
for(int i = head[x]; i; i = E[i].next)if(fa[E[i].to] == x && E[i].to != son[x])dfs2(E[i].to);
}
int query(int x,int y)
{
for(; top[x] != top[y]; dep[top[x]] > dep[top[y]] ? x=fa[top[x]]:y=fa[top[y]]);
return dep[x] < dep[y]?x:y;
}


然后,对于这道题的修改,我们只需要在求LCA的过程中将对应的链在dfs序中的对应序列维护一下就好啦

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 500005;
inline int read() {
int ch,  x = 0,  f = 1;ch = getchar();
while((ch < '0' || ch > '9') && ch != '-') ch = getchar();
ch == '-' ? f = -1,  ch = getchar() : 0;
while(ch >= '0' && ch <= '9') {
x = x * 10 + ch - '0';
ch = getchar();
}
return f * x;
}
int A[maxn], B[maxn];
int n, m, sta, mod;
#define Ls(x) (x << 1)
#define Rs(x) (x << 1 | 1)
struct SegmentTree{
int l, r, len;
long long tag;
long long mult;
long long val;
SegmentTree()
{
tag = 0;val = 0;mult = 1;
l = r = len = 0;
}
}tre[maxn << 2];
void build(int rt, int l, int r)
{
tre[rt].l = l, tre[rt].r = r, tre[rt].len = r - l + 1;
if(l == r) {
tre[rt].val = B[l];
return ;
}
int mid = (l + r) >> 1;
build(Ls(rt), l, mid);
build(Rs(rt), mid + 1, r);
tre[rt].val = (tre[Ls(rt)].val + tre[Rs(rt)].val) % mod;
}
void push_down(int rt) {
//tre[Ls(rt)].mult =  tre[Ls(rt)].mult * tre[rt].mult % mod;
tre[Ls(rt)].tag = (tre[Ls(rt)].tag + tre[rt].tag) % mod;
tre[Ls(rt)].val  = (tre[Ls(rt)].val + tre[rt].tag * (tre[Ls(rt)].len)) % mod;
//tre[Rs(rt)].mult =  tre[Rs(rt)].mult * tre[rt].mult % mod;
tre[Rs(rt)].tag = (tre[Rs(rt)].tag + tre[rt].tag) % mod;
tre[Rs(rt)].val  = (tre[Rs(rt)].val + tre[rt].tag * (tre[Rs(rt)].len)) % mod;
tre[rt].tag=0;
}
void update(int rt, int L, int R, long long delta) {
int l = tre[rt].l, r = tre[rt].r;
int mid = (l + r) >> 1;
if(l >= L && r <= R) {

tre[rt].tag = (tre[rt].tag + delta) % mod;
tre[rt].val = (tre[rt].val + delta * tre[rt].len) % mod;

return ;
}
if(tre[rt].mult != 1 || tre[rt].tag ) push_down(rt);
if(L <= mid) update(Ls(rt), L, R, delta);
if(R > mid)  update(Rs(rt), L, R, delta);
tre[rt].val = (tre[Ls(rt)].val + tre[Rs(rt)].val) % mod;
return ;
}
long long query(int rt, int L, int R) {
int l = tre[rt].l, r = tre[rt].r;int mid = (l + r) >> 1;
if(l >= L && r <= R) {
return tre[rt].val;
}
if(tre[rt].mult != 1 || tre[rt].tag ) push_down(rt);
long long res = 0;
if(L <= mid) res += query(Ls(rt), L, R);
if(mid < R) res += query(Rs(rt), L, R);
return res % mod;

}

struct Edge {
int to, len, nxt;
Edge() {}
Edge(int _to, int _len, int _nxt):to(_to), len(_len), nxt(_nxt) {}
}E[maxn << 1];
int h[maxn], cnte, tot, cnt;
int val[maxn];
int a[maxn];
int Lr[maxn], Rr[maxn];
int dep[maxn], fa[maxn], sz[maxn], top[maxn], son[maxn],ID[maxn];
inline void add_edge(int u, int v, int w) {
E[++cnte] = Edge(v, w, h[u]), h[u] = cnte;
E[++cnte] = Edge(u, w, h[v]), h[v] = cnte;
}
void dfs1(int x) {

sz[x] = 1; dep[x] = dep[fa[x]] + 1;
for(int i = h[x]; i; i = E[i].nxt) {
int to = E[i].to;
if(to == fa[x]) continue;
fa[to] = x;val[x] = E[i].len;
dfs1(to);
sz[x] += sz[to];
if(sz[to] > sz[son[x]]) son[x] = to;
}

}
void dfs2(int x) {
Lr[x] = ++tot;
B[Lr[x]] = a[x];

if(x == son[fa[x]]) top[x] = top[fa[x]];
else top[x] = x;
if(son[x]) dfs2(son[x]);
for(int i = h[x]; i; i = E[i].nxt) {
int to = E[i].to;
if(to == fa[x] || to == son[x]) continue;
dfs2(to);
}
Rr[x] = tot;
}

void up(int a, int b, int c) {
int f1 = top[a], f2 = top[b];
while(f1 != f2) {
if(dep[f1] < dep[f2]) { swap(a, b); swap(f1, f2); }
update(1, Lr[f1], Lr[a], c);
a = fa[f1];
f1 = top[a];
}
if(dep[a] > dep[b]) swap(a, b);
update(1, Lr[a], Lr[b], c);
}

int qsum(int x, int y) {
int ans=0;
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]])swap(x, y);
ans += query(1,Lr[top[x]], Lr[x]);
x = fa[top[x]];
}
if(dep[x]<dep[y])swap(x, y);
ans+=query(1,Lr[y],Lr[x]);
return ans;
}

signed main() {
n = read(), m = read(), sta = read(), mod = read();
for(int i = 1; i <= n; i++) a[i] = read();
for(int i = 1; i < n; i++) add_edge(read(), read(), 0);
dfs1(sta);
dfs2(sta);
int x, y, z, opt;
build(1, 1, tot);
while(m--) {
opt = read();
if(opt == 1) {
x = read(), y = read(), z = read();
up(x, y, z);
}
else if(opt == 2) {
x = read(); y = read();
printf("%d\n", qsum(x, y) % mod);
}
else if(opt == 3) {
x = read(); z = read();
update(1, Lr[x], Rr[x], z);
}
else {
x = read();
printf("%d\n", query(1, Lr[x], Rr[x]) % mod);
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: