您的位置:首页 > 其它

HDU 2296 Ring(AC自动机+DP)

2014-04-13 12:23 411 查看
HDU 2296 Ring(AC自动机+DP)
http://acm.hdu.edu.cn/showproblem.php?pid=2296
题意:

给你M个单词构成一个词典,每个单词有一个权值(单词出现多次算多个权值),现在要你构造一个不超过长度N的字符串,使得该字符串权值最大。如果出现多个答案,输出最短的,如果依然有多解,输出字典序最小的。

分析:

本题和之前几题类似。不过注意这题的AC自动机要用match, match[i]表示i节点的后缀单词权值总和.后缀单词指的是:所有可以做i节点表示的串的后缀的单词.

令dp[i][j]=x表示当前在i点走过了长j的路所生产的最大权值为x. 再用string path[i][j]来保存那个具有最大权值的最优字符串即可.不过要注意如果最后算的的最大权值是0,那么要输出空串.因为空串最短.

dp[i][j] = max(dp[k][j-1]+match[i]) 若dp[i][j]更新的时候,path[i][j]也要看看是否需要更新成:

if(path[i][j]!=””)path[i][j] = better(path[i][j] , path[k][j-1]+”x”)x字符表示从k走到i的那条边的字符是x.

初值path=空串,dp[0][0]=0.

注意:做这题的时候出现了几个错误的地方.

首先对于dp[i][j]不可达的节点一定要置-1,不能置1

其次对于当dp[ch[k][j]][len + 1] <dp[k][len] + match[ch[k][j]]时直接替换原来的path,而不是求最优了.

还有HDU提交这题的时候我一直编辑错误,不知道错在哪里?后来不得不改掉一些函数了.后来经过无数次的测试,发现少些了#include<string>
.


AC代码:

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<string>
using namespace std;
int N,M;
const int maxnode=1000+100;
const int sigma_size=26;
inline string better(string a,string b)
{
    if(a=="")
        return b;
    if(a.size()!=b.size())
        return a.size()<b.size() ? a:b;
    return a<b? a:b;
}
struct AC_Automata
{
    int ch[maxnode][sigma_size];
    int match[maxnode];
    int f[maxnode];
    int sz;
    int dp[maxnode][50+10];
    string path[maxnode][50+10];
    int ans;
    string res;
    void init()
    {
        sz=1;
        memset(ch[0],0,sizeof(ch[0]));
        match[0]=f[0]=0;
    }
    void insert(char *s,int v)
    {
        int n=strlen(s),u=0;
        for(int i=0;i<n;i++)
        {
            int id=s[i]-'a';
            if(ch[u][id]==0)
            {
                ch[u][id]=sz;
                memset(ch[sz],0,sizeof(ch[sz]));
                match[sz++]=0;
            }
            u=ch[u][id];
        }
        match[u]=v;
    }
    void getFail()
    {
        queue<int> q;
        for(int i=0;i<sigma_size;i++)
        {
            int u=ch[0][i];
            if(u)
            {
                f[u]=0;
                q.push(u);
            }
        }
        while(!q.empty())
        {
            int r=q.front();q.pop();
            for(int i=0;i<sigma_size;i++)
            {
                int u=ch[r][i];
                if(!u){ ch[r][i]=ch[f[r]][i]; continue; }
                q.push(u);
                int v=f[r];
                while(v && ch[v][i]==0) v=f[v];
                f[u]=ch[v][i];
                match[u] += match[f[u]];
            }
        }
    }
    void solve()
    {
        ans=0;
        res="";
        for(int i=0;i<sz;i++)
            for(int j=0;j<=N;j++)
            {
                dp[i][j]=-1;
                path[i][j]="";
            }
        dp[0][0]=0;
        for(int len=0;len<N;len++)//当前走的长度
            for(int k=0;k<sz;k++)if(dp[k][len]!=-1)//当前所在的节点
            {
                for(int j=0;j<sigma_size;j++)//下一步走的方向
                    if(dp[ch[k][j]][len+1] <= dp[k][len]+match[ch[k][j]])
                    {
                        char temp='a'+j;
                        if(dp[ch[k][j]][len+1] < dp[k][len]+match[ch[k][j]])
                        {
                            dp[ch[k][j]][len+1] = dp[k][len]+match[ch[k][j]];
                            path[ch[k][j]][len+1] = path[k][len]+temp;
                        }
                        else
                        {
                            path[ch[k][j]][len+1] = better(path[ch[k][j]][len+1],path[k][len]+temp);
                        }

                        if(ans < dp[ch[k][j]][len+1])
                        {
                            ans = dp[ch[k][j]][len+1];
                            res = path[ch[k][j]][len+1];
                        }
                        else if(ans == dp[ch[k][j]][len+1])
                            res = better(res,path[ch[k][j]][len+1]);
                    }
            }
    }
}ac;
char str[100+10][10+5];
int val[100+10];
int main()
{
    int T;
    scanf("%d",&T);
    while(T--)
    {
        ac.init();
        scanf("%d%d",&N,&M);
        for(int i=0;i<M;i++)
            scanf("%s",str[i]);
        for(int i=0;i<M;i++)
        {
            scanf("%d",&val[i]);
            ac.insert(str[i],val[i]);
        }
        ac.getFail();
        ac.solve();
        if(ac.ans==0)
            printf("\n");
        else
            cout<<ac.res<<endl;
    }
    return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: