准备开一个专题,写写pytorch中常用的CUDA算子源码,作为基础储备,索性就从最基础的Reduce操作开始吧!
0x00 前言
在神经网络的计算中,对一组数据进行规约是非常常见的操作,譬如规约求和(ReduceSum)、规约求最大最小值(ReduceMax, RedeceMin)等基础算子。
在CUDA中进行数据规约是每一个CUDA学习者绕不开的训练,在《CUDA C编程权威指南》中,以在全局内存中进行规约作为baseline,作者分别介绍如下优化方法:
- 避免线程束中的分支分化
- 使用交错配对而不是相邻配对
- 使用展开循环
- 使用模板函数(编译时决定代码分支)
- 使用共享内存 + 线程展开
- 使用动态共享内存 + 线程展开
以上优化的具体内容我们这里不表,因为是很基础的操作了,并且在实际项目开发中这些技术也已经过时,不能满足实际需求,感兴趣的可以自行翻阅,只需要理解每一步的优化思路即可,毕竟思路是永不过时的。
这篇博客主要对Pytorch源码中使用的BlockReduceSum CUDA算子进行分析,源码路径如下: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/block_reduce.cuh#L71
0x01 Lane & Warp
BlockReduceSum是针对一个Block中的数据进行规约的,因此我们得先介绍CUDA执行模型中的一个重要概念warp,我们知道当一个block被调度到一个SM上时,block中的所有threads会被划分为一簇线程束(warp),线程束是GPU上一组并行执行的线程,每个线程独立执行指令,但在一个线程束中的线程会以SIMD(单指令多数据)方式并行执行相同的指令,但处理不同的数据。每个warp包含连续的32个threads,因此每个block中的线程束数量可以如下计算:
$$线程束数量 = ceil\begin{pmatrix} \frac {block中的线程数量} {线程束大小}\end{pmatrix}$$
一个block中最多有1024个threads,线程束大小一般都是32,因此一个block中的线程束数量最多也只有32个,如下图纵坐标所示;
lane意为车道,在并发系统中,一般指执行并发任务的单个执行单元,它可以是一个线程、一个进程或其他执行单元的抽象,在CUDA中,通常指的是在一个线程束(warp)中的单个线程,如下图中横坐标表示。
0x02 源码分析
源码:
template <typename T, typename B = Block1D>
__inline__ __device__ T BlockReduceSum(T val, T* shared) {
const int tid = B::Tid();
const int lid = tid % C10_WARP_SIZE;
const int wid = tid / C10_WARP_SIZE;
val = WarpReduceSum(val);
__syncthreads(); // prevent races when BlockReduces are called in a row.
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
val = (tid < B::Warps()) ? shared[lid] : T(0);
if (wid == 0) {
val = WarpReduceSum(val);
}
return val;
}
首先该函数使用了模板函数,T为数据类型,B为Block线程组织形状,这里默认使用1D的Block单元:
struct Block1D {
static __forceinline__ __device__ int Tid() { return threadIdx.x; }
static __forceinline__ __device__ int Warps() {
return blockDim.x / C10_WARP_SIZE;
}
};
- 回到源码继续分析,函数接受两个参数,val为初值,shared为T类型的指针,可传入共享内存地址;
- 获取block中的线程索引tid = threadidx.x;
- 通过tid对C10_WARP_SIZE取余获取lid;
- 用tid整除C10_WARP_SIZE获得wid;(lid, wid可结合上述示意图理解,应该很好理解)
- 调用WarpReduceSum函数做一次规约,下面我们走进源码继续看。
// Sums `val` across all threads in a warp.
//
// Assumptions:
// - The size of each block should be a multiple of `C10_WARP_SIZE`
template <typename T>
__inline__ __device__ T WarpReduceSum(T val) {
#pragma unroll
for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
val += WARP_SHFL_DOWN(val, offset);
}
return val;
}
该函数使用一个for循环,offset初始化为C10_WARP_SIZE » 1,即右移一位,为16,以后每次继续右移一位,接下来循环体这里使用了一个宏函数WARP_SHFL_DOWN.
template <typename T>
__device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if !defined(USE_ROCM)
return __shfl_down_sync(mask, value, delta, width);
#else
return __shfl_down(value, delta, width);
#endif
}
__shfl_down_sync(mask, value, delta, width) 是一个CUDA内置函数,属于Warp Shuffle Functions,这一系列函数可以对一个warp内的数据进行交换或者亦或操作,只在CUDA计算能力5.0及以上的卡上才被支持,这样的函数一共有四个,分别是:
T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize);
T __shfl_up_sync(unsigned mask, T var, unsigned int delta, int width=warpSize);
T __shfl_down_sync(unsigned mask, T var, unsigned int delta, int width=warpSize);
T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width=warpSize);
回到循环体中,代码逐步将变量 val 的值从线程束中的高id位向低id位进行归约求和,使用了 __shfl_down_sync 函数来进行数据交换和归约操作。具体步骤如下:
- 初始化变量 offset 为线程束大小的一半(C10_WARP_SIZE » 1)。这是因为归约操作是通过不断减少偏移量来逐步将数据从高位向低位进行求和。
- 进入循环,循环条件是 offset 大于 0。
- 在循环体内,__shfl_down_sync 函数用于将位于 val 的值在线程束内向下偏移 offset 个位置进行交换。该函数的参数解释如下:
mask = 0xffffffff 表示使用全掩码,即所有线程都参与数据交换。
val 是当前线程的值,将被交换到其他线程。
offset 是偏移量,用于指定交换的目标线程。
width 在这里表示线程束的大小。
-
将 __shfl_down_sync 返回的值与 val 相加,得到更新后的 val 值。
-
在每次循环迭代结束后,将 offset 右移一位(offset »= 1),以便下一次迭代时只考虑更低位的线程。
-
通过这个循环,线程束内的每个线程将与其他线程进行数据交换并进行逐步归约求和,直到warp中所有lane的数据都被累加到val中,并且这个值存放在lane0的位置,因为时从高往低规约。
-
接下来继续回到主调用函数中,这里进行了一次 __syncthreads() 进行一次块内同步,这里源码有一句注释,同步是为了防止连续调用 BlockReduceSum 函数时出现竞争条件。试想,当多个线程块按顺序调用 BlockReduceSum 函数时,第一个线程块的归约操作可能还没有完成,第二个线程块就开始进行归约操作,便有可能导致错误的结果。
-
接着将lid == 0 的线程中的数据都复制到shared memery中;
-
进行一次同步,确保shared memery拿到所有需要的线程结果;
-
通过条件语句 (tid < B::Warps()) ? shared[lid] : T(0) ,根据当前线程的ID判断是否属于有效的线程束。如果是有效线程束内的线程,从shared数组中读取归约结果;否则,将val设为类型T的默认值T(0)。
-
如果线程束ID为0(即wid == 0),对val再次进行线程束内的归约求和操作。
-
返回最终的归约结果val。
0x03 参考链接
「真诚赞赏,手留余香」
真诚赞赏,手留余香
使用微信扫描二维码完成支付
