您的位置:首页 > 其它

「6月雅礼集训 2017 Day2」A

2017-06-18 16:57 375 查看
【题目大意】

给出一棵树,求有多少对点(u,v)满足其路径上不存在两个点a,b满足(a,b)=1

n<=10^5

【题解】

考虑找出所有不符合的点对,共有n*ln(n)对,他们要么是祖先->儿子边,要么是不是。

考虑祖先->儿子边,那么一个点在祖先以上,一个点在儿子以下的点对全部无法访问。

考虑另外一种边,就是LCA不是两个端点的,这就比较好统计了,两个点在这两棵子树的点对无法访问。

考虑用DFS序,这样子树就是连续的一段(祖先以上是连续两段)

然后就是一个二维覆盖问题,用扫描线+线段树即可解决。

复杂度O(nln(n)logn)

注意。。扫描线数组要开到 4 * n * ln(n) 不然。。会奇怪的WA/RE。。。

# include <stdio.h>
# include <string.h>
# include <iostream>
# include <algorithm>

using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;

# define RG register
# define ST static

const int M = 2e5 + 10, N = 1e5 + 10, Max = 8 * M;
const int mod = 998244353;

int n, head
, nxt[M], to[M], tot = 0;
inline void add(int u, int v) {
++tot; nxt[tot] = head[u]; head[u] = tot; to[tot] = v;
}
inline void adde(int u, int v) {
add(u, v), add(v, u);
}

int in
, out
, DFN = 0;
int dep
, fa
[19];
inline void dfs(int x, int fat = 0) {
in[x] = ++DFN; dep[x] = dep[fat] + 1;
fa[x][0] = fat;
for (int i=1; i<=18; ++i) fa[x][i] = fa[fa[x][i-1]][i-1];
for (int i=head[x]; i; i=nxt[i]) {
if(to[i] == fat) continue;
dfs(to[i], x);
}
out[x] = DFN;
}

inline int lca(int u, int v) {
if(dep[u] < dep[v]) swap(u, v);
for (int i=18; ~i; --i)
if((dep[u] - dep[v]) & (1<<i)) u = fa[u][i];
if(u == v) return u;
for (int i=18; ~i; --i)
if(fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
return fa[u][0];
}

inline int jump(int u, int anc) {
for (int i=18; ~i; --i)
if(dep[fa[u][i]] > dep[anc]) u = fa[u][i];
return u;
}

struct pa {
int x, yl, yr, d;
pa() {}
pa(int x, int yl, int yr, int d) : x(x), yl(yl), yr(yr), d(d) {}
friend bool operator < (pa a, pa b) {
return a.x < b.x;
}
}p[Max * 4]; int pn = 0;

inline void ADD(int xl, int xr, int yl, int yr) {
p[++pn] = pa(xl, yl, yr, 1);
p[++pn] = pa(xr+1, yl, yr, -1);
}

inline void doit(int x, int y) {
int par = lca(x, y);
//    if(par == -1) cout << x << ' ' << y << endl;
if(dep[x] > dep[y]) swap(x, y);
if(x == par) {
int pars = jump(y, par);
ADD(1, in[pars] - 1, in[y], out[y]);
ADD(in[y], out[y], out[pars] + 1, n);
} else {
if(in[x] > in[y]) swap(x, y);
ADD(in[x], out[x], in[y], out[y]);
}
}

struct SMT {
int w[Max], tag[Max];
# define ls (x<<1)
# define rs (x<<1|1)
inline void set() {
memset(w, 0, sizeof w);
memset(tag, 0, sizeof tag);
}
inline int gs(int x, int l, int r) {
if(tag[x]) return r-l+1;
else return w[x];
}
inline void edt(int x, int l, int r, int L, int R, int d) {
if(L > R) return ;
if(L <= l && r <= R) {tag[x] += d; return ;}
int mid = l+r>>1;
if(L <= mid) edt(ls, l, mid, L, R, d);
if(R > mid) edt(rs, mid+1, r, L, R, d);
w[x] = gs(ls, l, mid) + gs(rs, mid+1, r);
}
inline int sum(int x, int l, int r, int L, int R) {
if(L > R) return 0;
if(tag[x]) return min(R, r) - max(L, l) + 1;
if(L <= l && r <= R) return gs(x, l, r);
int mid = l+r>>1, ret = 0;
if(L <= mid) ret += sum(ls, l, mid, L, R);
if(R > mid) ret += sum(rs, mid+1, r, L, R);
return ret;
}
# undef ls
# undef rs
}T;

int main() {
//    freopen("A.in", "r", stdin);
//    freopen("A.out", "w", stdout);
cin >> n;
for (int i=1, u, v; i<n; ++i) {
scanf("%d%d", &u, &v);
adde(u, v);
}
dfs(1, 0);
for (int i=1; i<=n; ++i)
for (int j=i+i; j<=n; j+=i) doit(i, j);

sort(p+1, p+pn+1); T.set();
ll ans = (ll)n * (n-1) / 2;
for (int i=1, j=1; i<=n; ++i) {
while(j<=pn && p[j].x == i) T.edt(1, 1, n, p[j].yl, p[j].yr, p[j].d), ++j;
ans -= T.sum(1, 1, n, i+1, n);
}
cout << ans;
return 0;
}


View Code
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: