CoderXL's Blog

Back

朴素的矩阵乘法是 O(n3)O(n^3) 的. 以 Am×kBk×n=Cm×n\boldsymbol{A}^{m\times k} \boldsymbol{B}^{k\times n} = \boldsymbol{C}^{m\times n} 为例,计算 C\boldsymbol{C} 中的每个元素,都需要做 kk 次乘加运算,而 C\boldsymbol{C} 中有 m×nm\times n 个元素,因此时间复杂度是 O(n3)O(n^3).

在此基础上,想要连续做 pp 次矩阵乘法,复杂度为 O(pn3)O(pn^3). 这也能部分解释为何当下神经网络算力消耗大。

但是,一些特殊的情况,我们会需要计算某个方阵的 pp 次幂,此时参考快速幂的做法(还未学习快速幂的同学可以看看昨天的题解),进行矩阵快速幂,将时间复杂度优化到 O(n3logp)O(n^3 \log p).


昨天 GaP 介绍了快速幂的分治理解,这里补充一个递推理解,其实可以看到二者的本质是相同的:

欲计算 abmodpa^b \mod p,朴素做法是将 aa 连乘 bb 次,复杂度 O(b)O(b).

但如果我们把 bb 按二进制拆位成: b=b0+b1×2+b2×22++bt×2t, bi{0,1}b=b_0 + b_1\times 2 + b_2\times 2^2 + \cdots + b_t\times2^t,\ b_i\in\{0,1\},其中 t=log2bt=\lfloor \log_2 b \rfloor, 那么就有 ab=ab0(a2)b1(a4)b2(a2t)bta^b = a^{b_0} \cdot (a^2)^{b_1} \cdot (a^4)^{b_2} \cdots (a^{2^t})^{b_t}. 考虑递推:每个 a2ia^{2^i} 都能由 (a2i1)2(a^{2^{i-1}})^2 立刻得到。 因此在计算 aba^b 时,只需要 O(t)O(t) 地递推 a2ia^{2^i},然后根据 bib_i 是否等于 11,决定是否将 a2ia^{2^i} 乘到答案里。


代码非常直白,利用了一些重载语法以提高可读性:

#include <bits/stdc++.h>
using namespace std;

using ll = long long;
const int N = 200;
const ll Mod = 1e9+7;
int n;
ll k;

struct Matrix
{
	ll v[N][N];
    int m,n; //m 行 n 列
	inline static ll tmp[N][N]; //整个类共享的临时空间
	ll* operator[](const int&i){return v[i];}
    const ll* operator[](const int&i)const{return v[i];}
	Matrix&operator*=(const Matrix&y)
	{
        if(n!=y.m)throw "Shape mismatch.";
		memset(tmp,0,sizeof(tmp));
	    for(int i=1;i<=m;i++)
	        for(int l=1;l<=n;l++)
	            for(int j=1;j<=y.n;j++) //合适的 for 循环顺序能够提高内存访问效率
	            {
	                tmp[i][j]+=v[i][l]*y[l][j];
	                tmp[i][j]%=Mod;
	            }
        n=y.n;
	    for(int i=1;i<=m;i++)
	    {
	        for(int j=1;j<=n;j++)
	        {
	            v[i][j]=tmp[i][j];
	        }
	    }
		return *this;
	}
}a,ans;

int main()
{
	cin>>n>>k;
	a.m=a.n=ans.m=ans.n=n;
	for(int i=1;i<=n;i++)
    {
        ans[i][i]=1; // 初始化为单位矩阵
        for(int j=1;j<=n;j++)
        {
            cin>>a[i][j];
        }
    }
    while(k)
    {
        if(k&1)
        {
            ans*=a;
        }
        a*=a;
        k>>=1;
    }
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=n;j++)
        {
            cout<<ans[i][j]<<' ';
        }
        cout<<'\n';
    }
	return 0;
}
cpp
P3390 矩阵快速幂
https://blog.leosrealms.top/blog/miscellaneous/daily-luogu/2026-03-30-p3390-matrix-quick-power
Author CoderXL
Published at 2026年3月30日
Comment seems to stuck. Try to refresh?✨