您的位置:首页 > 其它

线段树模板

2016-05-21 09:44 357 查看
基本用法:

区间求和,求最大值,求最小值

#include<algorithm>
#include<iostream>
#include<stdio.h>
#include<cstring>

using namespace std;

const int inf = 0x3f3f3f3f;
const int length = 10010;
int MAX[length << 2];
int MIN[length << 2];
int SUM[length << 2];

void up(int p)  //由左右子树向上调整
{
MAX[p] = max(MAX[p << 1], MAX[p << 1 | 1]);
MIN[p] = min(MIN[p << 1], MIN[p << 1 | 1]);
SUM[p] = SUM[p << 1] + SUM[p << 1 | 1];
}
void build(int l, int r, int p) //构建树
{
if (l == r)
{
scanf("%d", &MAX[p]);
MIN[p] = MAX[p];
SUM[p] = MAX[p];
}
else
{
int mid = (l + r) >> 1;
build(l, mid, p << 1);
build(mid + 1, r, p << 1 | 1);
up(p);
}
}

void Replace(int pos, int val, int l, int r, int tr)//在 pos 位置将值改为val ,【l,r】表示子树边界,tr表示子树根节点
{
if (l == r)
{
MAX= val;
MIN= val;
SUM= val;
return;
}
int mid = (l + r) >> 1;
if (mid >= pos)
{
Replace(pos, val, l, mid, tr << 1);
}
else
{
Replace(pos, val, mid + 1, r, tr << 1 | 1);
}
up(tr);
}
void add(int pos, int val, int l, int r, int tr)//在 pos 位置加上 val ,,【l,r】表示子树边界,tr表示子树根节点
{
if (l == r)
{
MAX+= val;
MIN+= val;
SUM+= val;
return;
}
int mid = (l + r) >> 1;
if (mid >= pos)
{
add(pos, val, l, mid, tr << 1);
}
else
{
add(pos, val, mid + 1, r, tr << 1 | 1);
}
up(tr);
}
int quemax(int L, int R, int l, int r, int tr)//在区间[L,R]上找出最大值,【l,r】表示子树边界,tr表示子树根节点
{
if (R<l || L>r)
return -inf;
if (L <= l&&r <= R)
{
return MAX;
}
int res =  quemax(L, R, l, (l+r)>>1, tr << 1);
int ans =  quemax(L, R, 1+((l+r)>>1), r, tr << 1 | 1);
return res>ans?res:ans;
}
int quemin(int L, int R, int l, int r, int tr)//在区间[L,R]上找出最小值,【l,r】表示子树边界,tr表示子树根节点
{
if (R<l || L>r)
return inf;
if (L <= l&&r <= R)
{
return MIN;
}
int res = quemin(L, R, l, (l+r)>>1, tr << 1);
int ans = quemin(L, R, 1+((l+r)>>1), r, tr << 1 | 1);
return res<ans?res:ans;
}
int quesum(int L, int R, int l, int r, int tr)//在区间[L,R]上求和,【l,r】表示子树边界,tr表示子树根节点
{
if (R<l || L>r)
return 0;
if (L <= l&&r <= R)
{
return SUM;
}
int res = quesum(L, R, l, (l+r)>>1, tr << 1);
int ans = quesum(L, R, 1+((l+r)>>1), r, tr << 1 | 1);
return res+ans;
}
//void print(int l,int r,int rt)
//{
//	if (l == r)
//	{
//		cout << MAX[rt] << ' ';
//	}
//	else
//	{
//		int mid = (l + r) >> 1;
//		print(l, mid, rt << 1);
//		print(mid + 1, r, rt << 1 | 1);
//	}
//}
int main()
{
int n, m;
while (scanf("%d%d", &n, &m) != EOF)
{
build(1, n, 1);
//		print(1,n,1);
//		cout << endl;
while (m--)
{
char op[2];
int a, b;
scanf("%s%d%d", op, &a, &b);
if (op[0] == 'Q') //区间求最大
{
/* for(int i = 1;i<=10;i++)
printf("%d ",MAX[i]);
puts("");*/
printf("%d\n", quemax(a, b, 1, n, 1));
//				print(1,n,1);cout<<endl;
}
else if (op[0] == 'U') //单点替换
{
Replace(a, b, 1, n, 1);
//				print(1,n,1);cout<<endl;
}
else if (op[0] == 'M')//区间求最小
{
/*for(int i = 1;i<=10;i++)
printf("%d ",MIN[i]);
puts("");*/
printf("%d\n", quemin(a, b, 1, n, 1));
//				print(1,n,1);cout<<endl;
}
else if (op[0] == 'H')//区间求和
{
printf("%d\n", quesum(a, b, 1, n, 1));
//				print(1,n,1);cout<<endl;
}
else if (op[0] == 'S') //单点增加
{
add(a, b, 1, n, 1);
//				print(1,n,1);cout<<endl;
}
else if (op[0] == 'E')//单点减少
{
add(a, -b, 1, n, 1);
//				print(1,n,1);cout<<endl;
}
}
}
return 0;
}


