Skip to content

Commit bdad707

Browse files
committed
fix(kernel): 仍然使用 cub::BlockReduce 并改正 Attention
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent d9826f5 commit bdad707

File tree

4 files changed

+37
-47
lines changed

4 files changed

+37
-47
lines changed

src/04kernel/cuda/include/kernel/cuda/reduce.cuh

-31
This file was deleted.

src/04kernel/src/kernels/attention/cuda_kernel.cu

+17-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#include "../../utilities/cuda/cublaslt_utils.cuh"
22
#include "cuda_kernel.hh"
33
#include "hardware/functions.h"
4-
#include "kernel/cuda/reduce.cuh"
4+
#include "kernel/cuda/functions.cuh"
5+
#include <cub/block/block_reduce.cuh>
56

67
namespace refactor::kernel {
78
using K = AttentionCuda;
@@ -27,7 +28,7 @@ namespace refactor::kernel {
2728

2829
// gridDim.x = batch * nHead
2930
// gridDim.y = seqLen
30-
// blockDim.x = min(1024, attLen)
31+
// blockDim.x = 1024
3132
// sizeof(shared) = attLen * sizeof(float)
3233
template<class T, class Mask>
3334
static __global__ void softmax(
@@ -36,25 +37,34 @@ namespace refactor::kernel {
3637
uint32_t attLen,
3738
uint32_t bufLen) {
3839
// 找到这个线程块对应的 attention 区域
39-
att += (blockIdx.x * gridDim.x + gridDim.y) * bufLen;
40+
att += (blockIdx.x * gridDim.y + blockIdx.y) * bufLen;
4041
// 将输入装入共享内存并 cast + mask
4142
extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen
4243
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
4344
shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__;
4445
}
4546

47+
using BlockReduce = cub::BlockReduce<float, 1024>;
48+
__shared__ typename BlockReduce::TempStorage tempStorage;
49+
__shared__ float sharedMax, sharedSum;
50+
4651
float localMax = -1e20;
4752
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
4853
localMax = cub::Max()(localMax, shared[i]);
4954
}
50-
localMax = cuda::blockReduce(localMax, -1e20f, cub::Max());
55+
localMax = BlockReduce(tempStorage).Reduce(localMax, cub::Max(), attLen);
56+
if (threadIdx.x == 0) { sharedMax = localMax; }
57+
__syncthreads();
5158

5259
float localSum = 1e-20;
5360
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
54-
localSum += shared[i] = expf(shared[i] - localMax);
61+
localSum += shared[i] = expf(shared[i] - sharedMax);
5562
}
56-
localSum = cuda::blockReduce(localSum, 1e-20f, cub::Sum());
57-
auto reciprocal = fdividef(1, localSum);
63+
localSum = BlockReduce(tempStorage).Reduce(localSum, cub::Sum(), attLen);
64+
if (threadIdx.x == 0) { sharedSum = localSum; }
65+
__syncthreads();
66+
67+
auto reciprocal = fdividef(1, sharedSum);
5868
for (auto i = threadIdx.x; i < attLen; i += blockDim.x) {
5969
att[i] = shared[i] * reciprocal;
6070
}

src/04kernel/src/kernels/softmax/cuda_kernel.cu

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "cuda_kernel.hh"
2-
#include "kernel/cuda/reduce.cuh"
2+
#include <cub/cub.cuh>
33

44
namespace refactor::kernel {
55
using namespace runtime;
@@ -18,8 +18,8 @@ namespace refactor::kernel {
1818
template<> __device__ __forceinline__ nv_bfloat16 reciprocal<nv_bfloat16>(nv_bfloat16 x) { return hrcp(x); }
1919

2020
// blockDim.x === BLOCK_DIM
21-
template<class T>
22-
__global__ void blockSoftmaxKernel(
21+
template<int BLOCK_DIM, class T>
22+
__launch_bounds__(BLOCK_DIM) __global__ void blockSoftmaxKernel(
2323
T const *__restrict x,
2424
T *__restrict y,
2525
int mid,
@@ -40,8 +40,10 @@ namespace refactor::kernel {
4040
for (int i = threadIdx.x + blockDim.x; i < mid; i += blockDim.x) {
4141
maxSumThread = MaxSum::reduce(maxSumThread, {x[id + i * stride], 1});// reduce the data to one block
4242
}
43+
using BlockReduce = cub::BlockReduce<MaxSum, BLOCK_DIM>;
44+
__shared__ typename BlockReduce::TempStorage tempStorage;
4345
__shared__ MaxSum maxSumTotal;
44-
auto maxSumBlock = cuda::blockReduce(maxSumThread, {-__FLT_MAX__, 0}, MaxSum::reduce);
46+
auto maxSumBlock = BlockReduce(tempStorage).Reduce(maxSumThread, MaxSum::reduce);
4547
if (threadIdx.x == 0) {
4648
maxSumTotal = maxSumBlock;// must set threadIdx.x = 0 write the output to memory
4749
}
@@ -111,7 +113,7 @@ namespace refactor::kernel {
111113
auto y = reinterpret_cast<T *>(outputs[0]);
112114
int numBlocks = info.pre * info.post;
113115
if (info.mid > 1024) {
114-
blockSoftmaxKernel<<<numBlocks, 1024>>>(x, y, info.mid, info.post);
116+
blockSoftmaxKernel<1024><<<numBlocks, 1024>>>(x, y, info.mid, info.post);
115117
} else {
116118
int blockDimX, mid = static_cast<int>(info.mid);
117119
for (blockDimX = 32; blockDimX > 4 && mid < blockDimX; blockDimX /= 2) {}

src/04kernel/test/kernels/attention/test_cuda.cpp

+13-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using namespace hardware;
1313
TEST(kernel, AttentionCudaNoKvCache) {
1414
// build routine
1515
AttentionInfo info{
16-
.dataType = DataType::FP16,
16+
.dataType = DataType::F32,
1717
.batch = 1,
1818
.nHead = 4,
1919
.nKVHead = 4,
@@ -23,9 +23,9 @@ TEST(kernel, AttentionCudaNoKvCache) {
2323
.concatCache = false,
2424
.resetCache = false,
2525
};
26-
auto q = Tensor::share(DataType::FP16, Shape{info.batch, info.nHead, info.seqLen, info.headDim}),
27-
k = Tensor::share(DataType::FP16, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}),
28-
v = Tensor::share(DataType::FP16, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}),
26+
auto q = Tensor::share(DataType::F32, Shape{info.batch, info.nHead, info.seqLen, info.headDim}),
27+
k = Tensor::share(DataType::F32, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}),
28+
v = Tensor::share(DataType::F32, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}),
2929
o = q;
3030
auto kernel = AttentionCuda::build(info);
3131
ASSERT_TRUE(kernel);
@@ -38,6 +38,15 @@ TEST(kernel, AttentionCudaNoKvCache) {
3838
vGpu = dev.malloc(v->bytesSize()),
3939
oGpu = dev.malloc(o->bytesSize()),
4040
workspace = dev.malloc(workspaceSize);
41+
// put input data
42+
std::vector<float>
43+
q_(q->elementsSize(), 1),
44+
k_(k->elementsSize(), 1),
45+
v_(v->elementsSize(), 1),
46+
o_(o->elementsSize());
47+
qGpu->copyFromHost(q_.data());
48+
kGpu->copyFromHost(k_.data());
49+
vGpu->copyFromHost(v_.data());
4150
// inference
4251
{
4352
void const *inputs[]{*qGpu, *kGpu, *vGpu};

0 commit comments

Comments
 (0)