您的位置:首页 > 其它

BZOJ 3522 & 4543: [POI2014]Hotel

2016-10-21 23:48 465 查看
指针优化树形DP

BZOJ3522的n只有5000,可以随便用一个n^2的大暴力搞过去,就不讲了。

BZOJ4543好丧啊,看到题解我当时就震惊了,题解点我

n有十万,使人一看到就会放弃一些树形DP的思路,而正解正是树形DP,难点就在内存和时间的优化。

可以证明复杂度是O(n),详见上面的题解呀。

#include<cstdio>
#include<cstring>
#define N 5005
#define ll long long
using namespace std;
struct edge{int next,to;}e[N<<1];
int ecnt=0, last
, cnt
, vis
, timer, mx, mmx;
ll f
[4];
void addedge(int a, int b)
{
e[++ecnt]=(edge){last[a],b};
last[a]=ecnt;
}
void dfs(int x, int fa, int dep)
{
if(mx<dep) mx=dep;
if(vis[dep] < timer)vis[dep] = timer, cnt[dep] = 1;
else cnt[dep]++;
for(int i = last[x]; i; i=e[i].next)
{
int y=e[i].to;
if(y==fa)continue;
dfs(y,x,dep+1);
}
}
int main()
{
int n;
ll ans=0;
scanf("%d",&n);
for(int i = 1; i < n; i++)
{
int a, b;
scanf("%d%d",&a,&b);
addedge(a,b);
addedge(b,a);
}
for(int x = 1; x <= n; x++)
{
memset(f,0,sizeof(f));
for(int i = last[x]; i; i=e[i].next)
{
timer++;
int y=e[i].to;
mx=0;
dfs(y,x,1);
if(mmx<mx)mmx=mx;
for(int j = 1; j <= mx; j++)
{
f[j][3] += f[j][2]*cnt[j];
f[j][2] += f[j][1]*cnt[j];
f[j][1] += cnt[j];
}
}
for(int i = 1; i <= mmx; i++)
ans += f[i][3];
}
printf("%lld\n",ans);
}


#include<cstdio>
#define ll long long
#define N 100005
using namespace std;
ll *f
, *g
, temp[N*20], *cur, ans;
int n,ecnt, last
, under
, dep
;
struct edge{int next,to;}e[N<<1];
void addedge(int a, int b)
{
e[++ecnt]=(edge){last[a],b};
last[a]=ecnt;
}
void dfs(int x, int fa)
{
under[x]=x;
dep[x]=dep[fa]+1;
for(int i = last[x]; i; i=e[i].next)
{
int y=e[i].to;
if(y==fa)continue;
dfs(y,x);
if(dep[under[y]] > dep[under[x]])
under[x]=under[y];
}
for(int i = last[x]; i; i=e[i].next)
{
int y = e[i].to;
if(y==fa || (under[x] == under[y] && x!=1))continue;
int z = under[y];
cur += dep[z] - dep[x] + 5;
f[z] = cur;
g[z] = (cur+=5);
cur += (dep[z] - dep[x])*2 + 5;
}
}
void dp(int x, int fa)
{
for(int i = last[x]; i; i=e[i].next)
{
int y=e[i].to;
if(y==fa)continue;
dp(y,x);
if(under[y] == under[x])
{
f[x] = f[y] - 1;
g[x] = g[y] + 1;
}
}
f[x][0] = 1;
ans += g[x][0];
for(int i = last[x]; i; i=e[i].next)
{
int y=e[i].to;
if(y==fa || under[y] == under[x])continue;
for(int j = 0; j <= dep[under[y]] - dep[x]; j++)
{
ans += g[y][j] * f[x][j-1];//用y的f更新
ans += f[y][j] * g[x][j+1];//用y的g更新
}
for(int j = 0; j <= dep[under[y]] - dep[x]; j++)
{
g[x][j-1] += g[y][j];
g[x][j+1] += f[x][j+1] * f[y][j];
f[x][j+1] += f[y][j];
}
}
}
int main()
{
scanf("%d",&n);
for(int i = 1, a, b; i < n; i++)
{
scanf("%d%d",&a,&b);
addedge(a,b);
addedge(b,a);
}
cur = temp + 5;
dfs(1,0);
dp(1,0);
printf("%lld\n",ans);
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: