跳转至

cublas中矩阵乘及其广播机制的实现与单元测试

一、准备工作

要想测试cuda函数,如果采用原生的验证方式就需要自己在cpu上将在gpu上的kernel全都实现一遍,这个工作量是很大的。此外对于显存的管理和cublas参数的配置也较为复杂,对这些过程简化的方式就是封装,对显存和内存的管理进行封装。为了解决验证的问题,还需要将运行在kernel的数据导出,采用python的numpy库进行验证。面对这样的问题,本项工作将cnpy封装进用于管理显存和内存的Tensor类中,Tensor类的原版参考TensorRT-Pro,我手动增加了一些类型支持和cnpy的导出功能。 验证流程如下:

  1. cpp/cu实现cuda/cublas函数,并将输入输出的tensor导出为npy/npz/bin
  2. 采用python中的numpy读取保存的二进制文件,并将cuda函数实现的功能采用寥寥几行python代码实现一边,将python计算的结果与cuda结果比对。

本项目全部代码开源在: https://github.com/thb1314/cublas_matmul

二、python端验证代码

这里以验证矩阵elementwise加法为例。

import numpy as np
def _load_tensor(file):
    with open(file, "rb") as f:
        binary_data = f.read()
    magic_number, ndims, dtype = np.frombuffer(binary_data, np.uint32, count=3, offset=0)
    assert magic_number == 0xFCCFE2E2, f"{file} not a tensor file."
    dims = np.frombuffer(binary_data, np.uint32, count=ndims, offset=3 * 4)
    if dtype == 0:
        np_dtype = np.float32
    elif dtype == 1:
        np_dtype = np.float16
    else:
        assert False, f"Unsupport dtype = {dtype}, can not convert to numpy dtype"
    return np.frombuffer(binary_data, np_dtype, offset=(ndims + 3) * 4).reshape(*dims)
def load_tensor(file):
    if file.endswith("npz"):
        return np.load(file)['data']
    elif file.endswith("npy"):
        return np.load(file)
    else:
        return _load_tensor(file)
def test():
    p_tensor = load_tensor('p_tensor.npz')
    q_tensor = load_tensor('q_tensor.npz')
    out_tensor = load_tensor('out_tensor.npz')
    out = p_tensor + q_tensor
    print(np.abs(out - out_tensor).max())

if __name__ == "__main__":
    test()

三、矩阵乘法的实现

3.1 普通矩阵乘

  1. 行主维与列主维

在C/C++中数据的存储是按照线性的,是按照行优先存储的。比如对于一个两行三列的二维矩阵[[1,2,3],[4,5,6]],那么按照行优先存储在内存中为[1,2,3,4,5,6]。但是在一些语言(比如Fortran)和cublas中是按照列优先存储来读取数据的。同样是表示[[1,2,3],[4,5,6]],按照列优先存储在内存中应该为[1,4,2,5,3,6]。如果是[1,2,3,4,5,6]按照列优先读取为两行三列的二维矩阵的话,应该表示为[[1,3,5],[2,4,6]],这显然与原来的矩阵对不上。但是,如果按照三行两列的二维矩阵来读取,可以表示为[[1,4],[2,5],[3,6]],会发现按照行优先存储的矩阵按照列优先方式读取为其转置形状的话,可以理解为其在行优先存储形式的矩阵的转置。即我们在申请一块显存按照行优先来存储矩阵,那么在cublas使用时要按照其转置的形状来理解。

  1. 矩阵乘的转换

给定C = A @ B,那么C.T = B.T @ A.TC.T是相对于cublas来说的,其内存分布按照行优先存储其实为C,所以我们在调用cublas函数时,需要将所有矩阵按照转置维度来理解,并且需要注意第一个矩阵的参数为B.T,而不是A了。这样就引出了第一个cublas的矩阵乘法函数。

cublasStatus_t cublasGemmEx(cublasHandle_t handle,
                           cublasOperation_t transa,
                           cublasOperation_t transb,
                           int m,
                           int n,
                           int k,
                           const void    *alpha,
                           const void     *A,
                           cudaDataType_t Atype,
                           int lda,
                           const void     *B,
                           cudaDataType_t Btype,
                           int ldb,
                           const void    *beta,
                           void           *C,
                           cudaDataType_t Ctype,
                           int ldc,
                           cublasComputeType_t computeType,
                           cublasGemmAlgo_t algo);
这里我们计算Q @ P,其中Q的形状为m x kP的形状为k x n。下面给出关键代码
float alpha = 1.0;
float beta = 0.0;
// the LDA is used to define the distance in memory between elements of two consecutive columns which have the same row index
// B [k,n] B.T [n,k]
// A [m,k] A.T [k,m]
// C [m,n] C.T [n,m]
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
cudaDataType_t qType = CUDA_R_32F;
cudaDataType_t pType = CUDA_R_32F;
cudaDataType_t oType = CUDA_R_32F;
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cublasGemmEx(mCublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha,
            pptr_gpu, pType, n,
            qptr_gpu, qType, k,
            &beta,
            outptr_gpu, oType, n,
            computeType, algo
);
详情可以参考git链接中的02cublas_test_matmul

3.2 batched 矩阵乘

batched矩阵乘法git中给出两种api,第一个是cublasGemmBatchedEx,第二个是cublasGemmStridedBatchedExcublasGemmBatchedEx需要将一个批次中的单个数据理解为一个指针,整个batch存储为一个指针数据。 cublasGemmStridedBatchedEx是将一个批次的数据理解为一块连续内存,单个矩阵数据之间相隔stride的距离。 理解了上面的这些知识点和列主维存储,对于A[b,m,k] @ B[b,k,n]的实现就相对简单了。 下面给出关键代码

1
2
3
4
5
6
7
8
CUBLASASSERT(cublasGemmBatchedEx(mCublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
                                &alpha,
                                (void**)pptr_gpu_arr_tensor.to_gpu(true).gpu<float*>(), pType, n,
                                (void**)qptr_gpu_arr_tensor.to_gpu(true).gpu<float*>(), qType, k,
                                &beta,
                                (void**)outptr_gpu_arr_tensor.to_gpu().gpu<float*>(), oType, n,
                                b, computeType, algo
                                ));
1
2
3
4
5
6
7
8
CUBLASASSERT(cublasGemmStridedBatchedEx(mCublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
                                        &alpha,
                                        pptr_gpu, pType, n, p_tensor.size(1) * p_tensor.size(2),
                                        qptr_gpu, qType, k, q_tensor.size(1) * q_tensor.size(2),
                                        &beta,
                                        outptr_gpu1, oType, n, out_tensor1.size(1) * out_tensor1.size(2),
                                        b, computeType, algo
                                        ));

3.3 广播的矩阵乘

  1. 较为简单的形式

对于 A[b,m,k] @ B[1,k,n],在采用cublasGemmStridedBatchedEx API时我们只需要将B的stride设置为0,这样就可以实现b个batch读取的都是相同的B。 关键代码如下

1
2
3
4
5
6
7
8
9
// Q[0:1] @ P    Q[0:1] 会广播到 P 的 batch
    CUBLASASSERT(cublasGemmStridedBatchedEx(mCublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
                                            &alpha,
                                            pptr_gpu, pType, n, p_tensor.size(1) * p_tensor.size(2),
                                            qptr_gpu, qType, k, 0,
                                            &beta,
                                            outptr_gpu2, oType, n, out_tensor2.size(1) * out_tensor2.size(2),
                                            b, computeType, algo
                                            ));

  1. 较为复杂的形式

对于A[1,s,m,k] @ B[b,s,k,n],numpy的做法时将B广播为B[b,s,k,n],那么该如何模拟这种行为呢?实际上还是拆解为已知的方法。 拆解方法如下

1
2
3
4
C[0, s, :m, :n] = A[0, s, :m, :k] @ B[0, s, :k, :n]
C[1, s, :m, :n] = A[0, s, :m, :k] @ B[1, s, :k, :n]
C[2, s, :m, :n] = A[0, s, :m, :k] @ B[2, s, :k, :n]
...
这样我们需要循环b次来完成上面的广播,并且batch参数设置为s,A的stride为mk。B的stride为kn,B的地址需要动态计算。C的stride为m*n,C的地址在每次循环时也需要动态计算。 相关代码实现如下
// 代码中Q=A,P=B
for(int i = 0; i < b; ++i) {
        CUBLASASSERT(cublasGemmStridedBatchedEx(mCublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
        &alpha,
        pptr_gpu + i * p_tensor.size(1) * p_tensor.size(2) * p_tensor.size(3), pType, n, p_tensor.size(2) * p_tensor.size(3),
        qptr_gpu, qType, k, q_tensor.size(2) * q_tensor.size(3),
        &beta,
        outptr_gpu2 + i * out_tensor2.size(1) * out_tensor2.size(2) * out_tensor2.size(3), oType, n, out_tensor2.size(2) * out_tensor2.size(3),
        s, computeType, algo
        ));
    }
下面看另外一种广播形式的实现, A[b,1,m,k] @ B[b,s,k,n],照旧还是需要转换为已有带stride形式。
1
2
3
4
C[b, 0, :m, :n] = A[b, 0, :m, :k] @ B[b, 0, :k, :n]
C[b, 1, :m, :n] = A[b, 0, :m, :k] @ B[b, 1, :k, :n]
C[b, 2, :m, :n] = A[b, 0, :m, :k] @ B[b, 2, :k, :n]
...
相应的stride和地址都需要做更改,这里A的stride为s x m x k的原因是想偷个懒,采用原来的数据做一些切片。 B的stride为s x k x n,C的stride为s x m x n。 关键代码如下:
for(int i = 0; i < s; ++i) {
        CUBLASASSERT(cublasGemmStridedBatchedEx(mCublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k,
        &alpha,
        pptr_gpu + (0 * p_tensor.size(1) +  i) * p_tensor.size(2) * p_tensor.size(3), pType, n, p_tensor.size(1) * p_tensor.size(2) * p_tensor.size(3),
        qptr_gpu, qType, k, q_tensor.size(1) * q_tensor.size(2) * q_tensor.size(3),
        &beta,
        outptr_gpu3 + (0 * out_tensor3.size(1) +  i) * out_tensor3.size(2) * out_tensor3.size(3), oType, n, out_tensor3.size(1) * out_tensor3.size(2) * out_tensor3.size(3),
        b, computeType, algo
        ));
}

四、总结

本文逐步给出cublas中矩阵乘法及其广播形式的实现,并给出一种验证cuda函数的较为通用的框架,希望能给读者带来一些启发。


最后更新: March 21, 2024
创建日期: March 21, 2024