您的位置:首页 > 其它

[HDU4812]D Tree

2018-02-26 14:37 169 查看
vjudge

题意:给一棵树,每个点上有一个权值,求一条路径使得路径上权值的乘积膜\(10^6+3\)的结果为\(K\),输出路径的两个端点\(x,y\)。如有多解,设\(x<y\),输出\(x\)最小的,若仍有多解输出\(y\)最小的。

sol

点分。

每次考虑所有过重心的路径,开一个桶\(T[x]\)表示到根路径权值乘积(不算根的权值)为\(x\)的最小节点编号。

注意要先查出所有点到根的权值乘积,全部更新答案,再去更新桶\(T\)

更新答案的时候用逆元。逆元可以线性预处理出来。

记得要设\(T[1]=u\),做完这一层之后也要把\(T[1]\)清空

code

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int gi()
{
int x=0,w=1;char ch=getchar();
while ((ch<'0'||ch>'9')&&ch!='-') ch=getchar();
if (ch=='-') w=0,ch=getchar();
while (ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return w?x:-x;
}
const int N = 1e5+5;
const int mod = 1e6+3;
int n,k,inv[mod],val
,to[N<<1],nxt[N<<1],head
,cnt;
int sz
,w
,root,sum,vis
,T[mod],dep
,tmp
,top,ans1,ans2;
void link(int u,int v){to[++cnt]=v;nxt[cnt]=head[u];head[u]=cnt;}
void getroot(int u,int f)
{
sz[u]=1;w[u]=0;
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (v==f||vis[v]) continue;
getroot(v,u);
sz[u]+=sz[v];w[u]=max(w[u],sz[v]);
}
w[u]=max(w[u],sum-sz[u]);
if (w[u]<w[root]) root=u;
}
void getdeep(int u,int f,int sta)
{
dep[u]=sta;tmp[++top]=u;
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (v==f||vis[v]) continue;
getdeep(v,u,1ll*sta*val[v]%mod);
}
}
void solve(int u)
{
vis[u]=1;T[1]=u;
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (vis[v]) continue;
top=0;getdeep(v,0,val[v]);
for (int i=1;i<=top;++i)
{
int xx=1ll*dep[tmp[i]]*val[u]%mod,yy=1ll*k*inv[xx]%mod;
int x=tmp[i],y=T[1ll*k*inv[1ll*dep[tmp[i]]*val[u]%mod]%mod];
if (!y) continue;
if (x>y) swap(x,y);
if (x<ans1||(x==ans1&&y<ans2)) ans1=x,ans2=y;
}
for (int i=1;i<=top;++i) if (!T[dep[tmp[i]]]||tmp[i]<T[dep[tmp[i]]]) T[dep[tmp[i]]]=tmp[i];
}
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (vis[v]) continue;
top=0;getdeep(v,0,val[v]);
for (int i=1;i<=top;++i) T[dep[tmp[i]]]=0;
}
T[1]=0;
for (int e=head[u];e;e=nxt[e])
{
int v=to[e];if (vis[v]) continue;
sum=sz[v];root=0;
getroot(v,0);
solve(root);
}
}
int main()
{
inv[1]=1;
for (int i=2;i<mod;++i) inv[i]=mod-1ll*(mod/i)*inv[mod%i]%mod;
while (scanf("%d %d",&n,&k)!=EOF)
{
memset(head,0,sizeof(head));cnt=0;
memset(vis,0,sizeof(vis));ans1=ans2=1e9;
for (int i=1;i<=n;++i) val[i]=gi();
for (int i=1;i<n;++i)
{
int u=gi(),v=gi();
link(u,v);link(v,u);
}
root=0;sum=w[0]=n;
getroot(1,0);
solve(root);
if (ans1==1e9) puts("No solution");
else printf("%d %d\n",ans1,ans2);
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: