一个计算的工具
https://xmartlabs.github.io/cuda-calculator/
什么是GeMM
说实话我一开始以为GeMM和MM是一个东西,仔细看完之后才发现其实有点区别,前者实际上就是对后者在计算机系统领域的一个抽象。
从数学的角度来谈,矩阵乘法有着相当复杂的优化空间而且看不懂,但是从计算机的角度是可以实现更加高效的计算和处理的,接下来会从CPU硬件架构慢慢到GPU架构下进行。
CPU下的优化
不多废话了。
最简单的肯定如上,但是时间是$O(n^3)$
对应的伪代码很简单:
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
C[m][n] = 0;
for (int k = 0; k < K; k++) {
C[m][n] += A[m][k] * B[k][n];
}
}
}
对于这样的算法优化一般可以分为两类:
-
算法分析的优化:比如 Strassen 算法和Coppersmith–Winograd 算法
-
基于软件的优化:根据计算机存储系统层次调整顺序,通常有循环拆分、内存重排等
前者这里就不多说明,毕竟这部分确实是数学家该做的事情,我们要做的是如何利用架构尽可能优化后者的情况。
根据空间连续性和时间连续性
说实话这个图拉出来就知道发生什么了:
因为线性存储且缓存,所以每次四块数据会保存在缓存中,因此可以展开一下循环来加快速度(实际上这种展开是看存储器结构的):
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n += 4) {
C[m][n + 0] = 0;
C[m][n + 1] = 0;
C[m][n + 2] = 0;
C[m][n + 3] = 0;
for (int k = 0; k < K; k++) {
C[m][n + 0] += A[m][k] * B[k][n + 0];
C[m][n + 1] += A[m][k] * B[k][n + 1];
C[m][n + 2] += A[m][k] * B[k][n + 2];
C[m][n + 3] += A[m][k] * B[k][n + 3];
}
}
}
直到这一步还是可以看出来的,接下来操作实际上是对输出进行的,因为输出同样满足上面的缓存性质:
所以就变成了:
for (int m = 0; m < M; m += 4) {
for (int n = 0; n < N; n += 4) {
C[m + 0][n + 0..3] = 0;
C[m + 1][n + 0..3] = 0;
C[m + 2][n + 0..3] = 0;
C[m + 3][n + 0..3] = 0;
for (int k = 0; k < K; k++) {
C[m + 0][n + 0..3] += A[m + 0][k] * B[k][n + 0..3];
C[m + 1][n + 0..3] += A[m + 1][k] * B[k][n + 0..3];
C[m + 2][n + 0..3] += A[m + 2][k] * B[k][n + 0..3];
C[m + 3][n + 0..3] += A[m + 3][k] * B[k][n + 0..3];
}
}
}
由于中间对于C矩阵需要进行归约操作,因此理论上存储到寄存器的速度最快,所以可以分解成$4 \times 4$的小块进行
因为伪代码无法提供寄存器优化部分所以看看整体的一个图:
此外,根据矩阵变量的精度不同也可以根据内存块继续进行优化(这时候就要底层到数据结构在内存中的排布了)
此时源代码就变成了:
for (int mo = 0; mo < M; mo += 8) {
for (int no = 0; no < N; no += 8) {
for (int mi = 0; mi < 2;mi ++) {
for (int ni = 0; ni < 2; ni++) {
int m = mo + mi * 4;
int n = no + ni * 4;
C[m + 0..3][n + 0..3] = 0;
C[m + 0..3][n + 0..3] = 0;
C[m + 0..3][n + 0..3] = 0;
C[m + 0..3][n + 0..3] = 0;
for (int k = 0; k < K; k += 4) {
C[m + 0..3][n + 0..3] += A[m + 0..3][k + 0] * B[k + 0][n + 0..3];
C[m + 0..3][n + 0..3] += A[m + 0..3][k + 1] * B[k + 1][n + 0..3];
C[m + 0..3][n + 0..3] += A[m + 0..3][k + 2] * B[k + 2][n + 0..3];
C[m + 0..3][n + 0..3] += A[m + 0..3][k + 3] * B[k + 3][n + 0..3];
}
}
}
}
}
说实话,到这里,CPU上的优化就差不多了,但是上述的优化在GPU上还能继续大放光彩。不管怎么说先写一版CPU计算的C代码来看看:
#define OFFSET(row, col, ld) ((row) * (ld) + (col))
void cpuSgemm(
float *a, float *b, float *c, const int M, const int N, const int K) {
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
float psum = 0.0;
for (int k = 0; k < K; k++) {
psum += a[OFFSET(m, k, K)] * b[OFFSET(k, n, N)];
}
c[OFFSET(m, n, N)] = psum;
}
}
}
CUDA来咯
CPU还是有一定的扩展性的,但是在现在的时代还是用CUDA做并行加速更强大一点。
首先写一个最简单的GPU版的GeMM:
__global__ void naiveSgemm(
float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
const int M, const int N, const int K) {
int n = blockIdx.x * blockDim.x + threadIdx.x;
int m = blockIdx.y * blockDim.y + threadIdx.y;
if (m < M && n < N) {
float psum = 0.0;
#pragma unroll
for (int k = 0; k < K; k++) {
psum += a[OFFSET(m, k, K)] * b[OFFSET(k, n, N)];
}
c[OFFSET(m, n, N)] = psum;
}
}
const int BM = 32, BN = 32;
const int M = 512, N = 512, K = 512;
dim3 blockDim(BN, BM);
dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM);
这个很多简单倒是,每个线程处理一个矩阵上的数据,但是显然上述代码没有实现任何存储结构,计算结构上的优化,因此需要大量优化。
首先说一下上面的代码的从内存、计算上的流程:
- 在Globel Memory中为三个矩阵分配存储空间
- 每个矩阵C计算独立,所以每个thread对应一个值的计算
- 执行线程的配置(说实话,这种配置相较于之前学习时使用的配置手段好理解的多,但是从直觉上性能应该不高(因为线程束不相邻))
$$girdDim.x \times blockDim.x = N$$
$$gridDim.y \times blockDim.y = M$$
每个thread的workflow如下:从矩阵A中读取长度为k的向量,从矩阵B中读取长度为k的列向量,做循环点积计算,最后写回C矩阵,整体的读写相当花费带宽:
$$K\times M\times N\times 4 Bytes +M\times N \times 4 Bytes$$
由于32个线程属于一个线程束(看架构),所以读取矩阵B的时候可以一次读取32列数据。尽管如此还是差太多了。
优化第一步: 共享内存
一次累加运算需要两次global memory的load才能实现,这种访存导致性能相当低。所以可以把一些数据放到shared memory。
首先把矩阵C分成$$BM\times BN$$大小的分块,每个分块由一个block计算,其中每个Thread计算矩阵中的$TM\times TN$个元素,之后计算的数据就可以从一个smem中读取了(一个线程束放在一个block中,而这部分shared memory是共享的)
接下来是一个复杂但是很重要的分析:
首先分块之后,对于每个分块有
计算量:$BM\times BN\times K \times 2$
访存量:$(BM + BN)\times K \times 4 Bytes$
计算访存量两者比一下就好了,结果为$\frac{BM\cdot BN}{2(BM+BN)}=\frac{1}{2(\frac{1}{BN}+\frac{1}{BM})}$
显然,BM和BN越大,计算访存比越高,性能就会越好。但是在基础部分实际上学过了,受到各个因素的限制,这些个数字是不能无限制增大的。
首先是shared memory的大小,对于V100,1个SM仅仅的shared memory只有96KB,但是一个Block的数据要占用:$BK * (BM+BN)*4 Bytes$。
再者,TM和TN也是受限的,首先,对于不同架构,有着对Block中线程数量总数的限制,在V100中一个block的线程数量不能超过1024,且如果太小的话会影响SM中Block间的并行。此外,寄存器数量有限,一个线程就需要$TM\times TN$个寄存器来存放结果,由于总数不能超过256,否则也会影响并并行效果。
上面的分析虽然复杂但是相当有效且有用。
基于上述考量,选择$BM=BN=128, BK=8, TM=TN=8$,此时代码长这个样子(代码建议配合图食用):
#define FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])
\*
FLOAT4(pointer) 用于将指针 pointer 转换为 float4 类型的指针,并访问其第一个元素。具体来说,它的作用是将 pointer 强制转换为 float4* 类型,并返回指向该 float4 类型的第一个元素的引用。
*/
__global__ void sgemm_V1(
float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
const int M, const int N, const int K) {
const int BM = 128;
const int BN = 128;
const int BK = 8;
const int TM = 8;
const int TN = 8;
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tid = ty * blockDim.x + tx;
__shared__ float s_a[BM][BK];
__shared__ float s_b[BK][BN];
float r_c[TM][TN] = {0.0};
int load_a_smem_m = tid >> 1; // tid/2, row of s_a
int load_a_smem_k = (tid & 1) << 2; // (tid % 2 == 0) ? 0 : 4, col of s_a
int load_b_smem_k = tid >> 5; // tid/32, row of s_b
int load_b_smem_n = (tid & 31) << 2; // (tid % 32) * 4, col of s_b
int load_a_gmem_m = by * BM + load_a_smem_m; // global row of a
int load_b_gmem_n = bx * BN + load_b_smem_n; // global col of b
for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
int load_a_gmem_k = bk * BK + load_a_smem_k; // global col of a
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
FLOAT4(s_a[load_a_smem_m][load_a_smem_k]) = FLOAT4(a[load_a_gmem_addr]);
int load_b_gmem_k = bk * BK + load_b_smem_k; // global row of b
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr]);
__syncthreads();
#pragma unroll
for (int k = 0; k < BK; k++) {
#pragma unroll
for (int m = 0; m < TM; m++) {
#pragma unroll
for (int n = 0; n < TN; n++) {
int comp_a_smem_m = ty * TM + m;
int comp_b_smem_n = tx * TN + n;
r_c[m][n] += s_a[comp_a_smem_m][k] * s_b[k][comp_b_smem_n];
}
}
}
__syncthreads();
}
#pragma unroll
for (int i = 0; i < TM; i++) {
int store_c_gmem_m = by * BM + ty * TM + i;
#pragma unroll
for (int j = 0; j < TN; j += 4) {
int store_c_gmem_n = bx * BN + tx * TN + j;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][j]);
}
}
}
说实话相较于之前教程的代码看起来有点乱,但是不影响我们分析。其pipeline如下:
- 将矩阵分块$A_{[BM, BK]},B_{[BK, BN]}$放到shared memory中
这里说实话有点相当复杂。首先是对于每块,$A_{[BM, BK]}$每个tbread需要搬运$\frac{BKTMTN}{BN}$在这里是4个浮点数,这正好可以用CUDA的一个float4
数据结构进行存储(很显然,一个8B*4正好能对齐),这时候,对于上述配置的分块,其索引关系如下左
这时候就可以考虑把数据放到共享存储里了,但是对应的是需要一个索引和转储的过程,也就是后面要操作的s_a
,s_b
对象,不难理解,我们要做的就是把一部分存储放在shared memory用来减少访存的次数,在这种情况下,load_a_smem_m=tid/2=tid >> 2
就是s_a
的行号。对应的列号load_a_smem_k = (tid % 2 == 0) ? 0 : 4 = (tid & 1) << 2
实际上是线程在shared memory的索引,同理可以得到矩阵B的分布 int load_b_smem_k = tid >> 5
, int load_b_smem_n = (tid & 31) << 2
上面只不过是单个block的执行过程,在多个block索引分块的时候Global Memory的对应关系还是有变化的,还是以矩阵A为例子,分块$A_{[BM,BK]}$按着行进行,所以首先确定行号,根据Grid的二维全局线性索引关系,则分块的起始行号应该是by*BM
全局的行号就应该是load_a_gmem_m = by * BM + load_a_smem_m
。对于列号有所不同,分块沿着行方向进行,所以列是变化的,需要在循环内部进行计算,先计算起始列号bk*BK
加上分块内部的列号load_a_smem_k
可以得到load_a_gmem_k = bk*BK+load_a_smem_k
从而确定分块在原始数据中的位置OFFSET(load_a_gmem_m, load_a_gmem_k, K)
- 计算分块矩阵$C_{[TM, TN]}$知道
s_a, s_b
之后计算得到对应的r_c
即可。然后存入global memory。当然这个过程也是复杂的索引变换过程
优化第二步: 解决Bank Conflict问题
上面大大提高了访存效率从而提高性能,下一步是继续优化共享内存的使用。
这一步优化其实还是挺印象深刻的,因为共享内存分为32个bank,每个bank宽度为4B,如果多次访问同一个Bank的数据,就会导致Bank Conflict问题,这个解决方案之前就是错位。
先看看前面矩阵乘法导致的Bank Conflict问题
- 去矩阵A需要取一个列向量,而A在shared memory中是按行存储的,从而conflict了
- 此外,当TM=TN=8时,需要从shared memory中取连续8个地址,一条指令取四个数就需要两个指令,由于一个线程的两个load指令地址是连续的,此外,由于同一个warp不同线程同一条load指令的访存地址是隔开,所以一次是同时对一个bank进行tid/2数量的访存,同样会导致bank conflict
所以需要进行两点优化
- 为A分配的时候转置一下,按列进行存储。
- 将每个线程负责计算的TM*TN划分成两个,由于一条指令实现A的一块load操作,所以两个load可以同时执行
(虽然这里没有提到,但是在之前的编程中,我们知道实际上也是可以通过错位手段来解决bank conflict问题的)
__global__ void sgemm_V2(
float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
const int M, const int N, const int K) {
const int BM = 128;
const int BN = 128;
const int BK = 8;
const int TM = 8;
const int TN = 8;
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tid = ty * blockDim.x + tx;
__shared__ float s_a[BK][BM];
__shared__ float s_b[BK][BN];
float r_load_a[4];
float r_load_b[4];
float r_comp_a[TM];
float r_comp_b[TN];
float r_c[TM][TN] = {0.0};
int load_a_smem_m = tid >> 1;
int load_a_smem_k = (tid & 1) << 2;
int load_b_smem_k = tid >> 5;
int load_b_smem_n = (tid & 31) << 2;
int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;
for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
int load_a_gmem_k = bk * BK + load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
int load_b_gmem_k = bk * BK + load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
s_a[load_a_smem_k ][load_a_smem_m] = r_load_a[0];
s_a[load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
s_a[load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
s_a[load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
__syncthreads();
#pragma unroll
for (int tk = 0; tk < BK; tk++) {
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[tk][ty * TM / 2 ]);
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[tk][ty * TM / 2 + BM / 2]);
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[tk][tx * TN / 2 ]);
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[tk][tx * TN / 2 + BN / 2]);
#pragma unroll
for (int tm = 0; tm < TM; tm++) {
#pragma unroll
for (int tn = 0; tn < TN; tn++) {
r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
}
}
}
__syncthreads();
}
#pragma unroll
for (int i = 0; i < TM / 2; i++) {
int store_c_gmem_m = by * BM + ty * TM / 2 + i;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
}
#pragma unroll
for (int i = 0; i < TM / 2; i++) {
int store_c_gmem_m = by * BM + BM / 2 + ty * TM / 2 + i;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
}
}
优化第三步:流水并行化 Double Buffering
通过增加缓存使得整个过程称为流水线以减少等待时间,提高效率:
从代码的话其实看不出来哪里改了:
__global__ void sgemm_V3(
float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
const int M, const int N, const int K) {
const int BM = 128;
const int BN = 128;
const int BK = 8;
const int TM = 8;
const int TN = 8;
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tid = ty * blockDim.x + tx;
__shared__ float s_a[2][BK][BM];
__shared__ float s_b[2][BK][BN];
float r_load_a[4];
float r_load_b[4];
float r_comp_a[TM];
float r_comp_b[TN];
float r_c[TM][TN] = {0.0};
int load_a_smem_m = tid >> 1;
int load_a_smem_k = (tid & 1) << 2;
int load_b_smem_k = tid >> 5;
int load_b_smem_n = (tid & 31) << 2;
int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;
{
int load_a_gmem_k = load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
int load_b_gmem_k = load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
s_a[0][load_a_smem_k ][load_a_smem_m] = r_load_a[0];
s_a[0][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
s_a[0][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
s_a[0][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
FLOAT4(s_b[0][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
}
for (int bk = 1; bk < (K + BK - 1) / BK; bk++) {
int smem_sel = (bk - 1) & 1;
int smem_sel_next = bk & 1;
int load_a_gmem_k = bk * BK + load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
int load_b_gmem_k = bk * BK + load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
#pragma unroll
for (int tk = 0; tk < BK; tk++) {
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 ]);
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 + BM / 2]);
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 ]);
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 + BN / 2]);
#pragma unroll
for (int tm = 0; tm < TM; tm++) {
#pragma unroll
for (int tn = 0; tn < TN; tn++) {
r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
}
}
}
s_a[smem_sel_next][load_a_smem_k ][load_a_smem_m] = r_load_a[0];
s_a[smem_sel_next][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
s_a[smem_sel_next][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
s_a[smem_sel_next][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
FLOAT4(s_b[smem_sel_next][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
__syncthreads();
}
#pragma unroll
for (int tk = 0; tk < BK; tk++) {
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[1][tk][ty * TM / 2 ]);
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[1][tk][ty * TM / 2 + BM / 2]);
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[1][tk][tx * TN / 2 ]);
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[1][tk][tx * TN / 2 + BN / 2]);
#pragma unroll
for (int tm = 0; tm < TM; tm++) {
#pragma unroll
for (int tn = 0; tn < TN; tn++) {
r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
}
}
}
#pragma unroll
for (int i = 0; i < TM / 2; i++) {
int store_c_gmem_m = by * BM + ty * TM / 2 + i;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
}
#pragma unroll
for (int i = 0; i < TM / 2; i++) {
int store_c_gmem_m = by * BM + BM / 2 + ty * TM / 2 + i;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
}
}
实际上核心在于这一段:
__shared__ float s_a[2][BK][BM];
__shared__ float s_b[2][BK][BN];
float r_load_a[4];
float r_load_b[4];
float r_comp_a[TM];
float r_comp_b[TN];
float r_c[TM][TN] = {0.0};
int load_a_smem_m = tid >> 1;
int load_a_smem_k = (tid & 1) << 2;
int load_b_smem_k = tid >> 5;
int load_b_smem_n = (tid & 31) << 2;
int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;
{
int load_a_gmem_k = load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
int load_b_gmem_k = load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
s_a[0][load_a_smem_k ][load_a_smem_m] = r_load_a[0];
s_a[0][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
s_a[0][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
s_a[0][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
FLOAT4(s_b[0][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
}
for (int bk = 1; bk < (K + BK - 1) / BK; bk++) {
int smem_sel = (bk - 1) & 1;
int smem_sel_next = bk & 1;
int load_a_gmem_k = bk * BK + load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
int load_b_gmem_k = bk * BK + load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
#pragma unroll
for (int tk = 0; tk < BK; tk++) {
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 ]);
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 + BM / 2]);
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 ]);
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 + BN / 2]);
#pragma unroll
for (int tm = 0; tm < TM; tm++) {
#pragma unroll
for (int tn = 0; tn < TN; tn++) {
r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
}
}
}
s_a[smem_sel_next][load_a_smem_k ][load_a_smem_m] = r_load_a[0];
s_a[smem_sel_next][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
s_a[smem_sel_next][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
s_a[smem_sel_next][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
FLOAT4(s_b[smem_sel_next][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
__syncthreads();
}
这里专门用大括号划出来了一段作用域,看起来只是拆了一个循环出来,实际上还是有相当大的变化的。如果是在一个作用域内则之间指令是无法判断其相互独立的,从而导致等待,实际上共享内存取出之后就可以暂时放弃掉,通过划分这个作用域使得编译器知道这部分和下面的内容是独立的,总而可以直接进行写回操作,从而提高了性能。
通过上述操作实际上都能达到Cublas的水准了。(建议常看常新)