区间替换:

这里有一点需要注意:向下更新 lazy 数组有两种情况,1.update 一个区间没有一次性更新完。2.query 求和时,一个区间不会正好等于所求区间。

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>

#define LL long long
const int maxn = 100100;

using namespace std;

int lazy[maxn << 2];
int sum[maxn << 2];
void PushUp(int rt)//由左孩子、右孩子向上更新父节点
{
sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
}
void PushDown(int rt, int m) //向下更新
{
if (lazy[rt]) //懒惰标记
{
lazy[rt << 1] = lazy[rt << 1 | 1] = lazy[rt];
sum[rt << 1] = (m - (m >> 1)) * lazy[rt];
sum[rt << 1 | 1] = ((m >> 1)) * lazy[rt];
lazy[rt] = 0;
}
}
void build(int l, int r, int rt)//建树
{
lazy[rt] = 0;

if (l == r)
{
scanf("%d", &sum[rt]);
return;
}
int m = (l + r) >> 1;
build(l , m , rt << 1);
build(m+1, r , rt<<1|1);
PushUp(rt);
}
void update(int L, int R, int c, int l, int r, int rt)//更新
{
//if(L>l||R>r) return;
if (L <= l && r <= R)
{
lazy[rt] = c;
sum[rt] = c * (r - l + 1);
//printf("%d %d %d %d %d\n", rt, sum[rt], c, l, r);
return;
}
PushDown(rt, r - l + 1);
int m = (l + r) >> 1;
if (L <= m) update(L, R, c, l , m , rt << 1);
if (R > m) update(L, R, c, m+1, r , rt<<1|1);
PushUp(rt);
}

LL query(int L, int R, int l, int r, int rt)//在【L,R】区间求和,l 和 r 为左右边界,rt 为根
{
if (L <= l && r <= R)
{
//printf("%d\n", sum[rt]);
return sum[rt];
}
PushDown(rt, r - l + 1);
int m = (l + r) >> 1;
LL ret = 0;
if (L <= m) ret += query(L, R, l , m , rt << 1);
if (m < R) ret += query(L, R, m+1, r , rt<<1|1);
return ret;
}
int main()
{
int  n, m;
char str[5];
while (scanf("%d%d", &n, &m))
{
build(1, n, 1);
while (m--)
{
scanf("%s", str);
int a, b, c;
if (str[0] == 'T')
{
scanf("%d%d%d", &a, &b, &c);
update(a, b, c, 1, n, 1);
}
else if (str[0] == 'Q')
{
scanf("%d%d", &a, &b);
cout << query(a, b, 1, n, 1) << endl;
}
}
}
return 0;
}


区间增减:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>

#define max(a,b) (a>b)?a:b
#define min(a,b) (a>b)?b:a
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
#define LL __int64
const int maxn = 100100;
using namespace std;
LL lazy[maxn<<2];
LL sum[maxn<<2];

void putup(int rt)
{
sum[rt] = sum[rt<<1] + sum[rt<<1|1];
}
void putdown(int rt,int m)
{
if (lazy[rt])
{
lazy[rt<<1] += lazy[rt];
lazy[rt<<1|1] += lazy[rt];
sum[rt<<1] += lazy[rt] * (m - (m >> 1));
sum[rt<<1|1] += lazy[rt] * (m >> 1);
lazy[rt] = 0;
}
}
void build(int l,int r,int rt) {
lazy[rt] = 0;
if (l == r)
{
scanf("%I64d",&sum[rt]);
return ;
}
int m = (l + r) >> 1;
build(lson);
build(rson);
putup(rt);
}
void update(int L,int R,int c,int l,int r,int rt)
{
if (L <= l && r <= R)
{
lazy[rt] += c;
sum[rt] += (LL)c * (r - l + 1);
return ;
}
putdown(rt , r - l + 1);
int m = (l + r) >> 1;
if (L <= m) update(L , R , c , lson);
if (m < R) update(L , R , c , rson);
putup(rt);
}
LL query(int L,int R,int l,int r,int rt)
{
if (L <= l && r <= R)
{
return sum[rt];
}
putdown(rt , r - l + 1);
int m = (l + r) >> 1;
LL ret = 0;
if (L <= m) ret += query(L , R , lson);
if (m < R) ret += query(L , R , rson);
return ret;
}
int main()
{
int n , m;int a , b , c;
char str[5];
scanf("%d%d",&n,&m);
build(1 , n , 1);
while (m--)
{

scanf("%s",str);
if (str[0] == 'Q')
{
scanf("%d%d",&a,&b);
printf("%I64d\n",query(a , b , 1 , n , 1));
}
else if(str[0]=='C')
{
scanf("%d%d%d",&a,&b,&c);
update(a , b , c , 1 , n , 1);
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: