您的位置:首页 > 其它

BZOJ 4919 大根堆(LIS)

2018-03-29 16:59 295 查看
题目链接:BZOJ 4919

题目大意:一棵n个节点的数,从中选择尽可能多的节点,满足:对于任意两个点i,j,如果i在树上是j的祖先,那么v_i>v_j。求可选的最多的点数。

题解:先考虑一条链的情况,就是求一个LIS。再考虑一棵树。再次回想LIS经典求法,维护一个序列,即每次新加一个元素时,在维护的序列中找到比它大的第一个元素,替换,最后维护序列的长度就是LIS的长度。所以可以在每个节点开一个multiset维护LIS的情况。并且,各子树之间是互不影响的,可以直接合并。

具体到代码上的话,就像这样:

void merge(int x,int y)
{
if (s[x].size()>s[y].size()) swap(s[x],s[y]);
multiset<int>::iterator it=s[x].begin();
while (it!=s[x].end()) s[y].insert(*it),it++;
s[x].clear();
}
void dfs(int now)
{
for (int i=head[now];i;i=e[i].ne)
{
int v=e[i].to;
dfs(v); merge(v,now);
}
multiset<int>::iterator it=s[now].lower_bound(a[now]);
if (it!=s[now].end()) s[now].erase(it);
s[now].insert(a[now]);
}


这个
erase()
操作,之前一直不太理解,总觉得要用
while()
把所有大于等于
a[now]
的全都删掉。被dalao教育之后才意识到,加入新的元素后LIS的长度一定是不降的┭┮﹏┭┮。并且,如果被替换掉的是multiset里最大的元素,就表示当前点的值加入当前维护的LIS中,并在维护的LIS中删除被替换的值;如果替换之后,multiset里有比当前节点的值更大的元素,就表示不将当前节点的值加入LIS,LIS的长度不变;如果没有更大的元素,就直接加入当前的值,LIS长度+1。

感受一下的话,每次新加入元素的时候其实只跟当前LIS里最大的元素有关,要么是加入新元素、最大元素变小、LIS长度不变,要么是加入新元素、最大元素变大、LIS长度+1,要么是不加入新元素、LIS不变。我们最后需要的只是LIS的长度,所以
a[now]
*it
是等效的,可以这样替换。

总结一下,这个题比较神奇的就是发现子树互不影响和这种等效替换(也就是之前提到的经典求LIS的原理)。

code(有参考WerKeyTom_FTD的BLOG (✺ω✺))


#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<set>
#define N 200005
using namespace std;
inline int read()
{
char c=getchar(); int num=0,f=1;
while (c<'0'||c>'9') { if (c=='-') f=-1; c=getchar(); }
while (c<='9'&&c>='0') { num=num*10+c-'0'; c=getchar(); }
return num*f;
}
struct edge{
int to,ne;
}e[N<<1];
multiset<int> s
;
int n,m,tot,a
,head
;
void push(int x,int y) { e[++tot].to=y; e[tot].ne=head[x]; head[x]=tot; }
void merge(int x,int y) { if (s[x].size()>s[y].size()) swap(s[x],s[y]); multiset<int>::iterator it=s[x].begin(); while (it!=s[x].end()) s[y].insert(*it),it++; s[x].clear(); } void dfs(int now) { for (int i=head[now];i;i=e[i].ne) { int v=e[i].to; dfs(v); merge(v,now); } multiset<int>::iterator it=s[now].lower_bound(a[now]); if (it!=s[now].end()) s[now].erase(it); s[now].insert(a[now]); }
int main()
{
n=read();
for (int i=1;i<=n;i++) a[i]=read(),push(read(),i);
dfs(1);
printf("%d",s[1].size());
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: