Skip to content

Commit 04b6894

Browse files
committed
feat(kernel): 支持 kv cache attention
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent bdad707 commit 04b6894

File tree

5 files changed

+162
-8
lines changed

5 files changed

+162
-8
lines changed

src/04kernel/include/kernel/attributes/attention_info.h

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace refactor::kernel {
1212

1313
dim_t attLen(dim_t pastSeqLen) const noexcept;
1414
size_t attSize(dim_t pastSeqLen) const noexcept;
15+
size_t maxAttSize() const noexcept;
1516
};
1617

1718
}// namespace refactor::kernel

src/04kernel/src/attributes/attention_info.cc

+4
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,8 @@ namespace refactor::kernel {
1010
return batch * nHead * seqLen * attLen(pastSeqLen) * dataType.size();
1111
}
1212

13+
size_t AttentionInfo::maxAttSize() const noexcept {
14+
return batch * nHead * seqLen * (cacheLen ? cacheLen : seqLen) * dataType.size();
15+
}
16+
1317
}// namespace refactor::kernel

src/04kernel/src/kernels/attention/cuda_kernel.cu

+151-3
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,20 @@ namespace refactor::kernel {
7070
}
7171
}
7272

73+
static __global__ void concatCache(
74+
void *__restrict__ cache,
75+
void const *__restrict__ value,
76+
dim_t pageStrideI,
77+
dim_t pageStrideO,
78+
dim_t lineStride,
79+
dim_t pastOffset) {
80+
81+
auto tid = blockIdx.x * blockDim.x + threadIdx.x,
82+
dst = tid / pageStrideO * pageStrideI + pastOffset + tid % pageStrideO;
83+
reinterpret_cast<float4 *>(cache)[dst] = reinterpret_cast<float4 const *>(value)[tid];
84+
}
85+
constexpr uint64_t DYNAMIC_WORKSPACE_SIZE = 40 << 20;// 试出来 40MiB 是够用的
86+
7387
RoutineWorkspace K::lower(Resources &res) const {
7488
auto handle = res.fetchOrStore<CublasLtContext>()->handle;
7589

@@ -125,8 +139,8 @@ namespace refactor::kernel {
125139
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
126140
.batchStride = static_cast<int64_t>(info.seqLen * info.seqLen),
127141
}) {
128-
auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att);
129-
auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q);
142+
auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att, DYNAMIC_WORKSPACE_SIZE);
143+
auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q, DYNAMIC_WORKSPACE_SIZE);
130144
algoQK = algoQK_;
131145
algoAV = algoAV_;
132146
workspaceSizeQK = workspaceSizeQK_;
@@ -187,12 +201,146 @@ namespace refactor::kernel {
187201
&d->algoAV,
188202
workspaceAV, d->workspaceSizeAV,
189203
stream);
190-
};
204+
}
191205
};
192206
193207
return {std::move(routine), workspaceSize};
194208
}
209+
TODO("");
195210
}
211+
if (info.concatCache && !info.resetCache) {
212+
if (info.nHead == info.nKVHead) {
213+
214+
// RAII for closure
215+
struct Descriptors {
216+
MatMulDescriptor mul;
217+
218+
Descriptors(AttentionInfo info)
219+
: mul(computeTypeConvert(info.dataType),
220+
dataTypeConvert(info.dataType)) {}
221+
};
222+
223+
auto const &context = *res.fetchOrStore<CublasLtContext>();
224+
auto d = std::make_shared<Descriptors>(info);
225+
auto attentionSize = info.maxAttSize();
226+
auto workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize;
227+
228+
auto routine = [d = std::move(d), info = this->info]//
229+
(Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
230+
auto handle = res.fetchOrStore<CublasLtContext>()->handle;
231+
auto q = inputs[0];
232+
auto k = inputs[1];
233+
auto v = inputs[2];
234+
auto past = *reinterpret_cast<int64_t const *>(inputs[3]);
235+
auto attLen = info.attLen(past);
236+
auto o = reinterpret_cast<half *>(outputs[0]);
237+
auto kCache = reinterpret_cast<half *>(outputs[1]);
238+
auto vCache = reinterpret_cast<half *>(outputs[2]);
239+
auto att = reinterpret_cast<half *>(reinterpret_cast<uint8_t *>(workspace) + DYNAMIC_WORKSPACE_SIZE);
240+
auto stream = cudaStreamLegacy;
241+
{
242+
auto itemsPerLine = info.headDim * sizeof(half) / sizeof(float4);
243+
auto threads = info.batch * info.nHead * info.seqLen * itemsPerLine;
244+
auto blocks = (threads + 1023) / 1024;
245+
246+
concatCache<<<blocks, 1024, 0, stream>>>(
247+
kCache, k,
248+
info.seqLen * itemsPerLine,
249+
info.cacheLen * itemsPerLine,
250+
itemsPerLine,
251+
past * itemsPerLine);
252+
concatCache<<<blocks, 1024, 0, stream>>>(
253+
vCache, v,
254+
info.seqLen * itemsPerLine,
255+
info.cacheLen * itemsPerLine,
256+
itemsPerLine,
257+
past * itemsPerLine);
258+
}
259+
MatrixDescriptor
260+
q_(MatrixLayout{
261+
.dataType = dataTypeConvert(info.dataType),
262+
.rows = static_cast<uint64_t>(info.seqLen),
263+
.cols = static_cast<uint64_t>(info.headDim),
264+
.majorStride = static_cast<int64_t>(info.headDim),
265+
.order = ROW_MAJOR,
266+
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
267+
.batchStride = static_cast<int64_t>(info.seqLen * info.headDim),
268+
}),
269+
k_(MatrixLayout{
270+
.dataType = dataTypeConvert(info.dataType),
271+
.rows = static_cast<uint64_t>(info.headDim),
272+
.cols = static_cast<uint64_t>(attLen),
273+
.majorStride = static_cast<int64_t>(info.headDim),
274+
.order = COL_MAJOR,
275+
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
276+
.batchStride = static_cast<int64_t>(info.cacheLen * info.headDim),
277+
}),
278+
v_(MatrixLayout{
279+
.dataType = dataTypeConvert(info.dataType),
280+
.rows = static_cast<uint64_t>(attLen),
281+
.cols = static_cast<uint64_t>(info.headDim),
282+
.majorStride = static_cast<int64_t>(info.headDim),
283+
.order = ROW_MAJOR,
284+
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
285+
.batchStride = static_cast<int64_t>(info.cacheLen * info.headDim),
286+
}),
287+
att_(MatrixLayout{
288+
.dataType = dataTypeConvert(info.dataType),
289+
.rows = static_cast<uint64_t>(info.seqLen),
290+
.cols = static_cast<uint64_t>(attLen),
291+
.majorStride = static_cast<int64_t>(info.cacheLen),
292+
.order = ROW_MAJOR,
293+
.batchCount = static_cast<int32_t>(info.batch * info.nHead),
294+
.batchStride = static_cast<int64_t>(info.cacheLen * info.seqLen),
295+
});
296+
{
297+
auto [algo, workspaceSize] = tune(
298+
handle, d->mul,
299+
q_, k_, att_,
300+
DYNAMIC_WORKSPACE_SIZE);
301+
half alpha = rsqrtf(info.headDim), beta = 0;
302+
cublasLtMatmul(
303+
handle, d->mul.get(),
304+
&alpha,
305+
q, q_.get(),
306+
kCache, k_.get(),
307+
&beta,
308+
att, att_.get(),
309+
att, att_.get(),
310+
&algo,
311+
workspace, workspaceSize,
312+
stream);
313+
}
314+
softmax<<<dim3(info.batch * info.nHead, info.seqLen),
315+
std::min(1024u, attLen),
316+
attLen * sizeof(float),
317+
stream>>>(
318+
att, AttentionCausualMask(), attLen, info.cacheLen);
319+
{
320+
auto [algo, workspaceSize] = tune(
321+
handle, d->mul,
322+
att_, v_, q_,
323+
DYNAMIC_WORKSPACE_SIZE);
324+
half alpha = 1, beta = 0;
325+
cublasLtMatmul(
326+
handle, d->mul.get(),
327+
&alpha,
328+
att, att_.get(),
329+
vCache, v_.get(),
330+
&beta,
331+
o, q_.get(),
332+
o, q_.get(),
333+
&algo,
334+
workspace, workspaceSize,
335+
stream);
336+
}
337+
};
338+
339+
return {std::move(routine), workspaceSize};
340+
}
341+
TODO("");
342+
}
343+
196344
TODO("");
197345
}
198346

src/04kernel/src/utilities/cuda/cublaslt_utils.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -101,23 +101,23 @@ namespace refactor::kernel::cublas {
101101
MatMulDescriptor const &matmul,
102102
MatrixDescriptor const &a,
103103
MatrixDescriptor const &b,
104-
MatrixDescriptor const &c) {
104+
MatrixDescriptor const &c,
105+
uint64_t maxWorkspace) {
105106

106107
int device;
107108
CUDA_ASSERT(cudaGetDevice(&device));
108109
cudaDeviceProp prop;
109110
CUDA_ASSERT(cudaGetDeviceProperties(&prop, device));
110111

111-
auto workspace = std::numeric_limits<uint64_t>::max();
112112
uint32_t alignment = prop.textureAlignment;
113113

114114
cublasLtMatmulPreference_t preference;
115115
CUBLASLT_ASSERT(cublasLtMatmulPreferenceCreate(&preference));
116116
CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute(
117117
preference,
118118
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
119-
&workspace,
120-
sizeof(workspace)));
119+
&maxWorkspace,
120+
sizeof(maxWorkspace)));
121121
CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute(
122122
preference,
123123
CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES,

src/04kernel/src/utilities/cuda/cublaslt_utils.cuh

+2-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ namespace refactor::kernel::cublas {
6868
MatMulDescriptor const &,
6969
MatrixDescriptor const &,
7070
MatrixDescriptor const &,
71-
MatrixDescriptor const &);
71+
MatrixDescriptor const &,
72+
uint64_t);
7273

7374
}// namespace refactor::kernel::cublas
7475

0 commit comments

Comments
 (0)