您的位置:首页 > 其它

poj3233Matrix Power Series(矩阵快速幂,两种写法)

2015-01-16 20:55 155 查看
Matrix Power Series
Time Limit:3000MS     Memory Limit:131072KB     64bit IO Format:%I64d & %I64u
Submit Status

Description

Given a n × n matrix A and a positive integer k, find the sum S = A + A2 + A3 + … + Ak.

Input

The input contains exactly one test case. The first line of input contains three positive integers n (n ≤ 30), k (k ≤ 109) and m (m < 104). Then follow n lines each containing n nonnegative
integers below 32,768, giving A’s elements in row-major order.

Output

Output the elements of S modulo m in the same way as A is given.

Sample Input

2 2 4
0 1
1 1


Sample Output

1 2
2 3


第一种,令 B = A  E     E为单位矩阵 

 0   E

那么B^(n+1)  = A^n    A^n+A*(n-1)....A+1

0          E                               注意,当取余是一定要判断,如果此时的i==j 应该先-1再取余再+1,因为多了一个单位矩阵。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;
#define LL __int64

struct node{
LL k[32][32] ;
int n ;
};
struct node1{
node a , b , c , d ;
};
node mul(node p,node q,int m)
{
int i , j , l ;
node s ;
s.n = p.n ;
for(i = 0 ; i < p.n ; i++)
for(j = 0 ; j < p.n ; j++)
{
s.k[i][j] = 0 ;
for(l = 0 ; l < p.n ; l++)
{
if(i == j)
s.k[i][j] = (s.k[i][j] + p.k[i][l]*q.k[l][j]-1)%m+1 ;
else
s.k[i][j] = (s.k[i][j] + p.k[i][l]*q.k[l][j])%m ;
}
}
return s ;
}
node add(node p,node q,int m)
{
int i , j ;
for(i = 0 ; i < p.n ; i++)
for(j = 0 ; j < p.n ; j++)
{
if(i == j)
p.k[i][j] = (p.k[i][j] + q.k[i][j]-1 )%m+1;
else
p.k[i][j] = (p.k[i][j] + q.k[i][j] )%m;

}
return p ;
}
node1 pow(node1 o,int k,int m)
{
if( k == 1 )
return o ;
node1 temp = pow(o,k/2,m) , s ;
node p , q ;
p = mul(temp.a,temp.a,m) ; q = mul(temp.b,temp.c,m) ;
s.a = add( p,q,m ) ;
p = mul(temp.a,temp.b,m) ; q = mul(temp.b,temp.d,m) ;
s.b = add( p,q,m ) ;
p = mul(temp.c,temp.a,m) ; q = mul(temp.d,temp.c,m) ;
s.c = add( p,q,m ) ;
p = mul(temp.c,temp.b,m) ; q = mul(temp.d,temp.d,m) ;
s.d = add( p,q,m ) ;
temp = s ;
if( k%2 )
{
p = mul(temp.a,o.a,m) ; q = mul(temp.b,o.c,m) ;
s.a = add( p,q,m ) ;
p = mul(temp.a,o.b,m) ; q = mul(temp.b,o.d,m) ;
s.b = add( p,q,m ) ;
p = mul(temp.c,o.a,m) ; q = mul(temp.d,o.c,m) ;
s.c = add( p,q,m ) ;
p = mul(temp.c,o.b,m) ; q = mul(temp.d,o.d,m) ;
s.d = add( p,q,m ) ;
}
return s ;
}
int main()
{
int n , k , m ;
int i , j ;
node1 p , s ;
while( scanf("%d %d %d", &n, &k, &m) != EOF )
{
p.a.n = p.b.n = p.c.n = p.d.n = n ;
for(i = 0 ; i < n ; i++)
{
for(j = 0 ; j < n ; j++)
{
scanf("%I64d", &p.a.k[i][j]) ;
p.b.k[i][j] = p.c.k[i][j] = p.d.k[i][j] = 0 ;
}
p.b.k[i][i] = p.d.k[i][i] = 1 ;
}
s = pow(p,k+1,m) ;
for(i = 0 ; i < n ; i++)
{
for(j = 0 ; j < n ; j++)
{
if( i == j )
printf("%I64d", s.b.k[i][j]-1) ;
else
printf("%I64d", s.b.k[i][j]);
if( j == n-1 )
printf("\n") ;
else
printf(" ") ;
}
}
}
return 0;
}


第二种写法
当n为偶数时  A + A^2 + A^3 ......A^n = ( A + A^2 + A^3...A^(k/2) ) + A^(k/2)*(  A + A^2 + A^3...A^(k/2) )

当n为奇数时  A + A^2 + A^3 ......A^n = ( A + A^2 + A^3...A^(k/2) ) + A^(k/2+1)+ A^(k/2+1)*(  A + A^2 + A^3...A^(k/2) )

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;
#define LL __int64
struct node
{
LL a[32][32] ;
int n ;
};
node mul(node p,node q,int m)
{
node s ;
s.n = p.n ;
int i , j , k ;
for(i = 0 ; i < p.n ; i++)
for(j = 0 ; j < p.n ; j++)
{
s.a[i][j] = 0 ;
for(k = 0 ; k < p.n ; k++)
s.a[i][j] = ( s.a[i][j] + p.a[i][k]*q.a[k][j] ) % m ;
}
return s ;
}
node add(node p,node q,int m)
{
int i , j ;
node s ;
s.n = p.n ;
for(i = 0 ; i < p.n ; i++)
for(j = 0 ; j < p.n ; j++)
s.a[i][j] = ( p.a[i][j] + q.a[i][j] ) % m ;
return s ;
}
node pow(node p,int k,int m)
{
if( k == 1 )
return p ;
node s = pow(p,k/2,m) ;
s = mul(s,s,m) ;
if( k%2 )
s = mul(s,p,m) ;
return s ;
}
node f(node p,int k,int m)
{
if( k == 1 )
return p ;
node s = f(p,k/2,m) , q , temp ;
int i , j ;
for(i = 0 , temp.n = p.n; i < p.n ; i++)
for(j = 0 ; j < p.n ; j++)
{
if( i == j )
temp.a[i][j] = 1 ;
else
temp.a[i][j] = 0 ;
}
if( k%2 )
{
q = pow(p,k/2+1,m) ;
s = add( q, mul( s,add(q,temp,m),m ) ,m ) ;
}
else
{
q = pow(p,k/2,m) ;
s = mul(s, add(q,temp,m) ,m);
}
return s ;
}
int main()
{
int n , k , m ;
int i , j ;
node p , s ;
while( scanf("%d %d %d", &n, &k, &m) != EOF )
{
p.n = n ;
for(i = 0 ; i < n ; i++)
for(j = 0 ; j < n ; j++)
scanf("%I64d", &p.a[i][j]) ;
s = f(p,k,m) ;
for(i = 0 ; i < n ; i++)
{
for(j = 0 ; j < n ; j++)
{
if( j == n-1 )
printf("%I64d\n", s.a[i][j]) ;
else
printf("%I64d ", s.a[i][j]) ;
}
}
}
return 0;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: