您的位置:首页 > 其它

【JZOJ3872】【NOIP2014八校联考第4场第1试10.19】圣诞树(tree)

2017-01-18 16:29 423 查看

Description

圣诞节到了,小可可送给小薰一棵圣诞树。这棵圣诞树很奇怪,它是一棵多叉树,有n个点,n-1条边。它的每个结点都有一个权值。小可可和小薰想用这棵树玩一个游戏。

定义(s,e)为树上从s到e的简单路径,我们可以记下在这条路径上经过的结点,定义这个结点序列为S(s,e)。

我们按照如下方法定义这个序列S(s,e)的权值G(S(s,e)):假设这个序列中结点的权值为Z0,Z1,…,Z(L-1),其中L为序列的长度,我们定义G(S(s,e))=Z0 × k^0 + Z1 × k^1 + … + Z(L-1) × k^(L-1)。

如果路径(s,e)满足G(S(s,e)) ≡ x (mod y) ,那么这条路径属于小可可,否则这条路径属于小薰。小可可和小薰很显然不希望这个游戏变得那么简单。小薰认为如果路径(p1,p2)和(p2,p3)都属于他,那么路径(p1,p3)也属于他,反之如果路径(p1,p2)和(p2,p3)都属于小可可,那么路径(p1,p3)也属于小可可。然而这个性质并不总是正确的。所以小薰想知道到底有多少三元组(p1,p2,p3)满足这个性质。

小薰表示她看一眼就知道这道题怎么做了。你会吗?

Data Constraint

对于20%的数据,n ≤ 200;

对于50%的数据,n ≤ 10^4;

对于100%的数据,1 ≤ n ≤ 10^5,2 ≤ y ≤ 10^9,1 ≤ k ≤ y,0 ≤ x < y。

Solution

我们设出in0[i]表示i出发的最后能做到≡ x (mod y)的个数,in1[i]表示i出发的最后不能做到≡ x (mod y)的个数。out0[i],out1[i]同理。直接求满足性质的数目很难求,所以我们正难则反,求不满足性质的个数。那么就有t=∑ni=1in1[i]∗in0[i]∗2+out0[i]∗out1[i]∗2+in0[i]∗out1[i]+in1[i]∗out0[i]

由于每个三角形最后被算了两遍,所以ans=n3−t/2。求in1[i],out1[i]很难求,但我们发现in1[i]+in0[i]=out0[i]+out1[i]=n,所以现在只要求出in0[i],out0[i]即可。

我们可以用树分治来解决。每次跳到当前子树的重心,然后把这棵子树暴力遍历一遍,求出子树内每个点x到重心y的G(S(x,y))和G(S(y,x))。那么现在考虑一下合并。显然我们想要求的还有经过y的两两子节点的G值。那么就有G(S(x,x1))=G(S(x,y))+G(S(y,x1))∗klenS(x,y)。我们要满足G(S(x,x1))≡ x (mod y),那么G(S(y,x1))=X−G(S(x,y))klenS(x,y)。我们只要在事先求出的G(S(y,x))中二分查找一下即可求出对于该G(S(x,y))满足G(S(x,x’))≡ x (mod y)的数量,加在out0[x]上,同时在对应的y上加1。问题来了,怎样在在对应的y上加1用是最短呢?假定[a,b]的区间均要加1,我们直接在a加1,b+1减1,最后扫一遍即可。同时,由于两个点x,y可能来自同一棵子树,所以我们要在重心的每个儿子各自再遍历一遍,做一遍上述操作,减去这些情况。最后再跳到重心的儿子上,找出那个子树的的重心……时间复杂度O(Nlog2N)

Code

#include<iostream>
#include<cmath>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
const ll maxn=1e5+5;
ll first[maxn],last[2*maxn],next[2*maxn],a[maxn],f[maxn],g[maxn],size[maxn],bz[maxn];
ll mx[maxn],h[maxn],in[maxn],out[maxn],d[maxn],er[maxn],f1[maxn];
struct code{
ll a,b,c;
}b[maxn],c[maxn];
ll n,m,i,t,j,k,l,y,x,z,num,p,q,num1,ln,ans;
void lian(ll x,ll y){
last[++num]=y;next[num]=first[x];first[x]=num;
}
bool cmp(code x,code y){
return x.a<y.a;
}
ll mi(ll x,ll y){
if (y==1) return x;
ll t=mi(x,y/2);
if (y%2) return t*t%p*x%p;return t*t%p;
}
void dg(int x,int y){
int t;size[x]=1;mx[x]=0;
for (t=first[x];t;t=next[t]){
if (last[t]==y || bz[last[t]]) continue;
dg(last[t],x);
mx[x]=max(mx[x],size[last[t]]);
size[x]+=size[last[t]];
}
}
int dg1(int x,int y,int z){
int t,k;mx[x]=max(mx[x],size[z]-size[x]);
if (mx[x]<=size[z]/2) return x;
for (t=first[x];t;t=next[t]){
if (last[t]==y || bz[last[t]]) continue;
k=dg1(last[t],x,z);
if (k) return k;
}
return 0;
}
void dg2(int x,int y,ll z){
int t,k=num;
for (t=first[x];t;t=next[t]){
if (last[t]==y || bz[last[t]]) continue;
b[++num].a=(b[k].a+z*a[last[t]])%p;b[num].b=b[k].b+1;
c[num].a=(c[k].a*m+a[last[t]])%p;c[num].b=c[k].b+1;
b[num].c=c[num].c=last[t];
dg2(last[t],x,z*m%p);
}
}
void out0(ll x){
int i,j,t,k,l,r,mid;b[num+1].a=p+1;b[0].a=-1;
for (i=1;i<=num;i++){
l=0;
r=num+1;
y=(q-c[i].a+p)%p*f1[c[i].b]%p;
while (l<r){
mid=(l+r)/2;
if (b[mid].a>=y) r=mid;
else l=mid+1;
}
t=l;
l=0;
r=num+1;
while (l<r){
mid=(l+r+1)/2;
if (b[mid].a>y) r=mid-1;
else l=mid;
}
k=l;
if (t<=k) d[k+1]-=x,d[t]+=x,out[c[i].c]+=x*(k-t+1);
}
for (i=1;i<=num;i++)
d[i]+=d[i-1],in[b[i].c]+=d[i];
for (i=1;i<=num+2;i++)
d[i]=0;
}
void make(ll x){
int t,k,i;
dg(x,0);
h[x]=dg1(x,0,x);bz[h[x]]=1;num=1,b[1].a=0,c[1].a=a[h[x]],b[1].b=0,c[1].b=1;b[1].c=c[1].c=h[x];
dg2(h[x],0,1);
sort(b+1,b+num+1,cmp);sort(c+1,c+num+1,cmp);
out0(1);
for (t=first[h[x]];t;t=next[t])
if (!bz[last[t]]){
num=1,b[1].a=0,c[1].a=a[last[t]],b[1].b=0,c[1].b=1;b[1].c=c[1].c=last[t];
dg2(last[t],0,1);
for (i=1;i<=num;i++)
c[i].a=(c[i].a+a[h[x]]*er[c[i].b])%p,b[i].b++,b[i].a=(b[i].a*m+a[last[t]])%p,c[i].b++;
sort(b+1,b+num+1,cmp);sort(c+1,c+num+1,cmp);
out0(-1);
}
for (t=first[h[x]];t;t=next[t])
if (!bz[last[t]]) make(last[t]);
}
int main(){
//  freopen("data.in","r",stdin);
scanf("%lld%lld%lld%lld",&n,&p,&m,&q);er[0]=1;
for (i=1;i<=n;i++)
er[i]=er[i-1]*m%p,f1[i]=mi(er[i],p-2);
for (i=1;i<=n;i++)
scanf("%lld",&a[i]);
for (i=1;i<n;i++)
scanf("%lld%lld",&x,&y),lian(x,y),lian(y,x);
make(1);memset(bz,0,sizeof(bz));t=0;ans=n*n*n;
for (i=1;i<=n;i++){
x=n-in[i];y=n-out[i];
t+=2*(x*in[i]+y*out[i])+in[i]*y+x*out[i];
}
ans-=t/2;
printf("%lld\n",ans);
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