Skip to content

Commit e8dfd70

Browse files
authored
[MLIR][NVGPU] Use gpu.dynamic_shared_memory in tests (#133122)
Reland #133051
1 parent 54cc414 commit e8dfd70

File tree

3 files changed

+56
-49
lines changed

3 files changed

+56
-49
lines changed

mlir/test/Integration/GPU/CUDA/sm90/gemm_f32_f16_f16_128x128x128.mlir

+20-17
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,19 @@ func.func @main() {
141141
%c16 = arith.constant 16 : index
142142
%c4096 = arith.constant 4096 : index
143143
%c8 = arith.constant 8 : index
144-
%txcount = arith.constant 32768 : index
144+
%txcount = arith.constant 32768 : index
145+
%c24576 = arith.constant 24576 : index
146+
%c16384 = arith.constant 16384 : index
147+
%c49152 = arith.constant 49152 : index
148+
%c57344 = arith.constant 57344 : index
149+
%c40960 = arith.constant 40960 : index
145150

146151
%tidx = gpu.thread_id x
147152
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
148153
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3>
149154
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3>
150155
%rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3>
151-
156+
%dynsmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
152157
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
153158
%barrier = nvgpu.mbarrier.create -> !barrierType
154159
%cnd = arith.cmpi eq, %tidx, %c0 : index
@@ -161,31 +166,29 @@ func.func @main() {
161166
nvgpu.tma.prefetch.descriptor %descA : !lhsTensorMap
162167
nvgpu.tma.prefetch.descriptor %descB : !rhsTensorMap
163168

164-
// Step 4.1 [GPU] TMA Load Pipeline 1
169+
// Step 4.1 [GPU] TMA Load Pipeline 1
165170
scf.if %cnd {
166171
%pipe = arith.constant 0 : index
167-
%lhsSlice = memref.subview %lhsShmem[0, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, 3>
168-
%rhsSlice = memref.subview %rhsShmem[0, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 16384>, 3>
169-
%halfFirst = memref.subview %rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
170-
%halfSecond = memref.subview %rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
172+
%lhsSlice = memref.view %dynsmem[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
173+
%halfFirst = memref.view %dynsmem[%c32768][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
174+
%halfSecond = memref.view %dynsmem[%c40960][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
171175
nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe], %txcount : !barrierType
172176
%dim = arith.muli %pipe, %c64 : index
173-
nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, 3>
174-
nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
175-
nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
177+
nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, #gpu.address_space<workgroup>>
178+
nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
179+
nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
176180
}
177181
// Step 4.2 [GPU] TMA Load Pipeline 2
178182
scf.if %cnd {
179183
%pipe = arith.constant 1 : index
180-
%lhsSlice = memref.subview %lhsShmem[1, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
181-
%rhsSlice = memref.subview %rhsShmem[1, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 24576>, 3>
182-
%halfFirst = memref.subview %rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
183-
%halfSecond = memref.subview %rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
184+
%lhsSlice = memref.view %dynsmem[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
185+
%halfFirst = memref.view %dynsmem[%c49152][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
186+
%halfSecond = memref.view %dynsmem[%c57344][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
184187
nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe], %txcount : !barrierType
185188
%dim = arith.muli %pipe, %c64 : index
186-
nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
187-
nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
188-
nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
189+
nvgpu.tma.async.load %descA[%dim, %c0], %barrier[%pipe] to %lhsSlice : !lhsTensorMap, !barrierType -> memref<128x64xf16, #gpu.address_space<workgroup>>
190+
nvgpu.tma.async.load %descB[%c0, %dim], %barrier[%pipe] to %halfFirst : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
191+
nvgpu.tma.async.load %descB[%c64, %dim], %barrier[%pipe] to %halfSecond : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
189192
}
190193

191194
// Step 5. [GPU] Initiliaze accumulator matrix

mlir/test/Integration/GPU/CUDA/sm90/gemm_pred_f32_f16_f16_128x128x128.mlir

+18-17
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,18 @@ func.func @main() {
142142
%c4096 = arith.constant 4096 : index
143143
%c8 = arith.constant 8 : index
144144
%txcount = arith.constant 32768 : index
145+
%c24576 = arith.constant 24576 : index
146+
%c16384 = arith.constant 16384 : index
147+
%c49152 = arith.constant 49152 : index
148+
%c57344 = arith.constant 57344 : index
149+
%c40960 = arith.constant 40960 : index
145150

146151
%tidx = gpu.thread_id x
147152
%dynamicMem = memref.get_global @dynamicShmem : memref<0xf16, 3>
148153
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [2, 128, 64], strides: [8192, 64, 1] : memref<0xf16, 3> to memref<2x128x64xf16, 3>
149154
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset: [0], sizes: [4, 64, 128], strides: [8192,128,1] : memref<0xf16, 3> to memref<4x64x128xf16,3>
150155
%rhsShmem = memref.subview %rhsShmem2[2, 0, 0][2, 64, 128][1, 1, 1] : memref<4x64x128xf16,3> to memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3>
151-
156+
%dynsmem = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
152157
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
153158
%barrier = nvgpu.mbarrier.create -> !barrierType
154159

@@ -175,28 +180,25 @@ func.func @main() {
175180

176181
// Step 4.2 [GPU] TMA Load Pipeline 1 (predicated)
177182
%pipe1 = arith.constant 0 : index
178-
%p1lhsSlice = memref.subview %lhsShmem[0, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, 3>
179-
%p1rhsSlice = memref.subview %rhsShmem[0, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 16384>, 3>
180-
%p1halfFirst = memref.subview %p1rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
181-
%p1halfSecond = memref.subview %p1rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 16384>, 3> to memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
183+
%lhsSlice1 = memref.view %dynsmem[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
184+
%halfFirst1 = memref.view %dynsmem[%c32768][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
185+
%halfSecond1 = memref.view %dynsmem[%c40960][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
182186
nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe1], %txcount, predicate = %cnd : !barrierType
183187
%dim1 = arith.muli %pipe1, %c64 : index
184-
nvgpu.tma.async.load %descA[%dim1, %c0], %barrier[%pipe1] to %p1lhsSlice, predicate = %cnd : !lhsTensorMap, !barrierType -> memref<128x64xf16, 3>
185-
nvgpu.tma.async.load %descB[%c0, %dim1], %barrier[%pipe1] to %p1halfFirst, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 16384>, 3>
186-
nvgpu.tma.async.load %descB[%c64, %dim1], %barrier[%pipe1] to %p1halfSecond, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 20480>, 3>
188+
nvgpu.tma.async.load %descA[%dim1, %c0], %barrier[%pipe1] to %lhsSlice1, predicate = %cnd : !lhsTensorMap, !barrierType -> memref<128x64xf16, #gpu.address_space<workgroup>>
189+
nvgpu.tma.async.load %descB[%c0, %dim1], %barrier[%pipe1] to %halfFirst1, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
190+
nvgpu.tma.async.load %descB[%c64, %dim1], %barrier[%pipe1] to %halfSecond1, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
187191

188192
// Step 5. [GPU] TMA Load Pipeline 2 (predicated)
189193
%pipe2 = arith.constant 1 : index
190-
%p2lhsSlice = memref.subview %lhsShmem[1, 0, 0][1, 128, 64][1, 1, 1] : memref<2x128x64xf16, 3> to memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
191-
%p2rhsSlice = memref.subview %rhsShmem[1, 0, 0][1, 64, 128][1, 1, 1] : memref<2x64x128xf16, strided<[8192, 128, 1], offset: 16384>, 3> to memref<64x128xf16, strided<[128, 1], offset: 24576>, 3>
192-
%p2halfFirst = memref.subview %p2rhsSlice[0, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
193-
%p2halfSecond = memref.subview %p2rhsSlice[32, 0][64, 64][1, 1] : memref<64x128xf16, strided<[128, 1], offset: 24576>, 3> to memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
194+
%lhsSlice2 = memref.view %dynsmem[%c16384][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<128x64xf16, #gpu.address_space<workgroup>>
195+
%halfFirst2 = memref.view %dynsmem[%c49152][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
196+
%halfSecond2 = memref.view %dynsmem[%c57344][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<64x64xf16, #gpu.address_space<workgroup>>
194197
nvgpu.mbarrier.arrive.expect_tx %barrier[%pipe2], %txcount, predicate = %cnd : !barrierType
195198
%dim2 = arith.muli %pipe2, %c64 : index
196-
nvgpu.tma.async.load %descA[%dim2, %c0], %barrier[%pipe2] to %p2lhsSlice, predicate = %cnd : !lhsTensorMap, !barrierType -> memref<128x64xf16, strided<[64, 1], offset: 8192>, 3>
197-
nvgpu.tma.async.load %descB[%c0, %dim2], %barrier[%pipe2] to %p2halfFirst, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 24576>, 3>
198-
nvgpu.tma.async.load %descB[%c64, %dim2], %barrier[%pipe2] to %p2halfSecond, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, strided<[128, 1], offset: 28672>, 3>
199-
199+
nvgpu.tma.async.load %descA[%dim2, %c0], %barrier[%pipe2] to %lhsSlice2, predicate = %cnd : !lhsTensorMap, !barrierType -> memref<128x64xf16, #gpu.address_space<workgroup>>
200+
nvgpu.tma.async.load %descB[%c0, %dim2], %barrier[%pipe2] to %halfFirst2, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
201+
nvgpu.tma.async.load %descB[%c64, %dim2], %barrier[%pipe2] to %halfSecond2, predicate = %cnd : !rhsTensorMap, !barrierType -> memref<64x64xf16, #gpu.address_space<workgroup>>
200202
// Step 6. [GPU] Initiliaze accumulator matrix
201203
%14 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector<128x128xf32>>
202204

@@ -282,4 +284,3 @@ func.func @main() {
282284
return
283285
}
284286

285-

0 commit comments

Comments
 (0)