您的位置:首页 > 其它

[BZOJ5110]Yazid的新生舞会

2018-01-08 20:25 337 查看

题目大意:
  给你一个长度为$n(n\leq 5\times 10^5)$的序列$A_{1\sim n}$。求满足区间众数在区间内出现次数严格大于$\lfloor\frac{r-l+1}{2}\rfloor$的区间$[l,r]$的个数。

思路:
  分治。
  对于一个区间$[l,r]$,设$mid=\lfloor\frac{l+r}{2}\rfloor$,我们可以求出所有经过$mid$的区间内能够成为众数的所有数。
  不难发现所有的区间众数满足如下一个性质:如果$x$是区间$[l,r]$的众数,那么对于$l\leq x\leq r$,$x$一定是区间$[l,k]$或区间$(k,r]$的众数。
  利用这一性质,我们可以令$k=mid$,这样就可以$O(n)$从$mid$出发往左右两边扫,求出能够成为众数的所有数。
  接下来枚举每个众数$x$,求一下当前$[l,r]$区间中,以$x$作为众数的子区间个数。
  具体我们可以先从$mid$往左扫,设往左扫到的端点为$b$,记录一下对于不同的$b$,$mid-b+1-cnt[x]$不同取值的出现次数。然后再往右扫,求出对于当前右端点$e$,求出满足$e-b+1-cnt[x]>\lfloor\frac{e-b+1}{2}\rfloor$的区间$[b,e]$的个数,这可以用前缀和快速求出。
  这样我们就统计了区间$[l,r]$,经过$mid$的所有子区间。
  对于不经过$mid$的子区间可以递归求解。
  递归树中,每一层区间长度加起来是$n$,可能的众数个数有$\log n$个,每一层的时间复杂度是$O(n\log n)$。总共有$\log n$层,总的时间复杂度是$O(n\log^2 n)$。

#include<cstdio>
#include<cctype>
#include<algorithm>
typedef long long int64;
inline int getint() {
register char ch;
while(!isdigit(ch=getchar()));
register int x=ch^'0';
while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0');
return x;
}
const int N=500001;
int a
,pos
,num
,cnt[N*2];
int64 ans;
void solve(const int &l,const int &r) {
if(l==r) {
ans++;
return;
}
const int mid=(l+r)/2;
solve(l,mid);
solve(mid+1,r);
for(register int i=mid;i>=l;i--) {
if(++cnt[a[i]]>(mid-i+1)/2) {
if(!pos[a[i]]) {
num[pos[a[i]]=++num[0]]=a[i];
}
}
}
for(register int i=mid+1;i<=r;i++) {
if(++cnt[a[i]]>(i-mid)/2) {
if(!pos[a[i]]) {
num[pos[a[i]]=++num[0]]=a[i];
}
}
}
for(register int i=l;i<=r;i++) {
pos[a[i]]=cnt[a[i]]=0;
}
for(register int i=1;i<=num[0];i++) {
int sum=r-l+1,max=sum,min=sum;
cnt[sum]=1;
for(register int j=l;j<mid;j++) {
if(a[j]==num[i]) {
sum++;
} else {
sum--;
}
max=std::max(max,sum);
min=std::min(min,sum);
cnt[sum]++;
}
if(a[mid]==num[i]) {
sum++;
} else {
sum--;
}
for(register int i=min;i<=max;i++) {
cnt[i]+=cnt[i-1];
}
for(register int j=mid+1;j<=r;j++) {
if(a[j]==num[i]) {
sum++;
} else {
sum--;
}
ans+=cnt[std::min(max,sum-1)];
}
for(register int i=min;i<=max;i++) {
cnt[i]=0;
}
}
num[0]=0;
}
int main() {
const int n=getint(); getint();
for(register int i=1;i<=n;i++) {
a[i]=getint();
}
solve(1,n);
printf("%lld\n",ans);
return 0;
}

 

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: