Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit f6a78dc

Browse files
author
Theodoros Theodoridis
committed
[Cuda Codegen] Emit launch bounds
Cuda functions can be annotated with launch bounds, that is the maximum number of threads per block (the minimum blocks per multiprocessor can also be specified). This information is used by nvrtc/nvcc during register allocation (and probably other phases as well).
1 parent 45ca22e commit f6a78dc

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

tc/core/polyhedral/cuda/codegen.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,17 @@ void emitArgs(stringstream& ss, const Scop& scop) {
153153
void emitKernelSignature(
154154
stringstream& ss,
155155
const std::string& specializedName,
156-
const Scop& scop) {
156+
const Scop& scop,
157+
const Block& block) {
157158
TC_CHECK_NE(specializedName, "") << "name not provided";
158-
ss << "__global__ void " << specializedName << "(";
159+
auto b0 = block.view[0];
160+
b0 = b0 == 0 ? 1 : b0;
161+
auto b1 = block.view[1];
162+
b1 = b1 == 0 ? 1 : b1;
163+
auto b2 = block.view[2];
164+
b1 = b2 == 0 ? 1 : b2;
165+
ss << "__global__ __launch_bounds__(" << b0 * b1 * b2 << ") void "
166+
<< specializedName << "(";
159167
emitArgs(ss, scop);
160168
ss << ") {" << endl;
161169
}
@@ -753,7 +761,7 @@ string emitCudaKernel(
753761
}
754762

755763
stringstream ss;
756-
emitKernelSignature(ss, specializedName, scop);
764+
emitKernelSignature(ss, specializedName, scop, mscop.numThreads);
757765
emitThreadIdInit(ss, mscop);
758766
emitTensorViews(ss, scop.halide.outputs, paramValues);
759767
emitTensorViews(ss, scop.halide.inputs, paramValues);

test/test_cuda_mapper.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ def fun(float(N, N) A) -> (O)
451451
auto res = std::get<0>(mscop->codegen(specializedName));
452452

453453
string expected(
454-
R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA) {
454+
R"RES(__global__ __launch_bounds__(1) void kernel_anon(int32 N, float32* pO, const float32* pA) {
455455
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
456456
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
457457
float32 (*O)[N] = reinterpret_cast<float32 (*)[N]>(pO);
@@ -480,7 +480,7 @@ def fun(float(N, N) A, float(N, N) B, float(N) C) -> (O)
480480
auto res = std::get<0>(mscop->codegen(specializedName));
481481

482482
string expected =
483-
R"RES(__global__ void kernel_anon(int32 N, float32* pO, const float32* pA, const float32* pB, const float32* pC) {
483+
R"RES(__global__ __launch_bounds__(1) void kernel_anon(int32 N, float32* pO, const float32* pA, const float32* pB, const float32* pC) {
484484
int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z;
485485
int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;
486486
float32 (*O)[512] = reinterpret_cast<float32 (*)[512]>(pO);

0 commit comments

Comments
 (0)