您的位置:首页 > 其它

HDOJ 4670: Cube number on a tree

2013-08-24 17:27 423 查看
题目链接:

http://acm.hdu.edu.cn/showproblem.php?pid=4670

题目大意:

树的每个节点有一个点权,所有的点权都可以被给定的30个质数表示出来。

在树上找合法点对。

合法点对指的是,两点间路径上的所有点(含端点)的点权乘积是立方数的点对。

注意:点对中的两个点可以是相同的,这个坑了我好久,切~

算法:

树的点分治。

每次处理的时候,用一个map保存之前遍历过的子树中的节点到根的路径值,另一个保存当前正在遍历的这棵子树里的路径值。

然后很容易就可以求出以这个根为LCA的合法点对。

PS:这么个题还写了4K,有点儿伤哦。。下回果断搞个FOREACH宏

树分治技能初步get,耶!

代码:

#pragma comment(linker,"/STACK:102400000,102400000")
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<sstream>
#include<cstdlib>
#include<cstring>
#include<string>
#include<climits>
#include<cmath>
#include<queue>
#include<vector>
#include<stack>
#include<set>
#include<map>
#define INF 0x3f3f3f3f
#define eps 1e-8
#define mp make_pair
#define pb push_back
#define st first
#define nd second
using namespace std;

const int MAXN = 50010;
const int MAXM = 30;
typedef long long LL;

int siz[MAXN], maxb[MAXN];
bool vis[MAXN];
vector <int> node;
vector <int> mm[MAXN];
LL tnum[MAXN];
LL prim[MAXM];
LL pow3[MAXM];
map <LL, LL> map1, map2;
typedef map <LL, LL> :: iterator mapit;
int n,k;
LL ans;

inline LL hash(LL x)
{
LL tmp = 0LL;
for (int i = 0; i < k; i ++)
{
int cot = 0;
while (x % prim[i] == 0)
{
x /= prim[i];
cot ++;
}
tmp += (cot % 3) * pow3[i];
}
return tmp;
}

inline LL fadd(LL x, LL y)
{
LL tmp = 0LL;
for (int i = 0 ; i < k; i ++)
{
LL ret1 = (x / pow3[i]) % 3;
LL ret2 = (y / pow3[i]) % 3;
tmp += (ret1 + ret2) % 3 * pow3[i];
}
return tmp;
}

inline LL frev(LL x)
{
LL tmp = 0LL;
for (int i = 0 ; i < k; i ++)
{
LL ret = (x / pow3[i]) % 3;
tmp += ((3 - ret) %3) * pow3[i];
}
return tmp;
}

void pre_dfs(int u, int p)
{
node.pb(u);
maxb[u] = 0;
siz[u] = 1;
for (int i = 0; i < mm[u].size(); i ++)
{
int v = mm[u][i];
if (v == p || vis[v])
{
continue;
}
pre_dfs(v, u);
siz[u] += siz[v];
maxb[u] = max(maxb[u], siz[v]);
}
}

void dfs(int u, int p, LL tmp)
{
mapit it;
(it = map2.find(tmp)) == map2.end() ?
map2[tmp] = 1 : (it -> nd) ++;
for (int i = 0; i < mm[u].size(); i ++)
{
int v = mm[u][i];
if (v == p || vis[v])
{
continue;
}
dfs(v, u, fadd(tmp, tnum[v]));
}
}

void cal(int root)
{
map1[0] = 1LL;
if(! tnum[root])
{
ans ++;
}
for (int i = 0; i < mm[root].size(); i ++)
{
int v = mm[root][i];
if (vis[v])
{
continue;
}
dfs(v, root, tnum[v]);
for (mapit it1 = map2.begin(); it1 != map2.end(); it1 ++)
{
mapit it2 = map1.find(frev(fadd(it1 -> st, tnum[root])));
if (it2 != map1.end())
{
ans += (it1 -> nd) * (it2 -> nd);
}
}
for (mapit it1 = map2.begin(); it1 != map2.end(); it1 ++)
{
mapit it2;
(it2 = map1.find(it1 -> st)) == map1.end() ?
map1[it1 -> st] = it1 -> nd : (it2 -> nd) += (it1 -> nd);
}
map2.clear();
}
map1.clear();
}

void solve(int u)
{
node.clear();
pre_dfs(u, -1);
int num = node.size();
int root, tmp = INT_MAX;
for (int i = 0; i < num; i ++)
{
maxb[node[i]] = max(maxb[node[i]], num - maxb[node[i]] - 1);
if (tmp > maxb[node[i]])
{
tmp = maxb[node[i]];
root = node[i];
}
}
vis[root] = true;
cal(root);
for (int i = 0; i < mm[root].size(); i ++)
{
int v = mm[root][i];
if (!vis[v])
{
solve(v);
}
}
}

int main()
{
pow3[0] = 1LL;
for(int i = 1; i < MAXM; i ++)
{
pow3[i] = pow3[i - 1] * 3;
}
while (scanf("%d %d", &n, &k) == 2)
{
ans = 0LL;
memset(vis, 0, sizeof(vis));
for (int i = 0; i < n; i ++)
{
mm[i].clear();
}
for (int i = 0; i < k; i ++)
{
scanf("%I64d", &prim[i]);
}
for (int i = 0; i < n; i ++)
{
scanf("%I64d", &tnum[i]);
tnum[i] = hash(tnum[i]);
}
for (int i = 1; i < n; i ++)
{
int u, v;
scanf("%d %d", &u, &v);
u --;
v --;
mm[u].push_back(v);
mm[v].push_back(u);
}
solve(0);
printf("%I64d\n", ans);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: