您的位置:首页 > 其它

【树链剖分】 HDOJ 4718 The LCIS on the Tree

2014-11-16 11:01 344 查看
树链剖分,线段树区间合并。。。。比较难调试。。。

#include <iostream>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <bitset>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <climits>
#include <cstdlib>
#include <cmath>
#include <time.h>
#define maxn 100005
#define maxm 200005
#define eps 1e-10
#define mod 1000000007
#define INF 0x3f3f3f3f
#define PI (acos(-1.0))
#define lowbit(x) (x&(-x))
#define mp make_pair
#define ls o<<1
#define rs o<<1 | 1
#define lson o<<1, L, mid
#define rson o<<1 | 1, mid+1, R
//#pragma comment(linker, "/STACK:16777216")
typedef long long LL;
typedef unsigned long long ULL;
//typedef int LL;
using namespace std;
LL qpow(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base;base=base*base;b/=2;}return res;}
LL powmod(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base%mod;base=base*base%mod;b/=2;}return res;}
void scanf(int &__x){__x=0;char __ch=getchar();while(__ch==' '||__ch=='\n')__ch=getchar();while(__ch>='0'&&__ch<='9')__x=__x*10+__ch-'0',__ch = getchar();}
LL gcd(LL _a, LL _b){if(!_b) return _a;else return gcd(_b, _a%_b);}
// head

struct Edge
{
int v;
Edge *next;
}E[maxm], *H[maxn], *edges;

int segmax[maxn << 2];
int maxnum[maxn << 2];
int segmin[maxn << 2];
int minnum[maxn << 2];
int lmin[maxn << 2];
int rmin[maxn << 2];
int lmax[maxn << 2];
int rmax[maxn << 2];
int size[maxn];
int son[maxn];
int dep[maxn];
int top[maxn];
int fa[maxn];
int w[maxn];
int val[maxn], val1[maxn];
int z, n, m, ok, res, last;

void addedges(int u, int v)
{
edges->v = v;
edges->next = H[u];
H[u] = edges++;
}

void dfs1(int u)
{
size[u] = 1, son[u] = 0;
for(Edge *e = H[u]; e; e = e->next) {
dep[e->v] = dep[u] + 1;
fa[e->v] = u;
dfs1(e->v);
size[u] += size[e->v];
if(size[son[u]] < size[e->v]) son[u] = e->v;
}
}

void dfs2(int u, int tp)
{
w[u] = ++z, top[u] = tp;
if(son[u]) dfs2(son[u], tp);
for(Edge *e = H[u]; e; e = e->next)
if(e->v != son[u]) dfs2(e->v, e->v);
}

void init(void)
{
z = 0;
edges = E;
memset(H, 0, sizeof H);
memset(son, 0, sizeof son);
}

void read(void)
{
int v, u;
scanf("%d", &n);
for(int i = 1; i <= n; i++) scanf("%d", &val1[i]);
for(v = 2; v <= n; v++) {
scanf("%d", &u);
addedges(u, v);
}
}

void pushup(int o, int L, int R)
{
int mid = (L + R) >> 1;
if(mid - L + 1 == lmax[ls]) {
if(val[mid] < val[mid+1]) lmax[o] = lmax[ls] + lmax[rs];
else lmax[o] = lmax[ls];
}
else lmax[o] = lmax[ls];

if(R - mid == rmax[rs]) {
if(val[mid] < val[mid+1]) rmax[o] = rmax[ls] + rmax[rs];
else rmax[o] = rmax[rs];
}
else rmax[o] = rmax[rs];

segmax[o] = 0;
if(segmax[ls] > segmax[o]) segmax[o] = segmax[ls], maxnum[o] = maxnum[ls];
if(segmax[rs] > segmax[o]) segmax[o] = segmax[rs], maxnum[o] = maxnum[rs];
if(rmax[ls] + lmax[rs] > segmax[o] && val[mid] < val[mid+1])
segmax[o] = rmax[ls] + lmax[rs], maxnum[o] = mid - rmax[ls] + 1;

if(mid - L + 1 == lmin[ls]) {
if(val[mid] > val[mid+1]) lmin[o] = lmin[ls] + lmin[rs];
else lmin[o] = lmin[ls];
}
else lmin[o] = lmin[ls];

if(R - mid == rmin[rs]) {
if(val[mid] > val[mid+1]) rmin[o] = rmin[ls] + rmin[rs];
else rmin[o] = rmin[rs];
}
else rmin[o] = rmin[rs];

segmin[o] = 0;
if(segmin[ls] > segmin[o]) segmin[o] = segmin[ls], minnum[o] = minnum[ls];
if(segmin[rs] > segmin[o]) segmin[o] = segmin[rs], minnum[o] = minnum[rs];
if(rmin[ls] + lmin[rs] > segmin[o] && val[mid] > val[mid+1])
segmin[o] = rmin[ls] + lmin[rs], minnum[o] = mid - rmin[ls] + 1;
}

void build(int o, int L, int R)
{
if(L == R) {
segmax[o] = lmax[o] = rmax[o] = 1;
segmin[o] = lmin[o] = rmin[o] = 1;
maxnum[o] = minnum[o] = L;
return;
}
int mid = (L + R) >> 1;
build(lson);
build(rson);
pushup(o, L, R);
}

int query_max(int o, int L, int R, int ql, int qr)
{
if(ql <= L && qr >= R) return segmax[o];
int mid = (L + R) >> 1;
if(ql <= mid && qr > mid) {
int ans = max(query_max(lson, ql, qr), query_max(rson, ql, qr));
if(val[mid+1] > val[mid]) ans = max(ans, min(qr, mid + lmax[rs]) - max(ql, mid - rmax[ls] + 1) + 1);
return ans;
}
else if(ql <= mid) return query_max(lson, ql, qr);
else return query_max(rson, ql, qr);
}

int query_min(int o, int L, int R, int ql, int qr)
{
if(ql <= L && qr >= R) return segmin[o];
int mid = (L + R) >> 1;
if(ql <= mid && qr > mid) {
int ans = max(query_min(lson, ql, qr), query_min(rson, ql, qr));
if(val[mid+1] < val[mid]) ans = max(ans, min(qr, mid + lmin[rs]) - max(ql, mid - rmin[ls] + 1) + 1);
return ans;
}
else if(ql <= mid) return query_min(lson, ql, qr);
else return query_min(rson, ql, qr);
}

void query_lmax(int o, int L, int R, int ql, int qr)
{
if(ok) return;
if(ql <= L && qr >= R) {
if(last  < val[L]) {
res += lmax[o];
if(lmax[o] != R - L + 1) ok = 1;
last = val[R];
}
else ok = 1;
return;
}
int mid = (L + R) >> 1;
if(ql <= mid) query_lmax(lson, ql, qr);
if(qr > mid) query_lmax(rson, ql, qr);
}

void query_lmin(int o, int L, int R, int ql, int qr)
{
if(ok) return;
if(ql <= L && qr >= R) {
if(last > val[L]) {
res += lmin[o];
if(lmin[o] != R - L + 1) ok = 1;
last = val[R];
}
else ok = 1;
return;
}
int mid = (L + R) >> 1;
if(ql <= mid) query_lmin(lson, ql, qr);
if(qr > mid) query_lmin(rson, ql, qr);
}

void query_rmax(int o, int L, int R, int ql, int qr)
{
if(ok) return;
if(ql <= L && qr >= R) {
if(last > val[R]) {
res += rmax[o];
if(rmax[o] != R - L + 1) ok = 1;
last = val[L];
}
else ok = 1;
return;
}
int mid = (L + R) >> 1;
if(qr > mid) query_rmax(rson, ql, qr);
if(ql <= mid) query_rmax(lson, ql, qr);
}

void query_rmin(int o, int L, int R, int ql, int qr)
{
if(ok) return;
if(ql <= L && qr >= R) {
if(last < val[R]) {
res += rmin[o];
if(rmin[o] != R - L + 1) ok = 1;
last = val[L];
}
else ok = 1;
return;
}
int mid = (L + R) >> 1;
if(qr > mid) query_rmin(rson, ql, qr);
if(ql <= mid) query_rmin(lson, ql, qr);
}

int solve(int a, int b)
{
int f1 = top[a], f2 = top[b];
int ans = 0, lasta = a, lastb = b;
int prea = 0, preb = 0;
int lminv, rminv, lmaxv, rmaxv;
while(f1 != f2) {
if(dep[f1] < dep[f2]) {
ans = max(ans, query_max(1, 1, n, w[f2], w[b]));
ok = res = 0, last = val[w[b]] + 1;
query_rmax(1, 1, n, w[f2], w[b]);
rmaxv = res;
ok = res = 0, last = val[w[f2]] - 1;
query_lmax(1, 1, n, w[f2], w[b]);
lmaxv = res;
if(val[w[lastb]] > val[w[b]]) ans = max(ans, preb + rmaxv);
if(rmaxv == w[b] - w[f2] + 1 && val[w[b]] < val[w[lastb]]) preb = rmaxv + preb;
else preb = lmaxv;
lastb = f2;
b = fa[f2], f2 = top[b];
}
else {
ans = max(ans, query_min(1, 1, n, w[f1], w[a]));
ok = res = 0, last = val[w[a]] - 1;
query_rmin(1, 1, n, w[f1], w[a]);
rminv = res;
ok = res = 0, last = val[w[f1]] + 1;
query_lmin(1, 1, n, w[f1], w[a]);
lminv = res;
if(val[w[lasta]] < val[w[a]]) ans = max(ans, prea + rminv);
if(rminv == w[a] - w[f1] + 1 && val[w[a]] > val[w[lasta]]) prea = rminv + prea;
else prea = lminv;
lasta = f1;
a = fa[f1], f1 = top[a];
}
}
if(a == b) {
if(lasta == a) {
if(prea == 0) prea = 1;
if(val[w[a]] < val[w[lastb]]) ans = max(ans, prea + preb);
}
else if(lastb == b) {
if(preb == 0) preb = 1;
if(val[w[b]] > val[w[lasta]]) ans = max(ans, prea + preb);
}
else {
if(val[w[a]] > val[w[lasta]]) ans = max(ans, prea+1);
if(val[w[a]] < val[w[lastb]]) ans = max(ans, preb+1);
if(val[w[a]] > val[w[lasta]] && val[w[a]] < val[w[lastb]]) ans = max(ans, prea + preb + 1);
}
return ans;
}
if(a != b) {
if(dep[a] < dep[b]) {
ans = max(ans, query_max(1, 1, n, w[a], w[b]));
ok = res = 0, last = val[w[a]] - 1;
query_lmax(1, 1, n, w[a], w[b]);
lmaxv = res;
ok = res = 0, last = val[w[b]] + 1;
query_rmax(1, 1, n, w[a], w[b]);
rmaxv = res;
if(val[w[lastb]] > val[w[b]]) ans = max(ans, preb + rmaxv);
if(rmaxv == w[b] - w[a] + 1 && val[w[b]] < val[w[lastb]]) preb = rmaxv + preb;
else preb = lmaxv;
lastb = a;
b = a;
}
else {
ans = max(ans, query_min(1, 1, n, w[b], w[a]));
ok = res = 0, last = val[w[b]] + 1;
query_lmin(1, 1, n, w[b], w[a]);
lminv = res;
ok = res = 0, last = val[w[a]] - 1;
query_rmin(1, 1, n, w[b], w[a]);
rminv = res;
if(val[w[lasta]] < val[w[a]]) ans = max(ans, prea + rminv);
if(rminv == w[a] - w[b] + 1 && val[w[a]] > val[w[lasta]]) prea = rminv + prea;
else prea = lminv;
lasta = b;
a = b;
}
}
if(a == lasta) {
if(preb == 0) preb = 1;
if(val[w[lastb]] > val[w[b]]) ans = max(ans, prea + preb);
}
else if(b == lastb) {
if(prea == 0) prea = 1;
if(val[w[lasta]] < val[w[a]]) ans = max(ans, prea + preb);
}
else {

}

return ans;
}
void work(void)
{
fa[1] = 1;
dfs1(1);
dfs2(1, 1);
for(int i = 1; i <= n; i++) val[w[i]] = val1[i];
build(1, 1, n);
int a, b;
scanf("%d", &m);
while(m--) {
scanf("%d%d", &a, &b);
if(a == b) {
printf("1\n");
continue;
}
printf("%d\n", solve(a, b));
}
}

int main(void)
{
int _, __;
while(scanf("%d", &_)!=EOF) {
__ = 0;
while(_--) {
init();
read();
printf("Case #%d:\n", ++__);
work();
if(_) printf("\n");
}
}

return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: