您的位置:首页 > 其它

hdu 5411 2015多校十1006 ~矩阵快速幂

2015-08-30 18:33 399 查看

题意:

[code]给定n个碎片的转移关系,问最多使用m个碎片组成的不同的序列个数是多少。


思路:

容易想到dp的方法,以
dp[i][j]
表示长度为i以j号碎片结尾的不同序列的数量。那么
dp[i + 1][k] += dp[i][j](j -> k 可以转移)dp[0][0] = 1;
最后的答案即为整个dp[i][j]数组所有元素的和。但是一算复杂度,超时。

于是乎想到矩阵,关系矩阵A的k次方可以求出任意长度为k+1的序列总的数量。但是这道题要求最多长度为m的总数量,也就是说要求A^0 + A^1 + A^2… + A^m - 1。如果分别用快速幂,算一算复杂度依旧超时。然后就发现了构造矩阵的神奇方法。构造的矩阵B如下:

A . . . A 1

A . . . A 1

. .. . …. A 1

. . . . A 1

A . . . A 1

0 . . . 0 1

也就是在原来的转移矩阵最右边加上一列1,用于保存上一步以及之前的总的答案,

最后的答案为矩阵B^(m-1) 的所有项的和。时间复杂度O(n^3 *log(m)),注意特判m = 1(论构造矩阵的神奇)

附搓代码:

[code]#include <cstdio>
#include <cstring>
using namespace std;
const int piece = 55;
int ch[piece][piece];
int ans[piece][piece];
int tp[piece][piece];
int n, m;
void cal(int a[][piece], int b[][piece], int c[][piece])
{
    for(int i = 1; i <= n + 1; i++)
        for(int j = 1; j <= n + 1; j++)
        {
            int tmp = 0;
            for(int k = 1; k <= n + 1; k++)
                tmp = (tmp + a[i][k] * b[k][j]) % 2015;
            tp[i][j] = tmp;
        }
    memcpy(c, tp, sizeof(tp));
}
void rapid(int a[][piece], int k, int b[][piece])
{
    for(int i = 1; i <= n + 1; i++) b[i][i] = 1;
    while(k)
    {
        if(k & 1)
            cal(b, a, b);
        k >>= 1;
        cal(a, a, a);
    }
}
int main()
{
    // freopen("5411.in","r",stdin);
    // freopen("5411out.txt","w",stdout);
    int t;
    scanf("%d", &t);
    while(t--)
    {
        scanf("%d%d", &n, &m);
        memset(ch, 0, sizeof(ch));

        for(int i = 1; i <= n; i++)
        {
            int k;
            scanf("%d", &k);
            for(int j = 0; j < k; j++)
            {
                int tmp;
                scanf("%d", &tmp);
                ch[i][tmp] = 1;
            }
        }
        if(m == 1)
            printf("%d\n", n + 1);
        else
        {
            memset(ans, 0, sizeof(ans));
            for(int i = 1; i <= n + 1; i ++)
                ch[i][n + 1] = 1;
            for(int i = 1 ; i <= n; i++)
                ch[n + 1][i] = 0;
            rapid(ch, m - 1, ans);
            int ls = 0;
            for(int i = 1; i <= n + 1; i++)
                for(int j = 1; j <= n + 1; j++)
                    ls = (ls + ans[i][j]) % 2015;
            printf("%d\n", ls);
        }
    }
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: