您的位置:首页 > 大数据 > 人工智能

HDU 5877 Weak Pair

2016-09-12 16:06 417 查看
$dfs$序,线段树。

可以统计每一个节点作为$root$的子树上对答案的贡献,可以将树转换成序列。问题就变成了一段区间上求小于等于某个值的数有几个。用线段树记录排好序之后的区间序列,询问的时候,属于询问区间的每个节点二分一下统计答案即可。

#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#include<map>
#include<set>
#include<queue>
#include<stack>
#include<iostream>
using namespace std;
typedef long long LL;
const double pi=acos(-1.0),eps=1e-6;
void File()
{
freopen("D:\\in.txt","r",stdin);
freopen("D:\\out.txt","w",stdout);
}
template <class T>
inline void read(T &x)
{
char c=getchar(); x=0;
while(!isdigit(c)) c=getchar();
while(isdigit(c)) {x=x*10+c-'0'; c=getchar();}
}

const int maxn=100010;
int T,n,h[maxn],sz,r[maxn],root;
LL k,v[maxn];
struct Edge { int u,v,nx; }e[maxn];
LL a[2*maxn],L[2*maxn],R[2*maxn];
vector<int>s[8*maxn];

void add(int u,int v)
{
e[sz].u=u; e[sz].v=v;
e[sz].nx=h[u]; h[u]=sz++;
}

void dfs(int x)
{
sz++; a[sz]=v[x]; L[x]=sz;
for(int i=h[x];i!=-1;i=e[i].nx) dfs(e[i].v);
sz++; a[sz]=v[x]; R[x]=sz;
}

void build(int l,int r,int rt)
{
if(l==r) { s[rt].push_back(a[l]); return; }
int m=(l+r)/2; build(l,m,2*rt); build(m+1,r,2*rt+1);

int sum=0,p1=0,p2=0;
while(sum<r-l+1)
{
if(p1<s[2*rt].size()&&p2<s[2*rt+1].size())
{
if(s[2*rt][p1]<s[2*rt+1][p2]) s[rt].push_back(s[2*rt][p1]), p1++;
else s[rt].push_back(s[2*rt+1][p2]), p2++;
}
else if(p1<s[2*rt].size()) s[rt].push_back(s[2*rt][p1]), p1++;
else s[rt].push_back(s[2*rt+1][p2]), p2++;
sum++;
}
}

int get(int L,int R,LL num,int l,int r,int rt)
{
if(L<=l&&r<=R)
{
int left=0,right=r-l,pos=-1;

while(left<=right)
{
int mid=(left+right)/2;
if((LL)s[rt][mid]>num) right=mid-1;
else left=mid+1,pos=mid;
}

return pos+1;
}

int m=(l+r)/2,x1=0,x2=0;
if(L<=m) x1=get(L,R,num,l,m,2*rt);
if(R>m) x2=get(L,R,num,m+1,r,2*rt+1);
return x1+x2;
}

int main()
{
scanf("%d",&T);
while(T--)
{
scanf("%d%lld",&n,&k);
for(int i=1;i<=n;i++) scanf("%lld",&v[i]);

memset(h,-1,sizeof h);
memset(r,sz=0,sizeof r);
for(int i=0;i<8*maxn;i++) s[i].clear();

for(int i=1;i<=n-1;i++)
{
int u,v; scanf("%d%d",&u,&v);
add(u,v); r[v]++;
}

for(int i=1;i<=n;i++) if(r[i]==0) root=i;
sz=0; dfs(root); build(1,2*n,1);

LL Ans=0;
for(int i=1;i<=n;i++)
{
if(L[i]+1==R[i]) continue;
if(v[i]==0) { Ans=Ans+(R[i]-L[i]-1); continue; }
Ans=Ans+get(L[i]+1,R[i]-1,k/v[i],1,2*n,1);
}
printf("%lld\n",Ans/2);
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: