@@ -70,6 +70,20 @@ namespace refactor::kernel {
70
70
}
71
71
}
72
72
73
+ static __global__ void concatCache (
74
+ void *__restrict__ cache,
75
+ void const *__restrict__ value,
76
+ dim_t pageStrideI,
77
+ dim_t pageStrideO,
78
+ dim_t lineStride,
79
+ dim_t pastOffset) {
80
+
81
+ auto tid = blockIdx .x * blockDim .x + threadIdx .x ,
82
+ dst = tid / pageStrideO * pageStrideI + pastOffset + tid % pageStrideO;
83
+ reinterpret_cast <float4 *>(cache)[dst] = reinterpret_cast <float4 const *>(value)[tid];
84
+ }
85
+ constexpr uint64_t DYNAMIC_WORKSPACE_SIZE = 40 << 20 ;// 试出来 40MiB 是够用的
86
+
73
87
RoutineWorkspace K::lower (Resources &res) const {
74
88
auto handle = res.fetchOrStore <CublasLtContext>()->handle ;
75
89
@@ -125,8 +139,8 @@ namespace refactor::kernel {
125
139
.batchCount = static_cast <int32_t >(info.batch * info.nHead ),
126
140
.batchStride = static_cast <int64_t >(info.seqLen * info.seqLen ),
127
141
}) {
128
- auto [algoQK_, workspaceSizeQK_] = tune (context.handle , mul, q, k, att);
129
- auto [algoAV_, workspaceSizeAV_] = tune (context.handle , mul, att, v, q);
142
+ auto [algoQK_, workspaceSizeQK_] = tune (context.handle , mul, q, k, att, DYNAMIC_WORKSPACE_SIZE );
143
+ auto [algoAV_, workspaceSizeAV_] = tune (context.handle , mul, att, v, q, DYNAMIC_WORKSPACE_SIZE );
130
144
algoQK = algoQK_;
131
145
algoAV = algoAV_;
132
146
workspaceSizeQK = workspaceSizeQK_;
@@ -187,12 +201,146 @@ namespace refactor::kernel {
187
201
&d->algoAV ,
188
202
workspaceAV, d->workspaceSizeAV ,
189
203
stream);
190
- };
204
+ }
191
205
};
192
206
193
207
return {std::move (routine), workspaceSize};
194
208
}
209
+ TODO (" " );
195
210
}
211
+ if (info.concatCache && !info.resetCache) {
212
+ if (info.nHead == info.nKVHead ) {
213
+
214
+ // RAII for closure
215
+ struct Descriptors {
216
+ MatMulDescriptor mul;
217
+
218
+ Descriptors (AttentionInfo info)
219
+ : mul(computeTypeConvert(info.dataType),
220
+ dataTypeConvert (info.dataType)) {}
221
+ };
222
+
223
+ auto const &context = *res.fetchOrStore<CublasLtContext>();
224
+ auto d = std::make_shared<Descriptors>(info);
225
+ auto attentionSize = info.maxAttSize();
226
+ auto workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize;
227
+
228
+ auto routine = [d = std::move(d), info = this ->info]//
229
+ (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) {
230
+ auto handle = res.fetchOrStore <CublasLtContext>()->handle ;
231
+ auto q = inputs[0 ];
232
+ auto k = inputs[1 ];
233
+ auto v = inputs[2 ];
234
+ auto past = *reinterpret_cast <int64_t const *>(inputs[3 ]);
235
+ auto attLen = info.attLen (past);
236
+ auto o = reinterpret_cast <half *>(outputs[0 ]);
237
+ auto kCache = reinterpret_cast <half *>(outputs[1 ]);
238
+ auto vCache = reinterpret_cast <half *>(outputs[2 ]);
239
+ auto att = reinterpret_cast <half *>(reinterpret_cast <uint8_t *>(workspace) + DYNAMIC_WORKSPACE_SIZE);
240
+ auto stream = cudaStreamLegacy;
241
+ {
242
+ auto itemsPerLine = info.headDim * sizeof (half) / sizeof (float4 );
243
+ auto threads = info.batch * info.nHead * info.seqLen * itemsPerLine;
244
+ auto blocks = (threads + 1023 ) / 1024 ;
245
+
246
+ concatCache<<<blocks, 1024 , 0 , stream>>> (
247
+ kCache , k,
248
+ info.seqLen * itemsPerLine,
249
+ info.cacheLen * itemsPerLine,
250
+ itemsPerLine,
251
+ past * itemsPerLine);
252
+ concatCache<<<blocks, 1024 , 0 , stream>>> (
253
+ vCache, v,
254
+ info.seqLen * itemsPerLine,
255
+ info.cacheLen * itemsPerLine,
256
+ itemsPerLine,
257
+ past * itemsPerLine);
258
+ }
259
+ MatrixDescriptor
260
+ q_ (MatrixLayout{
261
+ .dataType = dataTypeConvert (info.dataType ),
262
+ .rows = static_cast <uint64_t >(info.seqLen ),
263
+ .cols = static_cast <uint64_t >(info.headDim ),
264
+ .majorStride = static_cast <int64_t >(info.headDim ),
265
+ .order = ROW_MAJOR,
266
+ .batchCount = static_cast <int32_t >(info.batch * info.nHead ),
267
+ .batchStride = static_cast <int64_t >(info.seqLen * info.headDim ),
268
+ }),
269
+ k_ (MatrixLayout{
270
+ .dataType = dataTypeConvert (info.dataType ),
271
+ .rows = static_cast <uint64_t >(info.headDim ),
272
+ .cols = static_cast <uint64_t >(attLen),
273
+ .majorStride = static_cast <int64_t >(info.headDim ),
274
+ .order = COL_MAJOR,
275
+ .batchCount = static_cast <int32_t >(info.batch * info.nHead ),
276
+ .batchStride = static_cast <int64_t >(info.cacheLen * info.headDim ),
277
+ }),
278
+ v_ (MatrixLayout{
279
+ .dataType = dataTypeConvert (info.dataType ),
280
+ .rows = static_cast <uint64_t >(attLen),
281
+ .cols = static_cast <uint64_t >(info.headDim ),
282
+ .majorStride = static_cast <int64_t >(info.headDim ),
283
+ .order = ROW_MAJOR,
284
+ .batchCount = static_cast <int32_t >(info.batch * info.nHead ),
285
+ .batchStride = static_cast <int64_t >(info.cacheLen * info.headDim ),
286
+ }),
287
+ att_ (MatrixLayout{
288
+ .dataType = dataTypeConvert (info.dataType ),
289
+ .rows = static_cast <uint64_t >(info.seqLen ),
290
+ .cols = static_cast <uint64_t >(attLen),
291
+ .majorStride = static_cast <int64_t >(info.cacheLen ),
292
+ .order = ROW_MAJOR,
293
+ .batchCount = static_cast <int32_t >(info.batch * info.nHead ),
294
+ .batchStride = static_cast <int64_t >(info.cacheLen * info.seqLen ),
295
+ });
296
+ {
297
+ auto [algo, workspaceSize] = tune (
298
+ handle, d->mul ,
299
+ q_, k_, att_,
300
+ DYNAMIC_WORKSPACE_SIZE);
301
+ half alpha = rsqrtf (info.headDim ), beta = 0 ;
302
+ cublasLtMatmul (
303
+ handle, d->mul .get (),
304
+ &alpha,
305
+ q, q_.get (),
306
+ kCache , k_.get (),
307
+ &beta,
308
+ att, att_.get (),
309
+ att, att_.get (),
310
+ &algo,
311
+ workspace, workspaceSize,
312
+ stream);
313
+ }
314
+ softmax<<<dim3 (info.batch * info.nHead, info.seqLen),
315
+ std::min (1024u , attLen),
316
+ attLen * sizeof(float ),
317
+ stream>>>(
318
+ att, AttentionCausualMask(), attLen, info.cacheLen);
319
+ {
320
+ auto [algo, workspaceSize] = tune (
321
+ handle, d->mul ,
322
+ att_, v_, q_,
323
+ DYNAMIC_WORKSPACE_SIZE);
324
+ half alpha = 1 , beta = 0 ;
325
+ cublasLtMatmul (
326
+ handle, d->mul .get (),
327
+ &alpha,
328
+ att, att_.get (),
329
+ vCache, v_.get (),
330
+ &beta,
331
+ o, q_.get (),
332
+ o, q_.get (),
333
+ &algo,
334
+ workspace, workspaceSize,
335
+ stream);
336
+ }
337
+ };
338
+
339
+ return {std::move (routine), workspaceSize};
340
+ }
341
+ TODO (" " );
342
+ }
343
+
196
344
TODO (" " );
197
345
}
198
346
0 commit comments