From cbe7c788e2ae8cccffdac8ab70615db2f8426b68 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 01/17] =?UTF-8?q?feat:=20=E5=BC=80=E5=A7=8B=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=20attention?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../include/hardware/devices/nvidia.h | 6 + src/02hardware/src/devices/nvidia/device.cc | 6 - src/02hardware/src/devices/nvidia/memory.cc | 8 +- .../kernel/attributes/attention_info.h | 16 ++ .../include/kernel/collectors/attention.h | 3 +- src/04kernel/src/collectors/attention.cc | 55 ++++--- .../src/kernels/attention/cuda_kernel.cu | 127 +++++++++++++++ .../src/kernels/attention/cuda_kernel.hh | 8 +- .../src/utilities/cuda/cublaslt_context.cu | 33 ---- .../src/utilities/cuda/cublaslt_context.hh | 33 ---- .../src/utilities/cuda/cublaslt_utils.cu | 145 ++++++++++++++++++ .../src/utilities/cuda/cublaslt_utils.cuh | 74 +++++++++ 12 files changed, 409 insertions(+), 105 deletions(-) create mode 100644 src/04kernel/include/kernel/attributes/attention_info.h create mode 100644 src/04kernel/src/kernels/attention/cuda_kernel.cu delete mode 100644 src/04kernel/src/utilities/cuda/cublaslt_context.cu delete mode 100644 src/04kernel/src/utilities/cuda/cublaslt_context.hh create mode 100644 src/04kernel/src/utilities/cuda/cublaslt_utils.cu create mode 100644 src/04kernel/src/utilities/cuda/cublaslt_utils.cuh diff --git a/src/02hardware/include/hardware/devices/nvidia.h b/src/02hardware/include/hardware/devices/nvidia.h index d19dd315..18a4269d 100644 --- a/src/02hardware/include/hardware/devices/nvidia.h +++ b/src/02hardware/include/hardware/devices/nvidia.h @@ -3,6 +3,12 @@ #include "../device.h" +#define CUDA_ASSERT(STATUS) \ + if (auto status = (STATUS); status != cudaSuccess) { \ + RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ + cudaGetErrorString(status), (int) status)); \ + } + namespace refactor::hardware { class Nvidia final : public Device { diff --git a/src/02hardware/src/devices/nvidia/device.cc b/src/02hardware/src/devices/nvidia/device.cc index fd10cb70..20f63c0f 100644 --- a/src/02hardware/src/devices/nvidia/device.cc +++ b/src/02hardware/src/devices/nvidia/device.cc @@ -4,12 +4,6 @@ #ifdef USE_CUDA #include "memory.hh" #include - -#define CUDA_ASSERT(STATUS) \ - if (auto status = (STATUS); status != cudaSuccess) { \ - RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ - cudaGetErrorString(status), (int) status)); \ - } #endif namespace refactor::hardware { diff --git a/src/02hardware/src/devices/nvidia/memory.cc b/src/02hardware/src/devices/nvidia/memory.cc index 42310196..1c3be21e 100644 --- a/src/02hardware/src/devices/nvidia/memory.cc +++ b/src/02hardware/src/devices/nvidia/memory.cc @@ -1,15 +1,9 @@ #ifdef USE_CUDA #include "memory.hh" -#include "common.h" +#include "hardware/devices/nvidia.h" #include -#define CUDA_ASSERT(STATUS) \ - if (auto status = (STATUS); status != cudaSuccess) { \ - RUNTIME_ERROR(fmt::format("cuda failed on \"" #STATUS "\" with \"{}\" ({})", \ - cudaGetErrorString(status), (int) status)); \ - } - namespace refactor::hardware { using M = NvidiaMemory; diff --git a/src/04kernel/include/kernel/attributes/attention_info.h b/src/04kernel/include/kernel/attributes/attention_info.h new file mode 100644 index 00000000..16d5fb0e --- /dev/null +++ b/src/04kernel/include/kernel/attributes/attention_info.h @@ -0,0 +1,16 @@ +#ifndef KERNEL_ATTENTION_INFO_H +#define KERNEL_ATTENTION_INFO_H + +#include "../tensor.h" + +namespace refactor::kernel { + + struct AttentionInfo { + DataType dataType; + dim_t batch, nHead, nKVHead, seqLen, headDim, cacheLen; + bool concatCache, resetCache; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_ATTENTION_INFO_H diff --git a/src/04kernel/include/kernel/collectors/attention.h b/src/04kernel/include/kernel/collectors/attention.h index 527bc63f..abf33957 100644 --- a/src/04kernel/include/kernel/collectors/attention.h +++ b/src/04kernel/include/kernel/collectors/attention.h @@ -6,9 +6,8 @@ namespace refactor::kernel { struct AttentionCollector final : public InfoCollector { - dim_t maxSeqLen; - AttentionCollector(decltype(_target), decltype(maxSeqLen)) noexcept; + AttentionCollector(decltype(_target)) noexcept; std::vector filter(TensorRefs inputs, TensorRefs outputs) const final; diff --git a/src/04kernel/src/collectors/attention.cc b/src/04kernel/src/collectors/attention.cc index 3933097f..a778c128 100644 --- a/src/04kernel/src/collectors/attention.cc +++ b/src/04kernel/src/collectors/attention.cc @@ -1,38 +1,57 @@ #include "kernel/collectors/attention.h" +#include "kernel/attributes/attention_info.h" // #include "../kernels/attention/cpu_kernel.hh" #include "../kernels/attention/cuda_kernel.hh" namespace refactor::kernel { AttentionCollector::AttentionCollector( - decltype(_target) target, - decltype(maxSeqLen) maxSeqLen_) noexcept - : InfoCollector(target), - maxSeqLen(maxSeqLen_) {} + decltype(_target) target) noexcept + : InfoCollector(target) {} std::vector AttentionCollector::filter(TensorRefs inputs, TensorRefs outputs) const { auto const &query = inputs[0].get(); auto const &key = inputs[1].get(); - auto pastSeqLen = inputs.size() == 3 ? 0 : *inputs[2].get().data->get(); - auto cacheLen = outputs.size() == 1 ? 0 : outputs[1].get().shape[2]; - std::vector ans; + AttentionInfo info{ + .dataType = query.dataType, + .batch = query.shape[0], + .nHead = query.shape[1], + .nKVHead = key.shape[1], + .seqLen = query.shape[2], + .headDim = query.shape[3], + .cacheLen = 0, + .concatCache = false, + .resetCache = false, + }; + switch (outputs.size()) { + case 1: + // no kv cache + ASSERT(inputs.size() == 3, ""); + break; + case 3: + switch (inputs.size()) { + case 6: + info.resetCache = true; + case 4: + info.concatCache = true; + case 3: + info.cacheLen = outputs[1].get().shape[2]; + break; + default: + UNREACHABLE(); + } + break; + default: + UNREACHABLE(); + } + + std ::vector ans; switch (_target) { case decltype(_target)::Cpu: break; case decltype(_target)::Nvidia: { - decltype(AttentionCuda::info) info{ - .dataType = query.dataType, - .batch = query.shape[0], - .nHead = query.shape[1], - .nKVHead = key.shape[1], - .pastSeqLen = static_cast(pastSeqLen), - .seqLen = query.shape[2], - .cacheLen = cacheLen, - .headDim = query.shape[3], - .resetCache = false, - }; if (auto ptr = AttentionCuda::build(info); ptr) { ans.emplace_back(std::move(ptr)); } diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu new file mode 100644 index 00000000..a0f3f56a --- /dev/null +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -0,0 +1,127 @@ +#include "../../utilities/cuda/cublaslt_utils.cuh" +#include "cuda_kernel.hh" +#include "hardware/functions.h" + +namespace refactor::kernel { + using K = AttentionCuda; + using namespace cublas; + + RoutineWorkspace K::lower(Resources &res) const { + auto handle = res.fetchOrStore()->handle; + + constexpr auto ROW_MAJOR = CUBLASLT_ORDER_ROW; + constexpr auto COL_MAJOR = CUBLASLT_ORDER_COL; + + if (!info.cacheLen) { + if (info.nHead == info.nKVHead) { + // RAII for closure + struct Descriptors { + MatMulDescriptor mul; + MatrixDescriptor q, k, v, att; + cublasLtMatmulAlgo_t algoQK, algoAV; + size_t attSize, workspaceSizeQK, workspaceSizeAV; + + Descriptors(CublasLtContext const &context, + cublasComputeType_t compute, + AttentionInfo info) + : mul(compute, CUDA_R_32F), + q(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(info.headDim), + .majorStride = static_cast(info.headDim), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.seqLen * info.headDim), + }), + k(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.headDim), + .cols = static_cast(info.seqLen), + .majorStride = static_cast(info.headDim), + .order = COL_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.seqLen * info.headDim), + }), + v(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(info.headDim), + .majorStride = static_cast(info.headDim), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.seqLen * info.headDim), + }), + att(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(info.seqLen), + .majorStride = static_cast(info.seqLen), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.seqLen * info.seqLen), + }), + attSize(info.batch * info.nHead * info.seqLen * info.seqLen * info.dataType.size()) { + auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att); + auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q); + algoQK = algoQK_; + algoAV = algoAV_; + workspaceSizeQK = workspaceSizeQK_; + workspaceSizeAV = workspaceSizeAV_; + } + }; + + auto const &context = *res.fetchOrStore(); + auto d = std::make_shared(context, CUBLAS_COMPUTE_32F, info); + auto workspaceSize = d->attSize; + workspaceSize = hardware::alignBytes(workspaceSize, 256); + workspaceSize += d->workspaceSizeQK; + workspaceSize = hardware::alignBytes(workspaceSize, 256); + workspaceSize += d->workspaceSizeAV; + workspaceSize = hardware::alignBytes(workspaceSize, 256); + + auto routine = [d = std::move(d), info = this->info]// + (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { + auto handle = res.fetchOrStore()->handle; + auto q = inputs[0]; + auto k = inputs[1]; + auto v = inputs[2]; + auto o = outputs[0]; + auto att = workspace; + auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(d->attSize, 256); + auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256); + + float alpha = 1, beta = 0; + cublasLtMatmul( + handle, d->mul.get(), + &alpha, + q, d->q.get(), + k, d->k.get(), + &beta, + att, d->att.get(), + att, d->att.get(), + &d->algoQK, + workspaceQK, d->workspaceSizeQK, + cudaStreamLegacy); + + // TODO inline mask && softmax + + cublasLtMatmul( + handle, d->mul.get(), + &alpha, + att, d->att.get(), + v, d->v.get(), + &beta, + o, d->q.get(), + o, d->q.get(), + &d->algoAV, + workspaceAV, d->workspaceSizeAV, + cudaStreamLegacy); + }; + return {std::move(routine), workspaceSize}; + } + } + TODO(""); + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.hh b/src/04kernel/src/kernels/attention/cuda_kernel.hh index 5ea19ae8..20cf9712 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.hh +++ b/src/04kernel/src/kernels/attention/cuda_kernel.hh @@ -1,17 +1,13 @@ #ifndef KERNEL_ATTENTION_CUDA_KERNEL_HH #define KERNEL_ATTENTION_CUDA_KERNEL_HH +#include "kernel/attributes/attention_info.h" #include "kernel/kernel.h" -#include "kernel/tensor.h" namespace refactor::kernel { struct AttentionCuda final : public Kernel { - struct { - DataType dataType; - dim_t batch, nHead, nKVHead, pastSeqLen, seqLen, cacheLen, headDim; - bool resetCache; - } info; + AttentionInfo info; AttentionCuda(decltype(info)) noexcept; diff --git a/src/04kernel/src/utilities/cuda/cublaslt_context.cu b/src/04kernel/src/utilities/cuda/cublaslt_context.cu deleted file mode 100644 index 2fc8fb18..00000000 --- a/src/04kernel/src/utilities/cuda/cublaslt_context.cu +++ /dev/null @@ -1,33 +0,0 @@ -#include "common.h" -#include "cublaslt_context.hh" - -namespace refactor::kernel::cublas { - - CublasLtContext::CublasLtContext() : runtime::Resource() { - if (cublasLtCreate(&handle) != CUBLAS_STATUS_SUCCESS) { - RUNTIME_ERROR("Failed to create cublasLt handle"); - } - } - CublasLtContext::~CublasLtContext() { - if (cublasLtDestroy(handle) != CUBLAS_STATUS_SUCCESS) { - fmt::println("Failed to destroy cublasLt handle"); - abort(); - } - } - - auto CublasLtContext::typeId() noexcept -> size_t { - static uint8_t ID = 1; - return reinterpret_cast(&ID); - } - auto CublasLtContext::build() noexcept -> runtime::ResourceBox { - return std::make_unique(); - } - - auto CublasLtContext::resourceTypeId() const noexcept -> size_t { - return typeId(); - } - auto CublasLtContext::description() const noexcept -> std::string_view { - return "CublasLtContext"; - } - -}// namespace refactor::kernel::cublas diff --git a/src/04kernel/src/utilities/cuda/cublaslt_context.hh b/src/04kernel/src/utilities/cuda/cublaslt_context.hh deleted file mode 100644 index 84e1d2d9..00000000 --- a/src/04kernel/src/utilities/cuda/cublaslt_context.hh +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef KERNEL_CUBLASLT_CONTEXT_HH -#define KERNEL_CUBLASLT_CONTEXT_HH - -#include "runtime/resource.h" -#include - -#define CUBLAS_ASSERT(STATUS) \ - if (auto status = (STATUS); status != CUBLAS_STATUS_SUCCESS) { \ - fmt::println("cublas failed on \"" #STATUS "\" with {}", \ - (int) status); \ - abort(); \ - } - -namespace refactor::kernel::cublas { - - struct CublasLtContext final : public runtime::Resource { - cublasLtHandle_t handle; - - CublasLtContext(); - ~CublasLtContext(); - CublasLtContext(CublasLtContext const &) noexcept = delete; - CublasLtContext(CublasLtContext &&) noexcept = delete; - - static size_t typeId() noexcept; - static runtime::ResourceBox build() noexcept; - - size_t resourceTypeId() const noexcept final; - std::string_view description() const noexcept final; - }; - -}// namespace refactor::kernel::cublas - -#endif// KERNEL_CUBLASLT_CONTEXT_HH diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu new file mode 100644 index 00000000..d07af6ab --- /dev/null +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu @@ -0,0 +1,145 @@ +#include "cublaslt_utils.cuh" +#include "hardware/devices/nvidia.h" + +namespace refactor::kernel::cublas { + + CublasLtContext::CublasLtContext() : runtime::Resource() { + if (cublasLtCreate(&handle) != CUBLAS_STATUS_SUCCESS) { + RUNTIME_ERROR("Failed to create cublasLt handle"); + } + } + CublasLtContext::~CublasLtContext() { + if (cublasLtDestroy(handle) != CUBLAS_STATUS_SUCCESS) { + fmt::println("Failed to destroy cublasLt handle"); + abort(); + } + } + + auto CublasLtContext::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto CublasLtContext::build() noexcept -> runtime::ResourceBox { + return std::make_unique(); + } + + auto CublasLtContext::resourceTypeId() const noexcept -> size_t { + return typeId(); + } + auto CublasLtContext::description() const noexcept -> std::string_view { + return "CublasLtContext"; + } + + cudaDataType dataTypeConvert(DataType dt) { + switch (dt) { + case DataType::F32: + return CUDA_R_32F; + default: + TODO(""); + } + } + + MatMulDescriptor::MatMulDescriptor(cublasComputeType_t compute, cudaDataType data) + : _internal(nullptr) { + CUBLASLT_ASSERT(cublasLtMatmulDescCreate(&_internal, compute, data)); + } + MatMulDescriptor::~MatMulDescriptor() { + CUBLASLT_ASSERT(cublasLtMatmulDescDestroy(_internal)); + } + cublasLtMatmulDesc_t MatMulDescriptor::get() const noexcept { + return _internal; + } + + MatrixDescriptor::MatrixDescriptor(MatrixLayout layout) + : _internal(nullptr) { + CUBLASLT_ASSERT(cublasLtMatrixLayoutCreate( + &_internal, + layout.dataType, + layout.rows, + layout.cols, + layout.majorStride)); + CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute( + _internal, + CUBLASLT_MATRIX_LAYOUT_ORDER, + &layout.order, + sizeof(layout.order))); + CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute( + _internal, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &layout.batchCount, + sizeof(layout.batchCount))); + CUBLASLT_ASSERT(cublasLtMatrixLayoutSetAttribute( + _internal, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &layout.batchStride, + sizeof(layout.batchStride))); + } + MatrixDescriptor::~MatrixDescriptor() { + CUBLASLT_ASSERT(cublasLtMatrixLayoutDestroy(_internal)); + } + cublasLtMatrixLayout_t MatrixDescriptor::get() const noexcept { + return _internal; + } + + std::pair + tune(cublasLtHandle_t handle, + MatMulDescriptor const &matmul, + MatrixDescriptor const &a, + MatrixDescriptor const &b, + MatrixDescriptor const &c) { + + int device; + CUDA_ASSERT(cudaGetDevice(&device)); + cudaDeviceProp prop; + CUDA_ASSERT(cudaGetDeviceProperties(&prop, device)); + + auto workspace = std::numeric_limits::max(); + auto alignment = prop.textureAlignment; + + cublasLtMatmulPreference_t preference; + CUBLASLT_ASSERT(cublasLtMatmulPreferenceCreate(&preference)); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace, + sizeof(workspace))); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, + &alignment, + sizeof(alignment))); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, + &alignment, + sizeof(alignment))); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, + &alignment, + sizeof(alignment))); + CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( + preference, + CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, + &alignment, + sizeof(alignment))); + + cublasLtMatmulHeuristicResult_t result; + int ansN; + CUBLASLT_ASSERT(cublasLtMatmulAlgoGetHeuristic( + handle, + matmul.get(), + a.get(), + b.get(), + c.get(), + c.get(), + preference, + 1, + &result, + &ansN)); + ASSERT(ansN == 1, ""); + + return {result.algo, result.workspaceSize}; + } + +}// namespace refactor::kernel::cublas diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh new file mode 100644 index 00000000..5dd23607 --- /dev/null +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh @@ -0,0 +1,74 @@ +#ifndef KERNEL_CUBLASLT_UTILS_CUH +#define KERNEL_CUBLASLT_UTILS_CUH + +#include "common.h" +#include "runtime/resource.h" +#include + +#define CUBLASLT_ASSERT(STATUS) \ + if (auto status = (STATUS); status != CUBLAS_STATUS_SUCCESS) { \ + fmt::println("cublasLt failed on \"" #STATUS "\" with {}", \ + (int) status); \ + abort(); \ + } + +namespace refactor::kernel::cublas { + + struct CublasLtContext final : public runtime::Resource { + cublasLtHandle_t handle; + + CublasLtContext(); + ~CublasLtContext(); + CublasLtContext(CublasLtContext const &) noexcept = delete; + CublasLtContext(CublasLtContext &&) noexcept = delete; + + static size_t typeId() noexcept; + static runtime::ResourceBox build() noexcept; + + size_t resourceTypeId() const noexcept final; + std::string_view description() const noexcept final; + }; + + cudaDataType dataTypeConvert(DataType); + + class MatMulDescriptor { + cublasLtMatmulDesc_t _internal; + + public: + MatMulDescriptor(cublasComputeType_t, cudaDataType); + ~MatMulDescriptor(); + MatMulDescriptor(MatMulDescriptor const &) noexcept = delete; + MatMulDescriptor(MatMulDescriptor &&) noexcept = delete; + cublasLtMatmulDesc_t get() const noexcept; + }; + + struct MatrixLayout { + cudaDataType dataType; + uint64_t rows, cols; + int64_t majorStride; + cublasLtOrder_t order; + int32_t batchCount; + int64_t batchStride; + }; + + class MatrixDescriptor { + cublasLtMatrixLayout_t _internal; + + public: + MatrixDescriptor(MatrixLayout layout); + ~MatrixDescriptor(); + MatrixDescriptor(MatrixDescriptor const &) noexcept = delete; + MatrixDescriptor(MatrixDescriptor &&) noexcept = delete; + cublasLtMatrixLayout_t get() const noexcept; + }; + + std::pair + tune(cublasLtHandle_t, + MatMulDescriptor const &, + MatrixDescriptor const &, + MatrixDescriptor const &, + MatrixDescriptor const &); + +}// namespace refactor::kernel::cublas + +#endif// KERNEL_CUBLASLT_UTILS_CUH From 6d4465c62567c7ec0e1e2757f87b24ed6d00404f Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 02/17] build(kernel): try to use local cccl instead of system Signed-off-by: YdrMaster --- 3rd-party/cccl | 2 +- CMakeLists.txt | 16 +++++++++---- cmake/CPM.cmake | 24 +++++++++++++++++++ src/04kernel/CMakeLists.txt | 3 ++- src/04kernel/cuda/CMakeLists.txt | 2 +- .../cuda/include/kernel/cuda/reduce.cuh | 9 +++++++ .../src/kernels/attention/cuda_kernel.cu | 1 + 7 files changed, 50 insertions(+), 7 deletions(-) create mode 100644 cmake/CPM.cmake create mode 100644 src/04kernel/cuda/include/kernel/cuda/reduce.cuh diff --git a/3rd-party/cccl b/3rd-party/cccl index b7d4228a..d27b5896 160000 --- a/3rd-party/cccl +++ b/3rd-party/cccl @@ -1 +1 @@ -Subproject commit b7d4228ab7268ed928984cd61096079bd671d25d +Subproject commit d27b58963128f17a6c2f3f867301d54e9f4b48cd diff --git a/CMakeLists.txt b/CMakeLists.txt index 49ddcda6..2d76dd81 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,12 +12,20 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_POSITION_INDEPENDENT_CODE ON) +# Download with: +# +# mkdir -p cmake +# wget -O cmake/CPM.cmake https://github.com/cpm-cmake/CPM.cmake/releases/latest/download/get_cpm.cmake +include(cmake/CPM.cmake) + if(USE_CUDA) + CPMAddPackage(NAME CCCL SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rd-party/cccl) + add_compile_definitions(USE_CUDA) enable_language(CUDA) set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER}) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - set(CMAKE_CUDA_ARCHITECTURES 80) + set(CMAKE_CUDA_ARCHITECTURES native) endif() if(NOT DEFINED CMAKE_CUDA_STANDARD) set(CMAKE_CUDA_STANDARD 17) @@ -45,7 +53,7 @@ endif() if (USE_BANG) add_compile_definitions(USE_BANG) include_directories(src/kernels/mlu/include) - + # Neuware Evironment if ((NOT DEFINED NEUWARE_HOME) AND (NOT DEFINED ENV{NEUWARE_HOME})) message(FATAL_ERROR "NEUWARE_HOME is not defined from cmake or env") @@ -55,14 +63,14 @@ if (USE_BANG) set(NEUWARE_HOME $ENV{NEUWARE_HOME} CACHE STRING "NEUWARE_HOME directory for Cambricon Neuware development") endif() message(STATUS "NEUWARE_HOME: ${NEUWARE_HOME}") - + # cnrt cndrv cnnl include_directories("${NEUWARE_HOME}/include") find_library(CAMBRICON_CNNL libcnnl.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNRT libcnrt.so "${NEUWARE_HOME}/lib64") find_library(CAMBRICON_CNDRV libcndrv.so "${NEUWARE_HOME}/lib64") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall") - + if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH})) execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE) set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH") diff --git a/cmake/CPM.cmake b/cmake/CPM.cmake new file mode 100644 index 00000000..cc25ec28 --- /dev/null +++ b/cmake/CPM.cmake @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: MIT +# +# SPDX-FileCopyrightText: Copyright (c) 2019-2023 Lars Melchior and contributors + +set(CPM_DOWNLOAD_VERSION 0.38.7) +set(CPM_HASH_SUM "83e5eb71b2bbb8b1f2ad38f1950287a057624e385c238f6087f94cdfc44af9c5") + +if(CPM_SOURCE_CACHE) + set(CPM_DOWNLOAD_LOCATION "${CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +elseif(DEFINED ENV{CPM_SOURCE_CACHE}) + set(CPM_DOWNLOAD_LOCATION "$ENV{CPM_SOURCE_CACHE}/cpm/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +else() + set(CPM_DOWNLOAD_LOCATION "${CMAKE_BINARY_DIR}/cmake/CPM_${CPM_DOWNLOAD_VERSION}.cmake") +endif() + +# Expand relative path. This is important if the provided path contains a tilde (~) +get_filename_component(CPM_DOWNLOAD_LOCATION ${CPM_DOWNLOAD_LOCATION} ABSOLUTE) + +file(DOWNLOAD + https://github.com/cpm-cmake/CPM.cmake/releases/download/v${CPM_DOWNLOAD_VERSION}/CPM.cmake + ${CPM_DOWNLOAD_LOCATION} EXPECTED_HASH SHA256=${CPM_HASH_SUM} +) + +include(${CPM_DOWNLOAD_LOCATION}) diff --git a/src/04kernel/CMakeLists.txt b/src/04kernel/CMakeLists.txt index 77b655c0..efdeb0da 100644 --- a/src/04kernel/CMakeLists.txt +++ b/src/04kernel/CMakeLists.txt @@ -26,7 +26,8 @@ if(USE_CUDA) # nvrtc for cuda kernel compile # cublas for matmul # cudnn for conv and others - target_link_libraries(kernel PUBLIC cuda nvrtc cublas cublasLt cudnn kernel_cuda) + target_link_libraries(kernel PUBLIC cuda kernel_cuda) + target_link_libraries(kernel PRIVATE nvrtc cublas cublasLt cudnn) target_include_directories(kernel PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) find_package(NCCL REQUIRED) diff --git a/src/04kernel/cuda/CMakeLists.txt b/src/04kernel/cuda/CMakeLists.txt index 4c976e33..07223090 100644 --- a/src/04kernel/cuda/CMakeLists.txt +++ b/src/04kernel/cuda/CMakeLists.txt @@ -4,7 +4,7 @@ project(kernel_cuda) file(GLOB_RECURSE KERNEL_CUDA_SUB_SRC src/*.cu) add_library(kernel_cuda STATIC ${KERNEL_CUDA_SUB_SRC}) -target_link_libraries(kernel_cuda PUBLIC common) +target_link_libraries(kernel_cuda PUBLIC common CCCL::CCCL) target_include_directories(kernel_cuda PUBLIC include) file(GLOB_RECURSE KERNEL_CUDA_TEST test/*.cu) diff --git a/src/04kernel/cuda/include/kernel/cuda/reduce.cuh b/src/04kernel/cuda/include/kernel/cuda/reduce.cuh new file mode 100644 index 00000000..42739534 --- /dev/null +++ b/src/04kernel/cuda/include/kernel/cuda/reduce.cuh @@ -0,0 +1,9 @@ +#ifndef KERNEL_CUDA_REDUCE_CUH +#define KERNEL_CUDA_REDUCE_CUH + +#include + +namespace refactor::kernel::cuda { +} + +#endif// KERNEL_CUDA_REDUCE_CUH diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index a0f3f56a..79aa6f2b 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -1,6 +1,7 @@ #include "../../utilities/cuda/cublaslt_utils.cuh" #include "cuda_kernel.hh" #include "hardware/functions.h" +#include "kernel/cuda/reduce.cuh" namespace refactor::kernel { using K = AttentionCuda; From e54c7db38cecea327bb75d249ed223f2b332279c Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 03/17] =?UTF-8?q?refactor(kernel):=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E4=B8=80=E7=A7=8D=E4=B8=8D=E4=BE=9D=E8=B5=96=E6=A8=A1=E6=9D=BF?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E7=9A=84=20BlockReduce=20=E5=B9=B6=E7=94=A8?= =?UTF-8?q?=E4=BA=8E=20softmax?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../cuda/include/kernel/cuda/reduce.cuh | 24 ++++++++++++++++++- .../src/kernels/softmax/cuda_kernel.cu | 12 ++++------ .../test/kernels/softmax/test_cuda.cpp | 19 ++++++++++----- 3 files changed, 41 insertions(+), 14 deletions(-) diff --git a/src/04kernel/cuda/include/kernel/cuda/reduce.cuh b/src/04kernel/cuda/include/kernel/cuda/reduce.cuh index 42739534..6a5be4a3 100644 --- a/src/04kernel/cuda/include/kernel/cuda/reduce.cuh +++ b/src/04kernel/cuda/include/kernel/cuda/reduce.cuh @@ -4,6 +4,28 @@ #include namespace refactor::kernel::cuda { -} + + template + __inline__ __device__ T blockReduce(T x, T init, ReductionOp op) { + using WarpReduce = cub::WarpReduce; + __shared__ typename WarpReduce::TempStorage tempStorage; + __shared__ T shared[32], ans; + + auto reduce = WarpReduce(tempStorage); + int lane = threadIdx.x % 32; + int wid = threadIdx.x / 32; + x = reduce.Reduce(x, op); + if (lane == 0) { shared[wid] = x; } + __syncthreads(); + if (wid == 0) { + x = (threadIdx.x < blockDim.x / 32) ? shared[lane] : init; + shared[lane] = reduce.Reduce(x, op); + if (lane == 0) { ans = shared[0]; } + } + __syncthreads(); + return ans;// avoid RAW hazard + } + +}// namespace refactor::kernel::cuda #endif// KERNEL_CUDA_REDUCE_CUH diff --git a/src/04kernel/src/kernels/softmax/cuda_kernel.cu b/src/04kernel/src/kernels/softmax/cuda_kernel.cu index 114cb453..c5ecb821 100644 --- a/src/04kernel/src/kernels/softmax/cuda_kernel.cu +++ b/src/04kernel/src/kernels/softmax/cuda_kernel.cu @@ -1,5 +1,5 @@ #include "cuda_kernel.hh" -#include +#include "kernel/cuda/reduce.cuh" namespace refactor::kernel { using namespace runtime; @@ -18,8 +18,8 @@ namespace refactor::kernel { template<> __device__ __forceinline__ nv_bfloat16 reciprocal(nv_bfloat16 x) { return hrcp(x); } // blockDim.x === BLOCK_DIM - template - __launch_bounds__(BLOCK_DIM) __global__ void blockSoftmaxKernel( + template + __global__ void blockSoftmaxKernel( T const *__restrict x, T *__restrict y, int mid, @@ -40,10 +40,8 @@ namespace refactor::kernel { for (int i = threadIdx.x + blockDim.x; i < mid; i += blockDim.x) { maxSumThread = MaxSum::reduce(maxSumThread, {x[id + i * stride], 1});// reduce the data to one block } - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tempStorage; __shared__ MaxSum maxSumTotal; - auto maxSumBlock = BlockReduce(tempStorage).Reduce(maxSumThread, MaxSum::reduce); + auto maxSumBlock = cuda::blockReduce(maxSumThread, {-__FLT_MAX__, 0}, MaxSum::reduce); if (threadIdx.x == 0) { maxSumTotal = maxSumBlock;// must set threadIdx.x = 0 write the output to memory } @@ -113,7 +111,7 @@ namespace refactor::kernel { auto y = reinterpret_cast(outputs[0]); int numBlocks = info.pre * info.post; if (info.mid > 1024) { - blockSoftmaxKernel<1024><<>>(x, y, info.mid, info.post); + blockSoftmaxKernel<<>>(x, y, info.mid, info.post); } else { int blockDimX, mid = static_cast(info.mid); for (blockDimX = 32; blockDimX > 4 && mid < blockDimX; blockDimX /= 2) {} diff --git a/src/04kernel/test/kernels/softmax/test_cuda.cpp b/src/04kernel/test/kernels/softmax/test_cuda.cpp index 4290e852..ce3cc1ad 100644 --- a/src/04kernel/test/kernels/softmax/test_cuda.cpp +++ b/src/04kernel/test/kernels/softmax/test_cuda.cpp @@ -4,18 +4,19 @@ #include "../../../src/kernels/softmax/cuda_kernel.hh" #include "hardware/device_manager.h" #include +#include using namespace refactor; using namespace kernel; using namespace hardware; -TEST(kernel, SoftmaxCuda) { +static void test(Shape shape, int axis) { // build routine - auto xTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4}); - auto outTensor = Tensor::share(DataType::F32, Shape{2, 3, 2, 5, 4}); - dim_t axis = 1; - auto kCpu = SoftmaxCpu::build(SoftmaxInfo(*xTensor, axis)); - auto kCuda = SoftmaxCuda::build(SoftmaxInfo(*xTensor, axis)); + auto xTensor = Tensor::share(DataType::F32, shape); + auto outTensor = Tensor::share(DataType::F32, shape); + SoftmaxInfo info(*xTensor, axis); + auto kCpu = SoftmaxCpu::build(info); + auto kCuda = SoftmaxCuda::build(info); ASSERT_TRUE(kCpu && kCuda); auto res = runtime::Resources(); auto rCpu = kCpu->lower(res).routine; @@ -28,6 +29,7 @@ TEST(kernel, SoftmaxCuda) { std::vector data(xTensor->elementsSize(), 0), cpuOut(outTensor->elementsSize()); + std::iota(data.begin(), data.end(), 0); gpuIn->copyFromHost(data.data(), xTensor->bytesSize()); // inference { @@ -49,4 +51,9 @@ TEST(kernel, SoftmaxCuda) { } } +TEST(kernel, SoftmaxCuda) { + test({2, 3, 2, 5, 4}, 1); + test({2, 2048, 2, 5, 4}, 1); +} + #endif From 8e1564ae2ba46a9a056cbc294532f411215f82d1 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 04/17] =?UTF-8?q?feat(kernel):=20=E7=A1=AE=E8=AE=A4?= =?UTF-8?q?=E4=BE=9B=20attention=20=E8=B0=83=E7=94=A8=E7=9A=84=20softmax?= =?UTF-8?q?=20=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../src/kernels/attention/cuda_kernel.cu | 48 +++++++++++++++++-- .../src/utilities/cuda/cublaslt_utils.cu | 15 ++++++ .../src/utilities/cuda/cublaslt_utils.cuh | 1 + 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index 79aa6f2b..7d018607 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -7,6 +7,44 @@ namespace refactor::kernel { using K = AttentionCuda; using namespace cublas; + static __forceinline__ __device__ bool mask(int tokid, int posid) { + return true; + } + + // gridDim.x = batch * nHead + // gridDim.y = seqLen + template + static __global__ void softmax( + T *__restrict__ attention, + Mask mask, + uint32_t seqLen, + uint32_t bufLen) { + // int offset = (blockIdx.x * len_q + blockIdx.y) * len_buf; + // SharedMemory shared; + // float *smem = shared.getPointer(); + + // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { + // T pb = (position_bias == nullptr) ? T(0.) : position_bias[offset + i]; + // smem[i] = mask[blockIdx.y * len_buf + i] > 0 ? x[offset + i] * scale + pb : -Inf(); + // } + // float local_max = -1e20; + // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { + // local_max = fmaxf(local_max, smem[i]); + // } + // local_max = functions::blockReduceMax(local_max); + + // float local_sum = 1e-20; + // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { + // float v = expf(float(smem[i]) - local_max); + // smem[i] = v; + // local_sum += v; + // } + // local_sum = functions::blockReduceSum(local_sum); + // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { + // x[offset + i] = float(smem[i]) / local_sum; + // } + } + RoutineWorkspace K::lower(Resources &res) const { auto handle = res.fetchOrStore()->handle; @@ -23,9 +61,9 @@ namespace refactor::kernel { size_t attSize, workspaceSizeQK, workspaceSizeAV; Descriptors(CublasLtContext const &context, - cublasComputeType_t compute, AttentionInfo info) - : mul(compute, CUDA_R_32F), + : mul(computeTypeConvert(info.dataType), + dataTypeConvert(info.dataType)), q(MatrixLayout{ .dataType = dataTypeConvert(info.dataType), .rows = static_cast(info.seqLen), @@ -73,11 +111,10 @@ namespace refactor::kernel { }; auto const &context = *res.fetchOrStore(); - auto d = std::make_shared(context, CUBLAS_COMPUTE_32F, info); + auto d = std::make_shared(context, info); auto workspaceSize = d->attSize; workspaceSize = hardware::alignBytes(workspaceSize, 256); workspaceSize += d->workspaceSizeQK; - workspaceSize = hardware::alignBytes(workspaceSize, 256); workspaceSize += d->workspaceSizeAV; workspaceSize = hardware::alignBytes(workspaceSize, 256); @@ -105,7 +142,8 @@ namespace refactor::kernel { workspaceQK, d->workspaceSizeQK, cudaStreamLegacy); - // TODO inline mask && softmax + softmax<<>>( + att, mask, info.seqLen, info.seqLen); cublasLtMatmul( handle, d->mul.get(), diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu index d07af6ab..7c34ad4c 100644 --- a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu @@ -34,6 +34,21 @@ namespace refactor::kernel::cublas { switch (dt) { case DataType::F32: return CUDA_R_32F; + case DataType::FP16: + return CUDA_R_16F; + case DataType::BF16: + return CUDA_R_16BF; + default: + TODO(""); + } + } + cublasComputeType_t computeTypeConvert(DataType dt) { + switch (dt) { + case DataType::F32: + case DataType::BF16: + return CUBLAS_COMPUTE_32F; + case DataType::FP16: + return CUBLAS_COMPUTE_16F; default: TODO(""); } diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh index 5dd23607..ccaad7ec 100644 --- a/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh @@ -30,6 +30,7 @@ namespace refactor::kernel::cublas { }; cudaDataType dataTypeConvert(DataType); + cublasComputeType_t computeTypeConvert(DataType); class MatMulDescriptor { cublasLtMatmulDesc_t _internal; From ccf293fa1396cf7b6c44074d17e40c1de683c48c Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 05/17] =?UTF-8?q?feat(kernel):=20=E5=BC=80=E5=A7=8B?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20softmax?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../src/kernels/attention/cuda_kernel.cu | 43 ++++++++++++------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index 7d018607..b976ac2c 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -7,26 +7,39 @@ namespace refactor::kernel { using K = AttentionCuda; using namespace cublas; - static __forceinline__ __device__ bool mask(int tokid, int posid) { - return true; + // 因果系统的注意力遮罩。 + // tokenId: 第几个词 + // seqLen: 此次处理的词数 + // posId: 在 kv cache 中的位置 + // attLen = pastSeqLen + seqLen + static __forceinline__ __device__ bool + causualMask(int tokenId, int seqLen, + int posId, int attLen) { + // tokenId ↓ |<---attLen---->| + // 0 | * * ... * | + // 1 | * * ... * * | + // 2 | * * ... * * * | + // seqLen: 3 |---------------| + return attLen + tokenId >= posId + seqLen; } // gridDim.x = batch * nHead // gridDim.y = seqLen - template + // blockDim.x = min(1024, attLen) + template static __global__ void softmax( - T *__restrict__ attention, - Mask mask, - uint32_t seqLen, + T *__restrict__ att, + bool (*mask)(int, int, int, int), + uint32_t attLen, uint32_t bufLen) { - // int offset = (blockIdx.x * len_q + blockIdx.y) * len_buf; - // SharedMemory shared; - // float *smem = shared.getPointer(); + // 找到这个线程块对应的 attention 区域 + att += (blockIdx.x * gridDim.x + gridDim.y) * bufLen; + // 将输入装入共享内存并 cast + mask + extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen + for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { + shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__; + } - // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { - // T pb = (position_bias == nullptr) ? T(0.) : position_bias[offset + i]; - // smem[i] = mask[blockIdx.y * len_buf + i] > 0 ? x[offset + i] * scale + pb : -Inf(); - // } // float local_max = -1e20; // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { // local_max = fmaxf(local_max, smem[i]); @@ -125,7 +138,7 @@ namespace refactor::kernel { auto k = inputs[1]; auto v = inputs[2]; auto o = outputs[0]; - auto att = workspace; + auto att = reinterpret_cast(workspace); auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(d->attSize, 256); auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256); @@ -143,7 +156,7 @@ namespace refactor::kernel { cudaStreamLegacy); softmax<<>>( - att, mask, info.seqLen, info.seqLen); + att, causualMask, info.seqLen, info.seqLen); cublasLtMatmul( handle, d->mul.get(), From 20a34ac3fe9489371b785cb49f143ff236e4121d Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 06/17] =?UTF-8?q?feat(kernel):=20=E5=AE=8C=E6=88=90?= =?UTF-8?q?=E4=B8=8D=E5=B8=A6=20kv=20cache=20=E7=9A=84=E7=AE=80=E5=8D=95?= =?UTF-8?q?=20attention?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../src/kernels/attention/cuda_kernel.cu | 84 ++++++++++--------- 1 file changed, 43 insertions(+), 41 deletions(-) diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index b976ac2c..ba768868 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -40,22 +40,21 @@ namespace refactor::kernel { shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__; } - // float local_max = -1e20; - // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { - // local_max = fmaxf(local_max, smem[i]); - // } - // local_max = functions::blockReduceMax(local_max); + float localMax = -1e20; + for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { + localMax = cub::Max()(localMax, shared[i]); + } + localMax = cuda::blockReduce(localMax, -1e20f, cub::Max()); - // float local_sum = 1e-20; - // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { - // float v = expf(float(smem[i]) - local_max); - // smem[i] = v; - // local_sum += v; - // } - // local_sum = functions::blockReduceSum(local_sum); - // for (int i = threadIdx.x; i < len_buf; i += blockDim.x) { - // x[offset + i] = float(smem[i]) / local_sum; - // } + float localSum = 1e-20; + for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { + localSum += shared[i] = expf(shared[i] - localMax); + } + localSum = cuda::blockReduce(localSum, 1e-20f, cub::Sum()); + auto reciprocal = fdividef(1, localSum); + for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { + att[i] = shared[i] * reciprocal; + } } RoutineWorkspace K::lower(Resources &res) const { @@ -141,35 +140,38 @@ namespace refactor::kernel { auto att = reinterpret_cast(workspace); auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(d->attSize, 256); auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256); - - float alpha = 1, beta = 0; - cublasLtMatmul( - handle, d->mul.get(), - &alpha, - q, d->q.get(), - k, d->k.get(), - &beta, - att, d->att.get(), - att, d->att.get(), - &d->algoQK, - workspaceQK, d->workspaceSizeQK, - cudaStreamLegacy); - + { + half alpha = rsqrtf(info.headDim), beta = 0; + cublasLtMatmul( + handle, d->mul.get(), + &alpha, + q, d->q.get(), + k, d->k.get(), + &beta, + att, d->att.get(), + att, d->att.get(), + &d->algoQK, + workspaceQK, d->workspaceSizeQK, + cudaStreamLegacy); + } softmax<<>>( att, causualMask, info.seqLen, info.seqLen); - - cublasLtMatmul( - handle, d->mul.get(), - &alpha, - att, d->att.get(), - v, d->v.get(), - &beta, - o, d->q.get(), - o, d->q.get(), - &d->algoAV, - workspaceAV, d->workspaceSizeAV, - cudaStreamLegacy); + { + half alpha = 1, beta = 0; + cublasLtMatmul( + handle, d->mul.get(), + &alpha, + att, d->att.get(), + v, d->v.get(), + &beta, + o, d->q.get(), + o, d->q.get(), + &d->algoAV, + workspaceAV, d->workspaceSizeAV, + cudaStreamLegacy); + }; }; + return {std::move(routine), workspaceSize}; } } From 728645916215235eb96cce68a3b36e3c2c514546 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 07/17] =?UTF-8?q?fix:=20=E6=95=B4=E7=90=86=E5=92=8C?= =?UTF-8?q?=E6=94=B9=E6=AD=A3=20attention=20=E6=9E=84=E9=80=A0=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../include/computation/operators/attention.h | 6 +++--- src/05computation/src/operators/attention.cc | 8 ++++++++ src/08-01llm/src/operators/attention.cc | 4 ++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/05computation/include/computation/operators/attention.h b/src/05computation/include/computation/operators/attention.h index d5f37997..753df946 100644 --- a/src/05computation/include/computation/operators/attention.h +++ b/src/05computation/include/computation/operators/attention.h @@ -6,14 +6,14 @@ namespace refactor::computation { struct Attention final : public Operator { - dim_t maxSeqLen; - constexpr Attention(decltype(maxSeqLen) maxSeqLen_) noexcept - : Operator(), maxSeqLen(maxSeqLen_) {} + constexpr Attention() noexcept = default; static size_t typeId() noexcept; size_t opTypeId() const noexcept final; std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const final; + std::string serialize() const noexcept final; }; }// namespace refactor::computation diff --git a/src/05computation/src/operators/attention.cc b/src/05computation/src/operators/attention.cc index 4624482a..b5788639 100644 --- a/src/05computation/src/operators/attention.cc +++ b/src/05computation/src/operators/attention.cc @@ -1,4 +1,5 @@ #include "computation/operators/attention.h" +#include "kernel/collectors/attention.h" namespace refactor::computation { using Op = Attention; @@ -9,5 +10,12 @@ namespace refactor::computation { } auto Op::opTypeId() const noexcept -> size_t { return typeId(); } auto Op::name() const noexcept -> std::string_view { return "Attention"; } + auto Op::candidateKernels(Target target) const -> kernel::CollectorBox { + using Collector_ = kernel::AttentionCollector; + return std::make_unique(target); + } + auto Op::serialize() const noexcept -> std::string { + return "Attention()"; + } }// namespace refactor::computation diff --git a/src/08-01llm/src/operators/attention.cc b/src/08-01llm/src/operators/attention.cc index 15479c6a..8993f1f3 100644 --- a/src/08-01llm/src/operators/attention.cc +++ b/src/08-01llm/src/operators/attention.cc @@ -9,7 +9,7 @@ namespace refactor::llm { : Operator(), maxSeqLen(maxSeqLen_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { - auto maxSeqLen = attributes.getOrInsert("max_seq_len", {0}).float_(); + auto maxSeqLen = attributes.getOrInsert("max_seq_len", {0}).int_(); return OpBox(std::make_unique(maxSeqLen)); } auto Op::typeId() -> size_t { @@ -129,7 +129,7 @@ namespace refactor::llm { auto Op::lower(TensorRefs) const -> computation::OpBox { using Op_ = computation::Attention; - return std::make_unique(maxSeqLen); + return std::make_unique(); } }// namespace refactor::llm From d7bbd3b051f435f27abde39a09c7c7dd57fd4995 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 08/17] =?UTF-8?q?fix(kernel):=20=E8=AE=BE=E7=BD=AE=20share?= =?UTF-8?q?dMemory?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/04kernel/src/kernels/attention/cuda_kernel.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index ba768868..6f5b7d13 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -26,6 +26,7 @@ namespace refactor::kernel { // gridDim.x = batch * nHead // gridDim.y = seqLen // blockDim.x = min(1024, attLen) + // sizeof(shared) = attLen * sizeof(float) template static __global__ void softmax( T *__restrict__ att, @@ -154,7 +155,9 @@ namespace refactor::kernel { workspaceQK, d->workspaceSizeQK, cudaStreamLegacy); } - softmax<<>>( + softmax<<>>( att, causualMask, info.seqLen, info.seqLen); { half alpha = 1, beta = 0; From 9f736e9a1d394bd7de401e2a8778f9cb6db66c4b Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 09/17] =?UTF-8?q?test(llm):=20=E6=B7=BB=E5=8A=A0=20attenti?= =?UTF-8?q?on=20=E5=89=8D=E7=AB=AF=E5=8D=95=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- src/08-01llm/test/test_attention.cpp | 44 ++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 src/08-01llm/test/test_attention.cpp diff --git a/src/08-01llm/test/test_attention.cpp b/src/08-01llm/test/test_attention.cpp new file mode 100644 index 00000000..fbe7d2e7 --- /dev/null +++ b/src/08-01llm/test/test_attention.cpp @@ -0,0 +1,44 @@ +#include "../src/operators/attention.hh" +#include "llm/operators.h" +#include + +using namespace refactor; +using namespace llm; + +TEST(infer, AttentionNoKvCache) { + llm::register_(); + auto batch = DimExpr("N"); + auto numHead = DimExpr(16); + auto seqLen = DimExpr(31); + auto headDim = DimExpr(64); + { + auto edges = Edges{ + {Tensor::share(DataType::FP16, Shape{batch, numHead, seqLen, headDim}, {}), ""}, + {Tensor::share(DataType::FP16, Shape{batch, numHead, seqLen, headDim}, {}), ""}, + {Tensor::share(DataType::FP16, Shape{batch, numHead, seqLen, headDim}, {}), ""}, + }; + count_t inputs[]{0, 1, 2}; + auto infered = Attention(0).infer(TensorRefs(edges, inputs), {true}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::FP16); + ASSERT_EQ(y->shape, edges[0].tensor->shape); + } + { + auto edges = Edges{ + {Tensor::share(DataType::FP16, Shape{batch, numHead, seqLen, headDim}, {}), ""}, + {Tensor::share(DataType::FP16, Shape{batch, DimExpr(4), seqLen, headDim}, {}), ""}, + {Tensor::share(DataType::FP16, Shape{batch, DimExpr(4), seqLen, headDim}, {}), ""}, + }; + count_t inputs[]{0, 1, 2}; + auto infered = Attention(0).infer(TensorRefs(edges, inputs), {true}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::FP16); + ASSERT_EQ(y->shape, edges[0].tensor->shape); + } +} From 4d47e5d7ae4b2811101490084ac7904cbaa4841b Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 10/17] =?UTF-8?q?feat(kernel):=20=E5=B0=81=E8=A3=85=20attL?= =?UTF-8?q?en=20=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../include/kernel/attributes/attention_info.h | 3 +++ src/04kernel/src/attributes/attention_info.cc | 13 +++++++++++++ .../src/kernels/attention/cuda_kernel.cu | 17 +++++++++-------- 3 files changed, 25 insertions(+), 8 deletions(-) create mode 100644 src/04kernel/src/attributes/attention_info.cc diff --git a/src/04kernel/include/kernel/attributes/attention_info.h b/src/04kernel/include/kernel/attributes/attention_info.h index 16d5fb0e..9cd64a56 100644 --- a/src/04kernel/include/kernel/attributes/attention_info.h +++ b/src/04kernel/include/kernel/attributes/attention_info.h @@ -9,6 +9,9 @@ namespace refactor::kernel { DataType dataType; dim_t batch, nHead, nKVHead, seqLen, headDim, cacheLen; bool concatCache, resetCache; + + dim_t attLen(dim_t pastSeqLen) const noexcept; + size_t attSize(dim_t pastSeqLen) const noexcept; }; }// namespace refactor::kernel diff --git a/src/04kernel/src/attributes/attention_info.cc b/src/04kernel/src/attributes/attention_info.cc new file mode 100644 index 00000000..c16c59fa --- /dev/null +++ b/src/04kernel/src/attributes/attention_info.cc @@ -0,0 +1,13 @@ +#include "kernel/attributes/attention_info.h" + +namespace refactor::kernel { + + dim_t AttentionInfo::attLen(dim_t pastSeqLen) const noexcept { + return pastSeqLen + seqLen; + } + + size_t AttentionInfo::attSize(dim_t pastSeqLen) const noexcept { + return batch * nHead * seqLen * attLen(pastSeqLen) * dataType.size(); + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index 6f5b7d13..bcf31dc4 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -71,7 +71,7 @@ namespace refactor::kernel { MatMulDescriptor mul; MatrixDescriptor q, k, v, att; cublasLtMatmulAlgo_t algoQK, algoAV; - size_t attSize, workspaceSizeQK, workspaceSizeAV; + size_t workspaceSizeQK, workspaceSizeAV; Descriptors(CublasLtContext const &context, AttentionInfo info) @@ -112,8 +112,7 @@ namespace refactor::kernel { .order = ROW_MAJOR, .batchCount = static_cast(info.batch * info.nHead), .batchStride = static_cast(info.seqLen * info.seqLen), - }), - attSize(info.batch * info.nHead * info.seqLen * info.seqLen * info.dataType.size()) { + }) { auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att); auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q); algoQK = algoQK_; @@ -125,7 +124,7 @@ namespace refactor::kernel { auto const &context = *res.fetchOrStore(); auto d = std::make_shared(context, info); - auto workspaceSize = d->attSize; + auto workspaceSize = info.attSize(0); workspaceSize = hardware::alignBytes(workspaceSize, 256); workspaceSize += d->workspaceSizeQK; workspaceSize += d->workspaceSizeAV; @@ -139,7 +138,7 @@ namespace refactor::kernel { auto v = inputs[2]; auto o = outputs[0]; auto att = reinterpret_cast(workspace); - auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(d->attSize, 256); + auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(info.attSize(0), 256); auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256); { half alpha = rsqrtf(info.headDim), beta = 0; @@ -155,10 +154,12 @@ namespace refactor::kernel { workspaceQK, d->workspaceSizeQK, cudaStreamLegacy); } + auto attLen = info.attLen(0); + auto bufLen = attLen; softmax<<>>( - att, causualMask, info.seqLen, info.seqLen); + std::min(1024u, attLen), + attLen * sizeof(float)>>>( + att, causualMask, attLen, bufLen); { half alpha = 1, beta = 0; cublasLtMatmul( From 6211b923e1068a5b857a1bb1ffe03754ab573687 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 11/17] =?UTF-8?q?test(kernel):=20=E6=B7=BB=E5=8A=A0=20atte?= =?UTF-8?q?ntion=20=E5=8D=95=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../src/kernels/attention/cuda_kernel.cu | 8 ++-- .../src/utilities/cuda/cublaslt_utils.cu | 2 +- .../test/kernels/attention/test_cuda.cpp | 48 +++++++++++++++++++ 3 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 src/04kernel/test/kernels/attention/test_cuda.cpp diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index bcf31dc4..89bcbe68 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -140,6 +140,7 @@ namespace refactor::kernel { auto att = reinterpret_cast(workspace); auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(info.attSize(0), 256); auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256); + auto stream = cudaStreamLegacy; { half alpha = rsqrtf(info.headDim), beta = 0; cublasLtMatmul( @@ -152,13 +153,14 @@ namespace refactor::kernel { att, d->att.get(), &d->algoQK, workspaceQK, d->workspaceSizeQK, - cudaStreamLegacy); + stream); } auto attLen = info.attLen(0); auto bufLen = attLen; softmax<<>>( + attLen * sizeof(float), + stream>>>( att, causualMask, attLen, bufLen); { half alpha = 1, beta = 0; @@ -172,7 +174,7 @@ namespace refactor::kernel { o, d->q.get(), &d->algoAV, workspaceAV, d->workspaceSizeAV, - cudaStreamLegacy); + stream); }; }; diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu index 7c34ad4c..6fc7e717 100644 --- a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu @@ -109,7 +109,7 @@ namespace refactor::kernel::cublas { CUDA_ASSERT(cudaGetDeviceProperties(&prop, device)); auto workspace = std::numeric_limits::max(); - auto alignment = prop.textureAlignment; + uint32_t alignment = prop.textureAlignment; cublasLtMatmulPreference_t preference; CUBLASLT_ASSERT(cublasLtMatmulPreferenceCreate(&preference)); diff --git a/src/04kernel/test/kernels/attention/test_cuda.cpp b/src/04kernel/test/kernels/attention/test_cuda.cpp new file mode 100644 index 00000000..0c8e4bcb --- /dev/null +++ b/src/04kernel/test/kernels/attention/test_cuda.cpp @@ -0,0 +1,48 @@ +#ifdef USE_CUDA + +#include "../../../src/kernels/attention/cuda_kernel.hh" +#include "hardware/device_manager.h" +#include +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, AttentionCudaNoKvCache) { + // build routine + AttentionInfo info{ + .dataType = DataType::FP16, + .batch = 1, + .nHead = 4, + .nKVHead = 4, + .seqLen = 31, + .headDim = 256, + .cacheLen = 0, + .concatCache = false, + .resetCache = false, + }; + auto q = Tensor::share(DataType::FP16, Shape{info.batch, info.nHead, info.seqLen, info.headDim}), + k = Tensor::share(DataType::FP16, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}), + v = Tensor::share(DataType::FP16, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}), + o = q; + auto kernel = AttentionCuda::build(info); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto [routine, workspaceSize] = kernel->lower(res); + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto qGpu = dev.malloc(q->bytesSize()), + kGpu = dev.malloc(k->bytesSize()), + vGpu = dev.malloc(v->bytesSize()), + oGpu = dev.malloc(o->bytesSize()), + workspace = dev.malloc(workspaceSize); + // inference + { + void const *inputs[]{*qGpu, *kGpu, *vGpu}; + void *outputs[]{*oGpu}; + routine(res, *workspace, inputs, outputs); + } +} + +#endif From 170512648cb859ab04c9da4bb033978ab79d3fd4 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 12/17] =?UTF-8?q?fix(kernel):=20=E8=A7=A3=E5=86=B3=20atten?= =?UTF-8?q?tion=20=E8=AE=BF=E5=AD=98=E9=94=99=E8=AF=AF=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../src/kernels/attention/cuda_kernel.cu | 28 ++++++++++--------- .../test/kernels/attention/test_cuda.cpp | 2 ++ 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index 89bcbe68..c66c2c43 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -12,25 +12,27 @@ namespace refactor::kernel { // seqLen: 此次处理的词数 // posId: 在 kv cache 中的位置 // attLen = pastSeqLen + seqLen - static __forceinline__ __device__ bool - causualMask(int tokenId, int seqLen, - int posId, int attLen) { - // tokenId ↓ |<---attLen---->| - // 0 | * * ... * | - // 1 | * * ... * * | - // 2 | * * ... * * * | - // seqLen: 3 |---------------| - return attLen + tokenId >= posId + seqLen; - } + struct AttentionCausualMask { + __forceinline__ __device__ bool + operator()(int tokenId, int seqLen, + int posId, int attLen) { + // tokenId ↓ |<---attLen---->| + // 0 | * * ... * | + // 1 | * * ... * * | + // 2 | * * ... * * * | + // seqLen: 3 |---------------| + return attLen + tokenId >= posId + seqLen; + } + }; // gridDim.x = batch * nHead // gridDim.y = seqLen // blockDim.x = min(1024, attLen) // sizeof(shared) = attLen * sizeof(float) - template + template static __global__ void softmax( T *__restrict__ att, - bool (*mask)(int, int, int, int), + Mask mask, uint32_t attLen, uint32_t bufLen) { // 找到这个线程块对应的 attention 区域 @@ -161,7 +163,7 @@ namespace refactor::kernel { std::min(1024u, attLen), attLen * sizeof(float), stream>>>( - att, causualMask, attLen, bufLen); + att, AttentionCausualMask(), attLen, bufLen); { half alpha = 1, beta = 0; cublasLtMatmul( diff --git a/src/04kernel/test/kernels/attention/test_cuda.cpp b/src/04kernel/test/kernels/attention/test_cuda.cpp index 0c8e4bcb..64555f16 100644 --- a/src/04kernel/test/kernels/attention/test_cuda.cpp +++ b/src/04kernel/test/kernels/attention/test_cuda.cpp @@ -2,6 +2,7 @@ #include "../../../src/kernels/attention/cuda_kernel.hh" #include "hardware/device_manager.h" +#include "kernel/cuda/functions.cuh" #include #include @@ -43,6 +44,7 @@ TEST(kernel, AttentionCudaNoKvCache) { void *outputs[]{*oGpu}; routine(res, *workspace, inputs, outputs); } + cuda::sync(); } #endif From f5d975684f78d2092a9c148073db3069ae39a26d Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 12:53:06 +0800 Subject: [PATCH 13/17] =?UTF-8?q?fix(kernel):=20=E4=BB=8D=E7=84=B6?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=20cub::BlockReduce=20=E5=B9=B6=E6=94=B9?= =?UTF-8?q?=E6=AD=A3=20Attention?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../cuda/include/kernel/cuda/reduce.cuh | 31 ------------------- .../src/kernels/attention/cuda_kernel.cu | 24 +++++++++----- .../src/kernels/softmax/cuda_kernel.cu | 12 ++++--- .../test/kernels/attention/test_cuda.cpp | 17 +++++++--- 4 files changed, 37 insertions(+), 47 deletions(-) delete mode 100644 src/04kernel/cuda/include/kernel/cuda/reduce.cuh diff --git a/src/04kernel/cuda/include/kernel/cuda/reduce.cuh b/src/04kernel/cuda/include/kernel/cuda/reduce.cuh deleted file mode 100644 index 6a5be4a3..00000000 --- a/src/04kernel/cuda/include/kernel/cuda/reduce.cuh +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef KERNEL_CUDA_REDUCE_CUH -#define KERNEL_CUDA_REDUCE_CUH - -#include - -namespace refactor::kernel::cuda { - - template - __inline__ __device__ T blockReduce(T x, T init, ReductionOp op) { - using WarpReduce = cub::WarpReduce; - __shared__ typename WarpReduce::TempStorage tempStorage; - __shared__ T shared[32], ans; - - auto reduce = WarpReduce(tempStorage); - int lane = threadIdx.x % 32; - int wid = threadIdx.x / 32; - x = reduce.Reduce(x, op); - if (lane == 0) { shared[wid] = x; } - __syncthreads(); - if (wid == 0) { - x = (threadIdx.x < blockDim.x / 32) ? shared[lane] : init; - shared[lane] = reduce.Reduce(x, op); - if (lane == 0) { ans = shared[0]; } - } - __syncthreads(); - return ans;// avoid RAW hazard - } - -}// namespace refactor::kernel::cuda - -#endif// KERNEL_CUDA_REDUCE_CUH diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index c66c2c43..89ce72f4 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -1,7 +1,8 @@ #include "../../utilities/cuda/cublaslt_utils.cuh" #include "cuda_kernel.hh" #include "hardware/functions.h" -#include "kernel/cuda/reduce.cuh" +#include "kernel/cuda/functions.cuh" +#include namespace refactor::kernel { using K = AttentionCuda; @@ -27,7 +28,7 @@ namespace refactor::kernel { // gridDim.x = batch * nHead // gridDim.y = seqLen - // blockDim.x = min(1024, attLen) + // blockDim.x = 1024 // sizeof(shared) = attLen * sizeof(float) template static __global__ void softmax( @@ -36,25 +37,34 @@ namespace refactor::kernel { uint32_t attLen, uint32_t bufLen) { // 找到这个线程块对应的 attention 区域 - att += (blockIdx.x * gridDim.x + gridDim.y) * bufLen; + att += (blockIdx.x * gridDim.y + blockIdx.y) * bufLen; // 将输入装入共享内存并 cast + mask extern __shared__ float shared[];// size = attLen = pastSeqLen + seqLen for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { shared[i] = mask(blockIdx.y, gridDim.y, i, attLen) ? float(att[i]) : -__FLT_MAX__; } + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tempStorage; + __shared__ float sharedMax, sharedSum; + float localMax = -1e20; for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { localMax = cub::Max()(localMax, shared[i]); } - localMax = cuda::blockReduce(localMax, -1e20f, cub::Max()); + localMax = BlockReduce(tempStorage).Reduce(localMax, cub::Max(), attLen); + if (threadIdx.x == 0) { sharedMax = localMax; } + __syncthreads(); float localSum = 1e-20; for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { - localSum += shared[i] = expf(shared[i] - localMax); + localSum += shared[i] = expf(shared[i] - sharedMax); } - localSum = cuda::blockReduce(localSum, 1e-20f, cub::Sum()); - auto reciprocal = fdividef(1, localSum); + localSum = BlockReduce(tempStorage).Reduce(localSum, cub::Sum(), attLen); + if (threadIdx.x == 0) { sharedSum = localSum; } + __syncthreads(); + + auto reciprocal = fdividef(1, sharedSum); for (auto i = threadIdx.x; i < attLen; i += blockDim.x) { att[i] = shared[i] * reciprocal; } diff --git a/src/04kernel/src/kernels/softmax/cuda_kernel.cu b/src/04kernel/src/kernels/softmax/cuda_kernel.cu index c5ecb821..114cb453 100644 --- a/src/04kernel/src/kernels/softmax/cuda_kernel.cu +++ b/src/04kernel/src/kernels/softmax/cuda_kernel.cu @@ -1,5 +1,5 @@ #include "cuda_kernel.hh" -#include "kernel/cuda/reduce.cuh" +#include namespace refactor::kernel { using namespace runtime; @@ -18,8 +18,8 @@ namespace refactor::kernel { template<> __device__ __forceinline__ nv_bfloat16 reciprocal(nv_bfloat16 x) { return hrcp(x); } // blockDim.x === BLOCK_DIM - template - __global__ void blockSoftmaxKernel( + template + __launch_bounds__(BLOCK_DIM) __global__ void blockSoftmaxKernel( T const *__restrict x, T *__restrict y, int mid, @@ -40,8 +40,10 @@ namespace refactor::kernel { for (int i = threadIdx.x + blockDim.x; i < mid; i += blockDim.x) { maxSumThread = MaxSum::reduce(maxSumThread, {x[id + i * stride], 1});// reduce the data to one block } + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tempStorage; __shared__ MaxSum maxSumTotal; - auto maxSumBlock = cuda::blockReduce(maxSumThread, {-__FLT_MAX__, 0}, MaxSum::reduce); + auto maxSumBlock = BlockReduce(tempStorage).Reduce(maxSumThread, MaxSum::reduce); if (threadIdx.x == 0) { maxSumTotal = maxSumBlock;// must set threadIdx.x = 0 write the output to memory } @@ -111,7 +113,7 @@ namespace refactor::kernel { auto y = reinterpret_cast(outputs[0]); int numBlocks = info.pre * info.post; if (info.mid > 1024) { - blockSoftmaxKernel<<>>(x, y, info.mid, info.post); + blockSoftmaxKernel<1024><<>>(x, y, info.mid, info.post); } else { int blockDimX, mid = static_cast(info.mid); for (blockDimX = 32; blockDimX > 4 && mid < blockDimX; blockDimX /= 2) {} diff --git a/src/04kernel/test/kernels/attention/test_cuda.cpp b/src/04kernel/test/kernels/attention/test_cuda.cpp index 64555f16..794ae174 100644 --- a/src/04kernel/test/kernels/attention/test_cuda.cpp +++ b/src/04kernel/test/kernels/attention/test_cuda.cpp @@ -13,7 +13,7 @@ using namespace hardware; TEST(kernel, AttentionCudaNoKvCache) { // build routine AttentionInfo info{ - .dataType = DataType::FP16, + .dataType = DataType::F32, .batch = 1, .nHead = 4, .nKVHead = 4, @@ -23,9 +23,9 @@ TEST(kernel, AttentionCudaNoKvCache) { .concatCache = false, .resetCache = false, }; - auto q = Tensor::share(DataType::FP16, Shape{info.batch, info.nHead, info.seqLen, info.headDim}), - k = Tensor::share(DataType::FP16, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}), - v = Tensor::share(DataType::FP16, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}), + auto q = Tensor::share(DataType::F32, Shape{info.batch, info.nHead, info.seqLen, info.headDim}), + k = Tensor::share(DataType::F32, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}), + v = Tensor::share(DataType::F32, Shape{info.batch, info.nKVHead, info.seqLen, info.headDim}), o = q; auto kernel = AttentionCuda::build(info); ASSERT_TRUE(kernel); @@ -38,6 +38,15 @@ TEST(kernel, AttentionCudaNoKvCache) { vGpu = dev.malloc(v->bytesSize()), oGpu = dev.malloc(o->bytesSize()), workspace = dev.malloc(workspaceSize); + // put input data + std::vector + q_(q->elementsSize(), 1), + k_(k->elementsSize(), 1), + v_(v->elementsSize(), 1), + o_(o->elementsSize()); + qGpu->copyFromHost(q_.data()); + kGpu->copyFromHost(k_.data()); + vGpu->copyFromHost(v_.data()); // inference { void const *inputs[]{*qGpu, *kGpu, *vGpu}; From e16c6791ff4cfae247f95b49de2f7d1aa8083b23 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 2 Feb 2024 16:37:09 +0800 Subject: [PATCH 14/17] =?UTF-8?q?feat(kernel):=20=E6=94=AF=E6=8C=81=20kv?= =?UTF-8?q?=20cache=20attention?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: YdrMaster --- .../kernel/attributes/attention_info.h | 1 + src/04kernel/src/attributes/attention_info.cc | 4 + .../src/kernels/attention/cuda_kernel.cu | 154 +++++++++++++++++- .../src/utilities/cuda/cublaslt_utils.cu | 8 +- .../src/utilities/cuda/cublaslt_utils.cuh | 3 +- 5 files changed, 162 insertions(+), 8 deletions(-) diff --git a/src/04kernel/include/kernel/attributes/attention_info.h b/src/04kernel/include/kernel/attributes/attention_info.h index 9cd64a56..38677681 100644 --- a/src/04kernel/include/kernel/attributes/attention_info.h +++ b/src/04kernel/include/kernel/attributes/attention_info.h @@ -12,6 +12,7 @@ namespace refactor::kernel { dim_t attLen(dim_t pastSeqLen) const noexcept; size_t attSize(dim_t pastSeqLen) const noexcept; + size_t maxAttSize() const noexcept; }; }// namespace refactor::kernel diff --git a/src/04kernel/src/attributes/attention_info.cc b/src/04kernel/src/attributes/attention_info.cc index c16c59fa..a867fd3f 100644 --- a/src/04kernel/src/attributes/attention_info.cc +++ b/src/04kernel/src/attributes/attention_info.cc @@ -10,4 +10,8 @@ namespace refactor::kernel { return batch * nHead * seqLen * attLen(pastSeqLen) * dataType.size(); } + size_t AttentionInfo::maxAttSize() const noexcept { + return batch * nHead * seqLen * (cacheLen ? cacheLen : seqLen) * dataType.size(); + } + }// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index 89ce72f4..da0a69ad 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -70,6 +70,20 @@ namespace refactor::kernel { } } + static __global__ void concatCache( + void *__restrict__ cache, + void const *__restrict__ value, + dim_t pageStrideI, + dim_t pageStrideO, + dim_t lineStride, + dim_t pastOffset) { + + auto tid = blockIdx.x * blockDim.x + threadIdx.x, + dst = tid / pageStrideO * pageStrideI + pastOffset + tid % pageStrideO; + reinterpret_cast(cache)[dst] = reinterpret_cast(value)[tid]; + } + constexpr uint64_t DYNAMIC_WORKSPACE_SIZE = 40 << 20;// 试出来 40MiB 是够用的 + RoutineWorkspace K::lower(Resources &res) const { auto handle = res.fetchOrStore()->handle; @@ -125,8 +139,8 @@ namespace refactor::kernel { .batchCount = static_cast(info.batch * info.nHead), .batchStride = static_cast(info.seqLen * info.seqLen), }) { - auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att); - auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q); + auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att, DYNAMIC_WORKSPACE_SIZE); + auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q, DYNAMIC_WORKSPACE_SIZE); algoQK = algoQK_; algoAV = algoAV_; workspaceSizeQK = workspaceSizeQK_; @@ -187,12 +201,146 @@ namespace refactor::kernel { &d->algoAV, workspaceAV, d->workspaceSizeAV, stream); - }; + } }; return {std::move(routine), workspaceSize}; } + TODO(""); } + if (info.concatCache && !info.resetCache) { + if (info.nHead == info.nKVHead) { + + // RAII for closure + struct Descriptors { + MatMulDescriptor mul; + + Descriptors(AttentionInfo info) + : mul(computeTypeConvert(info.dataType), + dataTypeConvert(info.dataType)) {} + }; + + auto const &context = *res.fetchOrStore(); + auto d = std::make_shared(info); + auto attentionSize = info.maxAttSize(); + auto workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize; + + auto routine = [d = std::move(d), info = this->info]// + (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { + auto handle = res.fetchOrStore()->handle; + auto q = inputs[0]; + auto k = inputs[1]; + auto v = inputs[2]; + auto past = *reinterpret_cast(inputs[3]); + auto attLen = info.attLen(past); + auto o = reinterpret_cast(outputs[0]); + auto kCache = reinterpret_cast(outputs[1]); + auto vCache = reinterpret_cast(outputs[2]); + auto att = reinterpret_cast(reinterpret_cast(workspace) + DYNAMIC_WORKSPACE_SIZE); + auto stream = cudaStreamLegacy; + { + auto itemsPerLine = info.headDim * sizeof(half) / sizeof(float4); + auto threads = info.batch * info.nHead * info.seqLen * itemsPerLine; + auto blocks = (threads + 1023) / 1024; + + concatCache<<>>( + kCache, k, + info.seqLen * itemsPerLine, + info.cacheLen * itemsPerLine, + itemsPerLine, + past * itemsPerLine); + concatCache<<>>( + vCache, v, + info.seqLen * itemsPerLine, + info.cacheLen * itemsPerLine, + itemsPerLine, + past * itemsPerLine); + } + MatrixDescriptor + q_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(info.headDim), + .majorStride = static_cast(info.headDim), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.seqLen * info.headDim), + }), + k_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.headDim), + .cols = static_cast(attLen), + .majorStride = static_cast(info.headDim), + .order = COL_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.cacheLen * info.headDim), + }), + v_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(attLen), + .cols = static_cast(info.headDim), + .majorStride = static_cast(info.headDim), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.cacheLen * info.headDim), + }), + att_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(attLen), + .majorStride = static_cast(info.cacheLen), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.cacheLen * info.seqLen), + }); + { + auto [algo, workspaceSize] = tune( + handle, d->mul, + q_, k_, att_, + DYNAMIC_WORKSPACE_SIZE); + half alpha = rsqrtf(info.headDim), beta = 0; + cublasLtMatmul( + handle, d->mul.get(), + &alpha, + q, q_.get(), + kCache, k_.get(), + &beta, + att, att_.get(), + att, att_.get(), + &algo, + workspace, workspaceSize, + stream); + } + softmax<<>>( + att, AttentionCausualMask(), attLen, info.cacheLen); + { + auto [algo, workspaceSize] = tune( + handle, d->mul, + att_, v_, q_, + DYNAMIC_WORKSPACE_SIZE); + half alpha = 1, beta = 0; + cublasLtMatmul( + handle, d->mul.get(), + &alpha, + att, att_.get(), + vCache, v_.get(), + &beta, + o, q_.get(), + o, q_.get(), + &algo, + workspace, workspaceSize, + stream); + } + }; + + return {std::move(routine), workspaceSize}; + } + TODO(""); + } + TODO(""); } diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu index 6fc7e717..ab797e8f 100644 --- a/src/04kernel/src/utilities/cuda/cublaslt_utils.cu +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cu @@ -101,14 +101,14 @@ namespace refactor::kernel::cublas { MatMulDescriptor const &matmul, MatrixDescriptor const &a, MatrixDescriptor const &b, - MatrixDescriptor const &c) { + MatrixDescriptor const &c, + uint64_t maxWorkspace) { int device; CUDA_ASSERT(cudaGetDevice(&device)); cudaDeviceProp prop; CUDA_ASSERT(cudaGetDeviceProperties(&prop, device)); - auto workspace = std::numeric_limits::max(); uint32_t alignment = prop.textureAlignment; cublasLtMatmulPreference_t preference; @@ -116,8 +116,8 @@ namespace refactor::kernel::cublas { CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, - &workspace, - sizeof(workspace))); + &maxWorkspace, + sizeof(maxWorkspace))); CUBLASLT_ASSERT(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, diff --git a/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh index ccaad7ec..33de075a 100644 --- a/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh +++ b/src/04kernel/src/utilities/cuda/cublaslt_utils.cuh @@ -68,7 +68,8 @@ namespace refactor::kernel::cublas { MatMulDescriptor const &, MatrixDescriptor const &, MatrixDescriptor const &, - MatrixDescriptor const &); + MatrixDescriptor const &, + uint64_t); }// namespace refactor::kernel::cublas From bfa8e9ff14de6b105ece7cead5f08a859e2dcca5 Mon Sep 17 00:00:00 2001 From: panzezhong Date: Wed, 21 Feb 2024 11:10:01 +0800 Subject: [PATCH 15/17] =?UTF-8?q?fix=20(kernel):=20=E4=BF=AE=E5=A4=8Datten?= =?UTF-8?q?tion=E7=AE=97=E5=AD=90=E4=B8=AD=E7=9A=84concat?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/kernels/attention/cuda_kernel.cu | 25 ++++++++++--------- src/08-01llm/src/operators/attention.cc | 12 ++++----- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index da0a69ad..8ba6e987 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -75,12 +75,13 @@ namespace refactor::kernel { void const *__restrict__ value, dim_t pageStrideI, dim_t pageStrideO, - dim_t lineStride, - dim_t pastOffset) { - - auto tid = blockIdx.x * blockDim.x + threadIdx.x, - dst = tid / pageStrideO * pageStrideI + pastOffset + tid % pageStrideO; - reinterpret_cast(cache)[dst] = reinterpret_cast(value)[tid]; + dim_t pastOffset, + dim_t n_items) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n_items) { + auto dst = tid / pageStrideI * pageStrideO + pastOffset + (tid % pageStrideI); + reinterpret_cast(cache)[dst] = reinterpret_cast(value)[tid]; + } } constexpr uint64_t DYNAMIC_WORKSPACE_SIZE = 40 << 20;// 试出来 40MiB 是够用的 @@ -231,7 +232,8 @@ namespace refactor::kernel { auto q = inputs[0]; auto k = inputs[1]; auto v = inputs[2]; - auto past = *reinterpret_cast(inputs[3]); + int64_t past; + cudaMemcpy(&past, inputs[3], sizeof(int64_t), cudaMemcpyDeviceToHost); auto attLen = info.attLen(past); auto o = reinterpret_cast(outputs[0]); auto kCache = reinterpret_cast(outputs[1]); @@ -242,19 +244,18 @@ namespace refactor::kernel { auto itemsPerLine = info.headDim * sizeof(half) / sizeof(float4); auto threads = info.batch * info.nHead * info.seqLen * itemsPerLine; auto blocks = (threads + 1023) / 1024; - concatCache<<>>( kCache, k, info.seqLen * itemsPerLine, info.cacheLen * itemsPerLine, - itemsPerLine, - past * itemsPerLine); + past * itemsPerLine, + threads); concatCache<<>>( vCache, v, info.seqLen * itemsPerLine, info.cacheLen * itemsPerLine, - itemsPerLine, - past * itemsPerLine); + past * itemsPerLine, + threads); } MatrixDescriptor q_(MatrixLayout{ diff --git a/src/08-01llm/src/operators/attention.cc b/src/08-01llm/src/operators/attention.cc index 8993f1f3..cc2e9ce4 100644 --- a/src/08-01llm/src/operators/attention.cc +++ b/src/08-01llm/src/operators/attention.cc @@ -80,10 +80,10 @@ namespace refactor::llm { if (pastSeqLen.dataType != DataType::I64 || pastSeqLen.shape != Shape{DimExpr(1)}) { return Err(InferError(ERROR_MSG("Past seqlen error"))); } - auto pastSeqLenVal = pastSeqLen.data->get()[0]; if (maxSeqLen <= 0) { + auto pastSeqLenVal = pastSeqLen.data->get()[0]; return outputs(pastSeqLenVal + seqlen); - } else if (maxSeqLen >= pastSeqLenVal + seqlen) { + } else if (maxSeqLen >= 1 + seqlen) { return outputs(maxSeqLen); } else { return Err(InferError(ERROR_MSG("max_seq_len must not less than seqlen"))); @@ -94,7 +94,6 @@ namespace refactor::llm { if (pastSeqLen.dataType != DataType::I64 || pastSeqLen.shape != Shape{DimExpr(1)}) { return Err(InferError(ERROR_MSG("Past seqlen error"))); } - auto pastSeqLenVal = pastSeqLen.data->get()[0]; auto const &kCahce = inputs[4], &vCache = inputs[5]; @@ -107,15 +106,14 @@ namespace refactor::llm { kCahce.shape[3] != kvShape[3] || kCahce.shape[0] != kvShape[0] || kCahce.shape[2] != kvShape[2] || - kCahce.shape[3] != kvShape[3] || - pastSeqLenVal < kCacheSeqLen || - pastSeqLenVal < vCacheSeqLen) { + kCahce.shape[3] != kvShape[3]) { return Err(InferError(ERROR_MSG("KV cache error"))); } if (maxSeqLen <= 0) { + auto pastSeqLenVal = pastSeqLen.data->get()[0]; return outputs(pastSeqLenVal + seqlen); - } else if (maxSeqLen >= pastSeqLenVal + seqlen) { + } else if (maxSeqLen >= 1 + seqlen) { return outputs(maxSeqLen); } else { return Err(InferError(ERROR_MSG("max_seq_len must not less than seqlen"))); From a686d2f6cab91bb19914ec6a5fce9d44aad8cbe8 Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 21 Feb 2024 16:03:51 +0800 Subject: [PATCH 16/17] temp Signed-off-by: YdrMaster --- .../src/kernels/attention/cuda_kernel.cu | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index 8ba6e987..23eecba9 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -226,6 +226,52 @@ namespace refactor::kernel { auto attentionSize = info.maxAttSize(); auto workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize; + for (auto attLen = 0; attLen < 2048; ++attLen) { + MatrixDescriptor + q_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(info.headDim), + .majorStride = static_cast(info.headDim), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.seqLen * info.headDim), + }), + k_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.headDim), + .cols = static_cast(attLen), + .majorStride = static_cast(info.headDim), + .order = COL_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.cacheLen * info.headDim), + }), + v_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(attLen), + .cols = static_cast(info.headDim), + .majorStride = static_cast(info.headDim), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.cacheLen * info.headDim), + }), + att_(MatrixLayout{ + .dataType = dataTypeConvert(info.dataType), + .rows = static_cast(info.seqLen), + .cols = static_cast(attLen), + .majorStride = static_cast(info.cacheLen), + .order = ROW_MAJOR, + .batchCount = static_cast(info.batch * info.nHead), + .batchStride = static_cast(info.cacheLen * info.seqLen), + }); + tune(handle, d->mul, + q_, k_, att_, + DYNAMIC_WORKSPACE_SIZE); + tune(handle, d->mul, + att_, v_, q_, + DYNAMIC_WORKSPACE_SIZE); + } + auto routine = [d = std::move(d), info = this->info]// (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { auto handle = res.fetchOrStore()->handle; From 0d1fce4d50b808e6f5b2be9f5ba2a4262c96371b Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Wed, 21 Feb 2024 17:07:21 +0800 Subject: [PATCH 17/17] temp Signed-off-by: YdrMaster --- .../src/kernels/attention/cuda_kernel.cu | 463 +++++++++--------- 1 file changed, 234 insertions(+), 229 deletions(-) diff --git a/src/04kernel/src/kernels/attention/cuda_kernel.cu b/src/04kernel/src/kernels/attention/cuda_kernel.cu index 23eecba9..aca6e8b7 100644 --- a/src/04kernel/src/kernels/attention/cuda_kernel.cu +++ b/src/04kernel/src/kernels/attention/cuda_kernel.cu @@ -1,4 +1,4 @@ -#include "../../utilities/cuda/cublaslt_utils.cuh" +#include "../../utilities/cuda/cublas_context.hh" #include "cuda_kernel.hh" #include "hardware/functions.h" #include "kernel/cuda/functions.cuh" @@ -91,124 +91,125 @@ namespace refactor::kernel { constexpr auto ROW_MAJOR = CUBLASLT_ORDER_ROW; constexpr auto COL_MAJOR = CUBLASLT_ORDER_COL; - if (!info.cacheLen) { - if (info.nHead == info.nKVHead) { - // RAII for closure - struct Descriptors { - MatMulDescriptor mul; - MatrixDescriptor q, k, v, att; - cublasLtMatmulAlgo_t algoQK, algoAV; - size_t workspaceSizeQK, workspaceSizeAV; + // if (!info.cacheLen) { + // if (info.nHead == info.nKVHead) { + // // RAII for closure + // struct Descriptors { + // MatMulDescriptor mul; + // MatrixDescriptor q, k, v, att; + // cublasLtMatmulAlgo_t algoQK, algoAV; + // size_t workspaceSizeQK, workspaceSizeAV; - Descriptors(CublasLtContext const &context, - AttentionInfo info) - : mul(computeTypeConvert(info.dataType), - dataTypeConvert(info.dataType)), - q(MatrixLayout{ - .dataType = dataTypeConvert(info.dataType), - .rows = static_cast(info.seqLen), - .cols = static_cast(info.headDim), - .majorStride = static_cast(info.headDim), - .order = ROW_MAJOR, - .batchCount = static_cast(info.batch * info.nHead), - .batchStride = static_cast(info.seqLen * info.headDim), - }), - k(MatrixLayout{ - .dataType = dataTypeConvert(info.dataType), - .rows = static_cast(info.headDim), - .cols = static_cast(info.seqLen), - .majorStride = static_cast(info.headDim), - .order = COL_MAJOR, - .batchCount = static_cast(info.batch * info.nHead), - .batchStride = static_cast(info.seqLen * info.headDim), - }), - v(MatrixLayout{ - .dataType = dataTypeConvert(info.dataType), - .rows = static_cast(info.seqLen), - .cols = static_cast(info.headDim), - .majorStride = static_cast(info.headDim), - .order = ROW_MAJOR, - .batchCount = static_cast(info.batch * info.nHead), - .batchStride = static_cast(info.seqLen * info.headDim), - }), - att(MatrixLayout{ - .dataType = dataTypeConvert(info.dataType), - .rows = static_cast(info.seqLen), - .cols = static_cast(info.seqLen), - .majorStride = static_cast(info.seqLen), - .order = ROW_MAJOR, - .batchCount = static_cast(info.batch * info.nHead), - .batchStride = static_cast(info.seqLen * info.seqLen), - }) { - auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att, DYNAMIC_WORKSPACE_SIZE); - auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q, DYNAMIC_WORKSPACE_SIZE); - algoQK = algoQK_; - algoAV = algoAV_; - workspaceSizeQK = workspaceSizeQK_; - workspaceSizeAV = workspaceSizeAV_; - } - }; + // Descriptors(CublasLtContext const &context, + // AttentionInfo info) + // : mul(computeTypeConvert(info.dataType), + // dataTypeConvert(info.dataType)), + // q(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.seqLen), + // .cols = static_cast(info.headDim), + // .majorStride = static_cast(info.headDim), + // .order = ROW_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.seqLen * info.headDim), + // }), + // k(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.headDim), + // .cols = static_cast(info.seqLen), + // .majorStride = static_cast(info.headDim), + // .order = COL_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.seqLen * info.headDim), + // }), + // v(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.seqLen), + // .cols = static_cast(info.headDim), + // .majorStride = static_cast(info.headDim), + // .order = ROW_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.seqLen * info.headDim), + // }), + // att(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.seqLen), + // .cols = static_cast(info.seqLen), + // .majorStride = static_cast(info.seqLen), + // .order = ROW_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.seqLen * info.seqLen), + // }) { + // auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att, DYNAMIC_WORKSPACE_SIZE); + // auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q, DYNAMIC_WORKSPACE_SIZE); + // algoQK = algoQK_; + // algoAV = algoAV_; + // workspaceSizeQK = workspaceSizeQK_; + // workspaceSizeAV = workspaceSizeAV_; + // } + // }; - auto const &context = *res.fetchOrStore(); - auto d = std::make_shared(context, info); - auto workspaceSize = info.attSize(0); - workspaceSize = hardware::alignBytes(workspaceSize, 256); - workspaceSize += d->workspaceSizeQK; - workspaceSize += d->workspaceSizeAV; - workspaceSize = hardware::alignBytes(workspaceSize, 256); + // auto const &context = *res.fetchOrStore(); + // auto d = std::make_shared(context, info); + // auto workspaceSize = info.attSize(0); + // workspaceSize = hardware::alignBytes(workspaceSize, 256); + // workspaceSize += d->workspaceSizeQK; + // workspaceSize += d->workspaceSizeAV; + // workspaceSize = hardware::alignBytes(workspaceSize, 256); - auto routine = [d = std::move(d), info = this->info]// - (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { - auto handle = res.fetchOrStore()->handle; - auto q = inputs[0]; - auto k = inputs[1]; - auto v = inputs[2]; - auto o = outputs[0]; - auto att = reinterpret_cast(workspace); - auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(info.attSize(0), 256); - auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256); - auto stream = cudaStreamLegacy; - { - half alpha = rsqrtf(info.headDim), beta = 0; - cublasLtMatmul( - handle, d->mul.get(), - &alpha, - q, d->q.get(), - k, d->k.get(), - &beta, - att, d->att.get(), - att, d->att.get(), - &d->algoQK, - workspaceQK, d->workspaceSizeQK, - stream); - } - auto attLen = info.attLen(0); - auto bufLen = attLen; - softmax<<>>( - att, AttentionCausualMask(), attLen, bufLen); - { - half alpha = 1, beta = 0; - cublasLtMatmul( - handle, d->mul.get(), - &alpha, - att, d->att.get(), - v, d->v.get(), - &beta, - o, d->q.get(), - o, d->q.get(), - &d->algoAV, - workspaceAV, d->workspaceSizeAV, - stream); - } - }; + // auto routine = [d = std::move(d), info = this->info]// + // (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { + // auto handle = res.fetchOrStore()->handle; + // auto q = inputs[0]; + // auto k = inputs[1]; + // auto v = inputs[2]; + // auto o = outputs[0]; + // auto att = reinterpret_cast(workspace); + // auto workspaceQK = reinterpret_cast(workspace) + hardware::alignBytes(info.attSize(0), 256); + // auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256); + // auto stream = cudaStreamLegacy; + // { + // half alpha = rsqrtf(info.headDim), beta = 0; + // cublasLtMatmul( + // handle, d->mul.get(), + // &alpha, + // q, d->q.get(), + // k, d->k.get(), + // &beta, + // att, d->att.get(), + // att, d->att.get(), + // &d->algoQK, + // workspaceQK, d->workspaceSizeQK, + // stream); + // } + // auto attLen = info.attLen(0); + // auto bufLen = attLen; + // softmax<<>>( + // att, AttentionCausualMask(), attLen, bufLen); + // { + // half alpha = 1, beta = 0; + // cublasLtMatmul( + // handle, d->mul.get(), + // &alpha, + // att, d->att.get(), + // v, d->v.get(), + // &beta, + // o, d->q.get(), + // o, d->q.get(), + // &d->algoAV, + // workspaceAV, d->workspaceSizeAV, + // stream); + // } + // }; + + // return {std::move(routine), workspaceSize}; + // } + // TODO(""); + // } - return {std::move(routine), workspaceSize}; - } - TODO(""); - } if (info.concatCache && !info.resetCache) { if (info.nHead == info.nKVHead) { @@ -226,55 +227,9 @@ namespace refactor::kernel { auto attentionSize = info.maxAttSize(); auto workspaceSize = DYNAMIC_WORKSPACE_SIZE + attentionSize; - for (auto attLen = 0; attLen < 2048; ++attLen) { - MatrixDescriptor - q_(MatrixLayout{ - .dataType = dataTypeConvert(info.dataType), - .rows = static_cast(info.seqLen), - .cols = static_cast(info.headDim), - .majorStride = static_cast(info.headDim), - .order = ROW_MAJOR, - .batchCount = static_cast(info.batch * info.nHead), - .batchStride = static_cast(info.seqLen * info.headDim), - }), - k_(MatrixLayout{ - .dataType = dataTypeConvert(info.dataType), - .rows = static_cast(info.headDim), - .cols = static_cast(attLen), - .majorStride = static_cast(info.headDim), - .order = COL_MAJOR, - .batchCount = static_cast(info.batch * info.nHead), - .batchStride = static_cast(info.cacheLen * info.headDim), - }), - v_(MatrixLayout{ - .dataType = dataTypeConvert(info.dataType), - .rows = static_cast(attLen), - .cols = static_cast(info.headDim), - .majorStride = static_cast(info.headDim), - .order = ROW_MAJOR, - .batchCount = static_cast(info.batch * info.nHead), - .batchStride = static_cast(info.cacheLen * info.headDim), - }), - att_(MatrixLayout{ - .dataType = dataTypeConvert(info.dataType), - .rows = static_cast(info.seqLen), - .cols = static_cast(attLen), - .majorStride = static_cast(info.cacheLen), - .order = ROW_MAJOR, - .batchCount = static_cast(info.batch * info.nHead), - .batchStride = static_cast(info.cacheLen * info.seqLen), - }); - tune(handle, d->mul, - q_, k_, att_, - DYNAMIC_WORKSPACE_SIZE); - tune(handle, d->mul, - att_, v_, q_, - DYNAMIC_WORKSPACE_SIZE); - } - auto routine = [d = std::move(d), info = this->info]// (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { - auto handle = res.fetchOrStore()->handle; + auto handle = res.fetchOrStore()->handle; auto q = inputs[0]; auto k = inputs[1]; auto v = inputs[2]; @@ -303,60 +258,85 @@ namespace refactor::kernel { past * itemsPerLine, threads); } - MatrixDescriptor - q_(MatrixLayout{ - .dataType = dataTypeConvert(info.dataType), - .rows = static_cast(info.seqLen), - .cols = static_cast(info.headDim), - .majorStride = static_cast(info.headDim), - .order = ROW_MAJOR, - .batchCount = static_cast(info.batch * info.nHead), - .batchStride = static_cast(info.seqLen * info.headDim), - }), - k_(MatrixLayout{ - .dataType = dataTypeConvert(info.dataType), - .rows = static_cast(info.headDim), - .cols = static_cast(attLen), - .majorStride = static_cast(info.headDim), - .order = COL_MAJOR, - .batchCount = static_cast(info.batch * info.nHead), - .batchStride = static_cast(info.cacheLen * info.headDim), - }), - v_(MatrixLayout{ - .dataType = dataTypeConvert(info.dataType), - .rows = static_cast(attLen), - .cols = static_cast(info.headDim), - .majorStride = static_cast(info.headDim), - .order = ROW_MAJOR, - .batchCount = static_cast(info.batch * info.nHead), - .batchStride = static_cast(info.cacheLen * info.headDim), - }), - att_(MatrixLayout{ - .dataType = dataTypeConvert(info.dataType), - .rows = static_cast(info.seqLen), - .cols = static_cast(attLen), - .majorStride = static_cast(info.cacheLen), - .order = ROW_MAJOR, - .batchCount = static_cast(info.batch * info.nHead), - .batchStride = static_cast(info.cacheLen * info.seqLen), - }); + // MatrixDescriptor + // q_(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.seqLen), + // .cols = static_cast(info.headDim), + // .majorStride = static_cast(info.headDim), + // .order = ROW_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.seqLen * info.headDim), + // }), + // k_(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.headDim), + // .cols = static_cast(attLen), + // .majorStride = static_cast(info.headDim), + // .order = COL_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.cacheLen * info.headDim), + // }), + // v_(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(attLen), + // .cols = static_cast(info.headDim), + // .majorStride = static_cast(info.headDim), + // .order = ROW_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.cacheLen * info.headDim), + // }), + // att_(MatrixLayout{ + // .dataType = dataTypeConvert(info.dataType), + // .rows = static_cast(info.seqLen), + // .cols = static_cast(attLen), + // .majorStride = static_cast(info.cacheLen), + // .order = ROW_MAJOR, + // .batchCount = static_cast(info.batch * info.nHead), + // .batchStride = static_cast(info.cacheLen * info.seqLen), + // }); { - auto [algo, workspaceSize] = tune( - handle, d->mul, - q_, k_, att_, - DYNAMIC_WORKSPACE_SIZE); + // auto [algo, workspaceSize] = tune( + // handle, d->mul, + // q_, k_, att_, + // DYNAMIC_WORKSPACE_SIZE); half alpha = rsqrtf(info.headDim), beta = 0; - cublasLtMatmul( - handle, d->mul.get(), - &alpha, - q, q_.get(), - kCache, k_.get(), - &beta, - att, att_.get(), - att, att_.get(), - &algo, - workspace, workspaceSize, - stream); + // cublasLtMatmul( + // handle, d->mul.get(), + // &alpha, + // q, q_.get(), + // kCache, k_.get(), + // &beta, + // att, att_.get(), + // att, att_.get(), + // &algo, + // workspace, workspaceSize, + // stream); + cublasGemmStridedBatchedEx( + handle, // handle + CUBLAS_OP_T, // trans a + CUBLAS_OP_N, // trans b + attLen, // m + info.seqLen, // n + info.headDim, // k + &alpha, // alpha + kCache, // a + CUDA_R_16F, // a type + info.headDim, // lda + info.cacheLen * info.headDim,// a stride + q, // b + CUDA_R_16F, // b type + info.headDim, // ldb + info.seqLen * info.headDim, // b stride + &beta, // beta + att, // c + CUDA_R_16F, // c type + info.cacheLen, // ldc + info.cacheLen * info.seqLen, // c stride + info.batch * info.nHead, // batch count + CUDA_R_32F, // compute type + CUBLAS_GEMM_DEFAULT // algo + ); } softmax<<>>( att, AttentionCausualMask(), attLen, info.cacheLen); { - auto [algo, workspaceSize] = tune( - handle, d->mul, - att_, v_, q_, - DYNAMIC_WORKSPACE_SIZE); + // auto [algo, workspaceSize] = tune( + // handle, d->mul, + // att_, v_, q_, + // DYNAMIC_WORKSPACE_SIZE); half alpha = 1, beta = 0; - cublasLtMatmul( - handle, d->mul.get(), - &alpha, - att, att_.get(), - vCache, v_.get(), - &beta, - o, q_.get(), - o, q_.get(), - &algo, - workspace, workspaceSize, - stream); + // cublasLtMatmul( + // handle, d->mul.get(), + // &alpha, + // att, att_.get(), + // vCache, v_.get(), + // &beta, + // o, q_.get(), + // o, q_.get(), + // &algo, + // workspace, workspaceSize, + // stream); + cublasGemmStridedBatchedEx( + handle, // handle + CUBLAS_OP_N, // trans a + CUBLAS_OP_N, // trans b + attLen, // m + info.seqLen, // n + info.headDim, // k + &alpha, // alpha + vCache, // a + CUDA_R_16F, // a type + info.headDim, // lda + info.cacheLen * info.headDim,// a stride + att, // b + CUDA_R_16F, // b type + info.cacheLen, // ldb + info.cacheLen * info.seqLen, // b stride + &beta, // beta + o, // c + CUDA_R_16F, // c type + info.headDim, // ldc + info.seqLen * info.headDim, // c stride + info.batch * info.nHead, // batch count + CUDA_R_32F, // compute type + CUBLAS_GEMM_DEFAULT // algo + ); } };