您的位置:首页 > 其它

Codeforces 868F 分治优化Dp

2017-10-22 02:02 417 查看
原题链接:http://codeforces.com/problemset/problem/868/F

大致题意:给出有n(n<=10^5)个元素的序列,元素值ai<=n,需要将其分为m(m<=min( 20, n))段,每段的费用是∑( calc[i]-1)*calc[i]/2,其中calc[i]为元素值为i的个数。

我的理解,首先这题似乎和合并类DP很相似,如果定下f[ i ][ j ]为将前i个元素分割为j段的最小代价,有个显然地转移方程就是f[ i ] [ j ] = MAX{ f[ i1 ] [ j-1 ]+cost( i1+1,i)  }(i1<i)

然后似乎有个单调性可以试着证明一下:如果i1 是满足f[ i ][ j ]=f[ i1 ][ j-1 ]+cost( i1+1,i)的最小值,则对于使得f[ k ][ j ]=f[ i2 ][ j-1 ]+cost( i2+1 , k ) (k>i)的i2必然有i2>=i1。

因为对于所有的i0 < i1 , f[ i0 ][ j-1 ]+cost( i0+1 , i ) >f[ i1 ][ j-1 ]+cost( i1+1 , i ),然后显然有cost( i0+1 , k )  > cost( i1+1 , k )

于是就得到了 f[ i0 ][ j-1 ]+cost( i0+1 , k ) >f[ i1 ][ j-1 ]+cost( i1+1 , k ),即对于任意的j,f[
i ][ j ]的转移点对于i具有单调性

换种简洁点的表述就是,令f[ i ][ j ]=f[  from[ i ][ j ]  ][ j-1 ]+cost( from[ i ][ j ]+1 , i ),则from[ i ][ j ]>=from[ i0 ][ j ]( i0< i )

这题的精妙之处就在于巧妙的利用了这个单调性,如果要求得f[ i ][ j ](L<= i <= R )的值,则可以先求from[ (L+R)/2 ][ j ]的值,然后对于i<(L+R)/2,就有from[ (L+R)/2 ][ j ]>from[ i0 ][ j ] ,对于i>(L+R)/2 ,就有from[ (L+R)/2 ][ j ],然后再递归处理i在[ L , (L+R)/2
-1 ]和[ (L+R)/2+1 , R ]区间时的情况。显然这样递归的层数是logn层,每一层都有总长度为n的扫描[L, R ]区间求from[ (L+R)/2 ][ j ]的花销,这部分总耗时为O(m*n*log n )。

但是怎么高效地求cost( l , r )的值,考虑到既然都是暴力扫描求from[ (L+R)/2 ][ j ],如果可以相同复杂度的消耗来完成就好了。观察分治后的各个子段内部,因为是从左到右依次暴力枚举,而cost的值的增量只和已有的数值数量有关,只要维护当前枚举段内各值的出现次数同时把增量就可以了,于是对于每个字段,求cost也是O(子段的长度)的复杂度了,也就是说,求cost的总复杂度和扫描求from[ (L+R)/2
][ j ]的总复杂度一样都是O(m*n*log n )。

整体算法就完成了

代码:

#include <bits/stdc++.h>
using namespace std;
inline void read(int &x){
char ch;
bool flag=false;
for (ch=getchar();!isdigit(ch);ch=getchar())if (ch=='-') flag=true;
for (x=0;isdigit(ch);x=x*10+ch-'0',ch=getchar());
x=flag?-x:x;
}

inline void read(long long &x){
char ch;
bool flag=false;
for (ch=getchar();!isdigit(ch);ch=getchar())if (ch=='-') flag=true;
for (x=0;isdigit(ch);x=x*10+ch-'0',ch=getchar());
x=flag?-x:x;
}
inline void write(int x){
static const int maxlen=100;
static char s[maxlen];
if (x<0) { putchar('-'); x=-x;}
if(!x){ putchar('0'); return; }
int len=0; for(;x;x/=10) s[len++]=x % 10+'0';
for(int i=len-1;i>=0;--i) putchar(s[i]);
}

const int MAXN = 120000;
const int MAXM = 22 ;
typedef long long ll;

int n , m;
int num[ MAXN ];
int calc[ MAXN ];
ll f[ MAXN ];
ll pre[ MAXN ];
ll st,ed,sum;

void solve(int ans_l,int ans_r,int aim_l,int aim_r){
if ( aim_l > aim_r )
return ;
int mid=(aim_l+aim_r)/2;
int ans_m=ans_l;
ll tmp=1ll<<60;
for (int i=ans_l;i<=min(mid-1, ans_r);i++)
{
while ( st< i+1 ) calc[ num[st] ]-- , sum-=calc[ num[st] ] ,st++ ;
while ( ed> mid ) calc[ num[ed] ]-- , sum-=calc[ num[ed] ] ,ed-- ;
while ( st> i+1 ) st-- ,sum+=calc[ num[st] ] , calc[ num[st] ]++ ;
while ( ed< mid )
{
ed++ ,sum+=calc[ num[ed] ] , calc[ num[ed] ]++ ;
//printf("----%d %d %d %d\n",ed, mid ,num[ed],calc[ num[ed] ]);
}
//printf("%d %d %d %d %d\n",st,ed,i,mid,sum);
//printf("%d %d %d %d\n",pre[i],sum,pre[ans_m],tmp);
if ( pre[i] + sum < pre[ ans_m ] + tmp )
ans_m=i,tmp=sum;
}
f[mid]=pre[ans_m]+tmp;
//printf("%d --> %d\n",ans_m,mid);
solve( ans_l, ans_m , aim_l , mid-1 );
solve( ans_m, ans_r , mid+1 , aim_r );
}

int main(){
read(n); read(m);
for (int i=1;i<=n;i++)
read(num[i]);
pre[0]=0;
for (int i=1;i<=n;i++)
pre[i]=1ll<<60;
sum=0;
st=1;ed=1;sum=0;calc[ num[1] ]=1;
while ( m--)
{
solve(0,n-1,1,n);
for (int i=1;i<=n;i++)
pre[i]=f[i];
/*
for (int i=1;i<=n;i++)
printf("%d ",pre[i]);
puts("");
*/
memset(f,0,sizeof(f));
}
printf("%I64d\n",pre
);
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: