1
1
#include " ../../utilities/cuda/cublaslt_utils.cuh"
2
2
#include " cuda_kernel.hh"
3
3
#include " hardware/functions.h"
4
- #include " kernel/cuda/reduce.cuh"
4
+ #include " kernel/cuda/functions.cuh"
5
+ #include < cub/block/block_reduce.cuh>
5
6
6
7
namespace refactor ::kernel {
7
8
using K = AttentionCuda;
@@ -27,7 +28,7 @@ namespace refactor::kernel {
27
28
28
29
// gridDim.x = batch * nHead
29
30
// gridDim.y = seqLen
30
- // blockDim.x = min( 1024, attLen)
31
+ // blockDim.x = 1024
31
32
// sizeof(shared) = attLen * sizeof(float)
32
33
template <class T , class Mask >
33
34
static __global__ void softmax (
@@ -36,25 +37,34 @@ namespace refactor::kernel {
36
37
uint32_t attLen,
37
38
uint32_t bufLen) {
38
39
// 找到这个线程块对应的 attention 区域
39
- att += (blockIdx .x * gridDim .x + gridDim .y ) * bufLen;
40
+ att += (blockIdx .x * gridDim .y + blockIdx .y ) * bufLen;
40
41
// 将输入装入共享内存并 cast + mask
41
42
extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen
42
43
for (auto i = threadIdx .x ; i < attLen; i += blockDim .x ) {
43
44
shared[i] = mask (blockIdx .y , gridDim .y , i, attLen) ? float (att[i]) : -__FLT_MAX__;
44
45
}
45
46
47
+ using BlockReduce = cub::BlockReduce<float , 1024 >;
48
+ __shared__ typename BlockReduce::TempStorage tempStorage;
49
+ __shared__ float sharedMax, sharedSum;
50
+
46
51
float localMax = -1e20 ;
47
52
for (auto i = threadIdx .x ; i < attLen; i += blockDim .x ) {
48
53
localMax = cub::Max ()(localMax, shared[i]);
49
54
}
50
- localMax = cuda::blockReduce (localMax, -1e20f, cub::Max ());
55
+ localMax = BlockReduce (tempStorage).Reduce (localMax, cub::Max (), attLen);
56
+ if (threadIdx .x == 0 ) { sharedMax = localMax; }
57
+ __syncthreads ();
51
58
52
59
float localSum = 1e-20 ;
53
60
for (auto i = threadIdx .x ; i < attLen; i += blockDim .x ) {
54
- localSum += shared[i] = expf (shared[i] - localMax );
61
+ localSum += shared[i] = expf (shared[i] - sharedMax );
55
62
}
56
- localSum = cuda::blockReduce (localSum, 1e-20f , cub::Sum ());
57
- auto reciprocal = fdividef (1 , localSum);
63
+ localSum = BlockReduce (tempStorage).Reduce (localSum, cub::Sum (), attLen);
64
+ if (threadIdx .x == 0 ) { sharedSum = localSum; }
65
+ __syncthreads ();
66
+
67
+ auto reciprocal = fdividef (1 , sharedSum);
58
68
for (auto i = threadIdx .x ; i < attLen; i += blockDim .x ) {
59
69
att[i] = shared[i] * reciprocal;
60
70
}
0 commit comments