hdu 5411 CRB and Puzzle(矩阵快速幂)

来源:互联网 发布:js去除数组的重复元素 编辑:程序博客网 时间:2024/06/05 10:10

题目链接:

http://acm.hdu.edu.cn/showproblem.php?pid=5411

解题思路:

题目大意:

给定n个点 常数m

下面n行第i行第一个数字表示i点的出边数,后面给出这些出边。

问:图里存在多少条路径使得路径长度<=m,路径上的点可以重复。

官方题解:

We can count the number of different patterns by counting the number of different paths of length at most M-1M1.

This can be solved by multiply the following matrix M-1M1 times.

D=\begin{pmatrix} & & & 1\ & A& & 1\ & & & M\ 0& \Lambda &0 &1 \end{pmatrix}D=0AΛ011M1

Here, AA is adjacency matrix of the given graph.

Then,Ans\ =\ \sum_{i=1}^{N+1}\sum_{j=1}^{N+1}D^{M-1}[i][j]Ans = i=1N+1j=1N+1DM1[i][j]

Time complexity:O(N^{3}\cdot logM)O(N3logM)

算法思想:

首先能得到一个m*n*n的dp,dp[i][j]表示路径长度为i 路径的结尾为j的路径个数 。

答案就是sigma(dp[i][j]) for every i from 1 to m, j from 1 to n;

我们先计算 路径长度恰好为 i 的方法数。

用矩阵快速幂,会发现是

其中B矩阵是一个n*n的矩阵,也就是输入的邻接矩阵。

A是一个n行1列的矩阵 A[i][1]表示长度为1且以i结尾的路径个数,所以A矩阵是全1矩阵。

相乘得到的n*1 的矩阵求和就是路径长度恰好为i的条数。

那么<=m的路径就是:

把A提出来,里面就是一个关于B的矩阵等比数列。

AC代码:

#include <iostream>#include <cstdio>#include <cstring>using namespace std;const int MOD = 2015;struct matrix//矩阵{    int m[55][55];    matrix(){        memset(m,0,sizeof(m));    }};int n,m;void debug(matrix a){    for(int i = 1; i <= n+1;i++){        for(int j = 1; j <= n+1; j++){            cout<<a.m[i][j];        }        cout<<endl;    }}matrix multi(matrix a, matrix b){    matrix tmp;    for(int i = 0; i < 55; ++i)    {        for(int j = 0; j < 55; ++j)        {            for(int k = 0; k < 55; ++k)                tmp.m[i][j] = (tmp.m[i][j] + a.m[i][k] * b.m[k][j]) % MOD;        }    }    return tmp;}int fast_mod(matrix ans,matrix base, int m)  // 求矩阵 base 的  n 次幂{    while(m)    {        if(m & 1)  //实现 ans *= t; 其中要先把 ans赋值给 tmp,然后用 ans = tmp * t            ans = multi(ans, base);        base = multi(base, base);        m >>= 1;    }    return ans.m[1][n+1];}int main(){    int T;    scanf("%d",&T);    while(T--){        scanf("%d%d",&n,&m);        matrix ans,base;        for(int i = 1; i <= n+1; i++)           base.m[i][n + 1] = 1;        int x,xx;        for(int i = 1; i <= n; i++){            scanf("%d",&x);            for(int j = 1; j <= x; j++){                scanf("%d",&xx);                base.m[i][xx] = 1;            }        }        for(int i = 1; i <= n+1; i++)            ans.m[1][i] = 1;        if(m == 1)            printf("%d\n",n+1);        else            printf("%d\n",fast_mod(ans,base,m));    }    return 0;}


0 0
原创粉丝点击