矩阵乘法最佳实现算法 (by CSAPP)
矩阵乘法最佳实现算法 (by CSAPP)
背景
注意看,这个男人叫矩阵乘法定义式
C[m, n] += A[m, k] * B[k, n]
不难写出一个 python 程序来实现:
1 |
|
但是,
这是最快的算法吗?
不是。
推理过程
最基础的算法如上,这里 copy 过来:
1 |
|
我们画图思考一下,数据的访问顺序
这是普通人能想到的,最简单算法。
A 的局部性原理利用的很好,
所以很快。
但是 B 的局部性原理排第二,
B 是按照列来访问的。
这很慢。跨越了局部性 cache
所以 B 也应该按照行访问。但是公式是死的。
C[m, n] += A[m, k] * B[k, n]
我们注意到,对于 C 来说,虽然没有使用 k,但它也是满足局部性原理的
我们的最终目的是,优先按行访。问
也就是优先迭代 n, k ,n。
这似乎是矛盾的。
k 是 A 的列索引(2 级索引),B 的行索引(1 级索引)。
不可能最优先访问。
所以就有了第二个想法。
迭代顺序为: m->k->n
1 |
|
那么这是不是最快的呢?
这样我们实现了 BC 的按行访问
那么 A 呢?
由于 m 仍然在 k 外层:M -> K -> N
所以 _A 也是按行访问的_。
其实所有的迭代顺序种类为 A(3, 3) = 3! = 6
CSAPP 测试了所有访问顺序
由于我们有三个行索引:m, m, k
三个列索引:n, k, n
所以,只要满足:
任取行索引 x,列索引 y,
满足 x 在 y 的外层,
就是最快算法;
反之是最慢算法
重要表格
所以所有算法的排序为
rank | 访问顺序 | m > n | m > k | k > n | score |
---|---|---|---|---|---|
1 | m->k->n | 1 | 1 | 1 | 3 |
2 | k->m->n | 1 | 0 | 1 | 2 |
2 | m->n->k | 1 | 1 | 0 | 2 |
3 | n->m->k | 0 | 1 | 0 | 1 |
3 | n->m->k | 0 | 1 | 0 | 1 |
4 | n->k->m | 0 | 0 | 0 | 0 |
m | n | k |
---|---|---|
1 级索引 | 2 级索引 | 1 级索引+2 级索引 |
C[m, n] += A[m, k] * B[k, n]
结论
因此,最快矩阵乘法为
1 |
|