您的位置:首页 > 其它

金华邀请赛 B题 poj 4045

2012-07-07 21:06 281 查看
昨晚没能把这题A掉,很懊悔,且不管什么树形DP,只要dfs水平过硬,这题也就变成水的了。大神们都说是比较简单但是却是比较经典的树形dp,还有用两次DFS扫描全树就行。赛后请教了孟神,才明白就是这么个样。重新打了代码,T了两次,就A了。
       题意:n个点之间有n-1条边相连,形成一棵树,在这n个点上选一个点建一个发电站,要求该发电站到所有点的代价总和最小,每一个点到发电站的代价为I*I*R*D,D为该点到发电站的距离(每条边的权值均视为1)。
      题目分析:dfs1扫描并算出树中每个节点有多少个孩子,dfs2重下往上算出子节点到根节点的花费代价,dfs3从上往下扫描,算出根到节点花费的代价,加起来就可以了,,,太深,,,脑子还不够灵活我,做的时候我老想一个dfs解决问题,最终失败。

#include <iostream>
#include <cstdio>
#include <cmath>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=50005;
int child[maxn];
long long  val[maxn];
int vis[maxn];
int n,m,k;
vector<int>mp[maxn];
void init()
{
    for(int i=1;i<=n;i++)
    mp[i].clear();
    memset(val,0,sizeof(val));
    memset(vis,0,sizeof(vis));
}
void dfs_1(int u)
{
    child[u]=0;
    vis[u]=1;
    for(int i=0;i<mp[u].size();i++)
    {
        int v=mp[u][i];
        if(!vis[v])
        {
          dfs_1(v);
          child[u]+=child[v]+1;
        }
    }
}

void dfs_2(int u)
{
    vis[u]=1;
    val[0]=0;
    for(int i=0;i<mp[u].size();i++)
    {
        int v=mp[u][i];
        if(!vis[v])
        {
            dfs_2(v);
            val[u]+=val[v]+child[v]+1;
        }
    }
}

void dfs_3(int u)
{
    vis[u]=1;
    for(int i=0;i<mp[u].size();i++)
    {
        int v=mp[u][i];
        if(!vis[v])
        {
            val[v]+=val[u]-child[v]-val[v]-1+n-child[v]-1;
            dfs_3(v);
        }
    }
}

int main()
{
    int cas;
    int u,v;
    scanf("%d",&cas);
    while(cas--)
    {
       scanf("%d%d%d",&n,&m,&k);
       init();
       for(int i=0;i<n-1;i++)
       {
           scanf("%d%d",&u,&v);
           mp[u].push_back(v);
           mp[v].push_back(u);
       }

        dfs_1(1);//算出直接和间接孩子总个数
//        for(int i=1;i<=n;i++)
//        cout<<child[i]<<endl;
        memset(vis,0,sizeof(vis));
        dfs_2(1);//可以求根的代价//重下往上算
      //  cout<<val[1]<<endl;
        memset(vis,0,sizeof(vis));
        dfs_3(1);//从上往下算出每个节点的代价
//        for(int i=1;i<=n;i++)
//        cout<<val[i]<<endl;
        long long  ans=99999999999;
        for(int i=1;i<=n;i++)
        {
            ans=min(val[i],ans);
        }
        printf("%lld\n",ans*m*m*k);
        for(int i=1;i<=n;i++)
        {
            if(ans==val[i])
            printf("%d ",i);
        }
        cout<<endl<<endl;
    }
    return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: