矩阵乘法最佳实现算法 (by CSAPP)


矩阵乘法最佳实现算法 (by CSAPP)

背景

注意看,这个男人叫矩阵乘法定义式

alt text

C[m, n] += A[m, k] * B[k, n]

不难写出一个 python 程序来实现:

1
2
3
4
for m in range(M):
for n in range(N):
for k in range(K):
C[m, n] += A[m, k] * B[k, n]

但是,

这是最快的算法吗?

不是。

推理过程

最基础的算法如上,这里 copy 过来:

1
2
3
4
for m in range(M):
for n in range(N):
for k in range(K):
C[m, n] += A[m, k] * B[k, n]

我们画图思考一下,数据的访问顺序


这是普通人能想到的,最简单算法。

A 的局部性原理利用的很好,
所以很快。

但是 B 的局部性原理排第二,
B 是按照列来访问的。
这很慢。跨越了局部性 cache

所以 B 也应该按照行访问。但是公式是死的。

C[m, n] += A[m, k] * B[k, n]

我们注意到,对于 C 来说,虽然没有使用 k,但它也是满足局部性原理的

我们的最终目的是,优先按行访。问
也就是优先迭代 n, k ,n。
这似乎是矛盾的。

k 是 A 的列索引(2 级索引),B 的行索引(1 级索引)。
不可能最优先访问。
所以就有了第二个想法。
迭代顺序为: m->k->n

1
2
3
4
for m in range(M):
for k in range(K):
for n in range(N):
C[m, n] += A[m, k] * B[k, n]

那么这是不是最快的呢?

alt text

这样我们实现了 BC 的按行访问

那么 A 呢?

由于 m 仍然在 k 外层:
M -> K -> N
所以 _A 也是按行访问的_。

其实所有的迭代顺序种类为 A(3, 3) = 3! = 6

CSAPP 测试了所有访问顺序

由于我们有三个行索引:m, m, k
三个列索引:n, k, n

所以,只要满足:
任取行索引 x,列索引 y,
满足 x 在 y 的外层,
就是最快算法;
反之是最慢算法

重要表格

所以所有算法的排序为

rank 访问顺序 m > n m > k k > n score
1 m->k->n 1 1 1 3
2 k->m->n 1 0 1 2
2 m->n->k 1 1 0 2
3 n->m->k 0 1 0 1
3 n->m->k 0 1 0 1
4 n->k->m 0 0 0 0
m n k
1 级索引 2 级索引 1 级索引+2 级索引

C[m, n] += A[m, k] * B[k, n]

结论

因此,最快矩阵乘法为

1
2
3
4
for m in range(M):
for k in range(K):
for n in range(N):
C[m, n] += A[m, k] * B[k, n]