您的位置:首页 > 理论基础 > 数据结构算法

SPOJ Count on a tree II(树上莫队)

2017-11-07 23:38 323 查看
 debug 到想吐....

 各种撒比错误,一晚上就没有了,

 总结如下几点:

两个不同参数的数组,(n,m) 的最大值不一样,最好开到同样大

树上莫队注意重复节点的拆分

 树型数据简单生成技巧:

* i rand()%i

树上莫队的桶是 (q[i].l/S) not u/S( saaaa…)

题目链接

Count on a tree II:

分析

 如果你学了树上莫队,对这题应该不会陌生,将树搞成链

AC code

#include <bits/stdc++.h>
using namespace std;
#define ms(x,v) (memset((x),(v),sizeof(x)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
typedef long long LL;
typedef pair<int,int > Pair;
const int maxn = 1e5+10;
const int max_query = 1e5+10;
const int MAX_LOG = 19;

int n,m;
int S;

struct Query{
int id,l,r,lc,backet;
bool operator<(const Query & o)const{
return backet==o.backet?(backet&1?r>o.r : r <o.r) : l<o.l;
}
};
Query q[max_query];
int ret[max_query],ans=0;
int dfn[maxn<<1],st[maxn],ed[maxn],dft=0;
int tmp[maxn],a[maxn];
int dep[maxn];
int fa[maxn][MAX_LOG];
int vis[maxn],cnt[maxn];
std::vector<int> G[maxn];
void dfs(int u,int f) {
dep[u] = dep[f]+1;
fa[u][0] = f;
st[u] = ++dft;
dfn[dft] = u;
for(int i=1 ; (1<<i) <= dep[u] ; ++i)fa[u][i] = fa[fa[u][i-1]][i-1];
for(auto v : G[u]){
if(v == f)continue;
dfs(v,u);
}
ed[u] = ++dft;
dfn[dft] = u;
}
inline int lca(int u,int v){
if(dep[u]<dep[v])swap(u,v);
int bin = dep[u] - dep[v];
for(int i=0 ; i<MAX_LOG ; ++i)if(bin>>i & 1)u = fa[u][i];
if(u==v) return u;
for(int i= MAX_LOG-1 ; i>=0 ; --i)
if(fa[u][i]!=fa[v][i]) u = fa[u][i],v = fa[v][i];
return fa[u][0];
}
inline int move(int node){
if(vis[node] && --cnt[a[node]]==0)ans--;
else if(!vis[node] && ++cnt[a[node]]==1) ans++;
vis[node] ^=1;
}
void mo() {
int curL = q[0].l ,curR = q[0].l-1;
for(int i=0 ; i<m ; ++i){
int L = q[i].l,R=q[i].r;
while (curL > L)move(dfn[--curL]);
while (curR < R)move(dfn[++curR]);
while (curL < L)move(dfn[curL++]);
while (curR > R)move(dfn[curR--]);
ret[q[i].id] = ans;
if(!cnt[a[q[i].lc]])ret[q[i].id]++;
}
}

int main(int argc, char const *argv[]) {
scanf("%d%d",&n,&m );
ms(vis,0);
ms(cnt,0);
for(int i=1 ; i<=n ; ++i)scanf("%d", a+i),tmp[i] =a[i];
sort(tmp+1,tmp+n+1);
for(int i=1 ; i<=n ; ++i)a[i] = lower_bound(tmp+1,tmp+n+1,a[i]) - tmp;
for(int i=1 ; i<n ; ++i){
int u,v;
scanf("%d%d",&u,&v );
G[u].pb(v);G[v].pb(u);
}
S = sqrt(2.0*n)+1;
dfs(1,0);
// for(int i=1 ; i<=dft ; ++i)std::cout << dfn[i] << ' ';std::cout  << '\n';
// for(int i=1 ; i<=dft ; ++i)std::cout << a[dfn[i]] << ' ';std::cout  << '\n';

for(int i=0 ; i<m ; ++i){
int u,v;
scanf("%d%d",&u,&v );
// std::cout << u << " " << v << '\n';
int p = lca(u,v);
if(st[u]>st[v])swap(u,v);
q[i].id = i;
q[i].lc = p;
if(p==u){q[i].l = st[u];q[i].r = st[v];}
else {q[i].l = ed[u];q[i].r = st[v];}
q[i].backet = q[i].l/S;
// std::cout << q[i].l << " "<< q[i].r<< " " << q[i].backet<< '\n';
// std::cout << q[i].lc << '\n';
}
sort(q,q+m);
mo();
for(int i=0 ; i<m ; ++i)printf("%d\n",ret[i]);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息