您的位置:首页 > 其它

CodeChef PrimeDST【点分治】【FFT】

2015-07-09 08:57 627 查看
/* I will wait for you */

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <ctime>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <vector>
#include <queue>
#include <deque>
#include <set>
#include <map>
#include <string>
#define make(a,b) make_pair(a,b)
#define fi first
#define se second

using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef map<int, int> mii;

const int maxn = 200010;
const int maxm = 1010;
const int maxs = 26;
const int inf = 0x3f3f3f3f;
const int P = 1000000007;
const double error = 1e-9;
const double Pi = 3.1415926535897932;

inline ll read()
{
ll x = 0, f = 1; char ch = getchar();
while (ch > '9' || ch < '0' )
f = (ch == '-' ? -1 : 1), ch = getchar();
while (ch <= '9' && ch >= '0')
x = x * 10 + ch - '0', ch = getchar();
return f * x;
}

struct complex
{
double re, im;
} _x[maxn], w[2][maxn];

complex operator + (complex a, complex b)
{
complex c;
c.re = a.re + b.re;
c.im = a.im + b.im;
return c;
}

complex operator - (complex a, complex b)
{
complex c;
c.re = a.re - b.re;
c.im = a.im - b.im;
return c;
}

complex operator * (complex a, complex b)
{
complex c;
c.re = a.re * b.re - a.im * b.im;
c.im = a.re * b.im + a.im * b.re;
return c;
}

struct edge
{
int v, next;
} e[maxn];

int n, root, sum, _maxdeep, maxdeep, head[maxn],
pri[maxn], _max[maxn], size[maxn], cnt, g[maxn],
f[maxn], deep[maxn], del[maxn], rev[maxn];
ll ans;

void insert(int u, int v)
{
e[cnt] = (edge) {v, head[u]}, head[u] = cnt++;
}

void _deep(int u, int p)
{
g[deep[u]] += 1, maxdeep = max(maxdeep, deep[u]);
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].v;
if (v != p && !del[v])
deep[v] = deep[u] + 1, _deep(v, u);
}
}

void _size(int u, int p)
{
size[u] = 1;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].v;
if (v != p && !del[v])
_size(v, u), size[u] += size[v];
}
}

void _find(int u, int p)
{
size[u] = 1, _max[u] = 0;
for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].v;
if (v != p && !del[v]) {
_find(v, u), size[u] += size[v];
_max[u] = max(_max[u], size[v]);
}
}
_max[u] = max(_max[u], sum - size[u]);
if (_max[u] < _max[root]) root = u;
}

void FFT(complex *a, int n, int f)
{
for (int i = 0; i < n; i++) {
rev[i] = 0;
for (int j = i, k = 1; k < n; k <<= 1, j >>= 1)
(rev[i] <<= 1) |= (j & 1);
if (rev[i] > i) swap(a[i], a[rev[i]]);
}

for (int i = 0; i < n; i++) {
w[0][i].re = cos(2 * Pi * i / n);
w[0][i].im = sin(2 * Pi * i / n);
w[1][i].re = cos(2 * Pi * i / n);
w[1][i].im = -sin(2 * Pi * i / n);
}

for (int i = 1; i < n; i <<= 1)
for (int j = 0, l = n / (i << 1); j < n; j += (i << 1))
for (int k = 0, t = 0; k < i; k += 1, t += l) {
complex x = a[j + k], y = w[f][t] * a[i + j + k];
a[j + k] = x + y, a[i + j + k] = x - y;
}
for (int i = 0; f && i < n; i++) a[i].re /= n;
}

void _solve(int *a, int n, int f)
{
int len = 1;
while (1 << len < n << 1) len += 1;

for (int i = 0; i < 1 << len; i++)
_x[i].re = _x[i].im = 0;
for (int i = 0; i < n; i++)
_x[i].re = a[i];

FFT(_x, 1 << len, 0);
for (int i = 0; i < 1 << len; i++)
_x[i] = _x[i] * _x[i];
FFT(_x, 1 << len, 1);

for (int i = 0; i < 1 << len; i++) {
if (!pri[i] && f == 1)
ans += (ll) (_x[i].re + 0.5);
if (!pri[i] && f == -1)
ans -= (ll) (_x[i].re + 0.5);
}
}

void solve(int u)
{
del[u] = 1, _size(u, 0);
f[0] = 1, _maxdeep = 0;
for (int i = 1; i <= size[u]; i++)
f[i] = 0;

for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].v;
if (!del[v]) {
maxdeep = 0;
for (int i = 0; i <= size[v]; i++)
g[i] = 0;

deep[v] = 1, _deep(v, u);
_solve(g, maxdeep + 1, -1);

_maxdeep = max(_maxdeep, maxdeep);
for (int i = 0; i <= size[v]; i++)
f[i] += g[i];
}
}
_solve(f, _maxdeep + 1, 1);

for (int i = head[u]; ~i; i = e[i].next) {
int v = e[i].v;
if (!del[v]) {
sum = size[v], root = 0;
_find(v, u), solve(root);
}
}

}

void init()
{
pri[0] = pri[1] = 1;
for (int i = 2; i < n; i++)
if (!pri[i])
for (int j = 2 * i; j < n; j += i)
pri[j] = 1;
}

int main()
{
n = read(), init();

memset(head, -1, sizeof head);
for (int i = 1; i < n; i++) {
int u = read(), v = read();
insert(u, v), insert(v, u);
}

sum = _max[0] = n, root = 0;
_find(1, 0), solve(root);

printf("%.6f\n", 1.0 * ans / n / (n - 1));

return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: