关于矩阵乘法优化dp(入门+斐波那契模板题)

来源:互联网 发布:淘宝网店开店步骤2016 编辑:程序博客网 时间:2024/05/17 06:12

矩阵乘法就是指一个a*b的矩阵和一个b*c的矩阵相乘得到一个a*c的矩阵
我们分别叫做A矩阵,B矩阵和C矩阵。
C[i][j]=nk=1a[i][k]+b[k][j]
而用矩阵乘法优化dp时,实际上是一个矩阵自己与自己相乘,所以可以看做是矩阵的幂运算。
而关于幂运算,自然可以用快速幂来优化。

——————————————————————————————————————————–前话到此为止

我们都知道dp的实质就是递推
先举个例子吧,比如用矩阵乘法优化快速求斐波那契数列。
我们都知道斐波那契数列的递推式为f[i]=f[i-2]+f[i-1]
我们可以将这个式子转化成为一个矩阵A
F[i]    F[i-1]
F[i-1]  F[i-2]
与另一个矩阵B
1   1
1   0
相乘
得出来矩阵C
F[i]+F[i-1]   F[i]
F[i]         F[i-1]
最后实际上就变成了
F[i+1]   F[i]
F[i]     F[i-1]
实际上是i++的原矩阵。
而当矩阵C继续与B相乘的时候,显然又会的出F[i+2]相关的矩阵
那么实际上我们进行的操作就是由一个矩阵
F[3]   F[2]
F[2]   F[1]
与矩阵不断的相乘。
那么实际上就只需要求出B矩阵的幂再与A相乘。

关于实现上,我们知道正常的快速幂中res/ans(即记录返回值的量)的初始值是1
那么矩阵的快速幂中,记录返回矩阵的初始矩阵是什么样的呢。
经过大量推理(julizi),我发现实际上这个矩阵是一个除了左上→右下对角线为1,其余都为0的矩阵,大家也可以自己推一推(jujulizi),其实也很好理解,毕竟c[i][j]=a[i][k]b[k][j],而由于只有b[j][j]存在值,所以这个式子就变成了c[i][j]=a[i][j]b[j][j],而b[j][j]又是1.

模板题
http://poj.org/problem?id=3070

看到我以上口胡没看懂的可以自己用代码print一下矩阵什么的..不过毕竟入门题没啥难度

#include<bits/stdc++.h>#define fer(i,j,n) for(int i=j;i<=n;i++)#define far(i,j,n) for(int i=j;i>=n;i--)#define ll long longconst int maxn=4010;const int INF=1e9+7;const int mod=10000;using namespace std;/*----------------------------------------------------------------------------*/inline ll read(){    char ls;ll x=0,sng=1;    for(;ls<'0'||ls>'9';ls=getchar())if(ls=='-')sng=-1;    for(;ls>='0'&&ls<='9';ls=getchar())x=x*10+ls-'0';    return x*sng;}/*----------------------------------------------------------------------------*/int fib[]={0,1};struct kaga{    ll v[2][2];    kaga friend operator *(kaga a,kaga b)    {        kaga c;        fer(i,0,1)            fer(j,0,1)            {                c.v[i][j]=0;                fer(k,0,1)                    c.v[i][j]=(c.v[i][j]+a.v[i][k]*b.v[k][j])%mod;            }        return c;    }}a,b,c;void print(kaga a){    fer(i,0,1)    {        fer(j,0,1)        cout<<a.v[i][j]<<" ";        cout<<endl;    }}void init(){    a.v[0][0]=2;a.v[0][1]=a.v[1][0]=a.v[1][1]=1;    b.v[0][0]=b.v[1][0]=b.v[0][1]=1;b.v[1][1]=0;    c.v[0][1]=c.v[1][0]=0;c.v[1][1]=c.v[0][0]=1;    //print(a);    //print(b);    //print(c);}int main(){    int n;    while(scanf("%d",&n))    {        if(n==-1)return 0;        if(n<=1)        {            cout<<fib[n]<<endl;            continue;        }        n--;        init();        for(;n;n>>=1,b=b*b)            if(n&1)c=b*c;        a=a*c;        cout<<a.v[1][1]<<endl;    }}

(这代码是真的丑)

原创粉丝点击