@@ -142,13 +142,18 @@ func.func @main() {
142
142
%c4096 = arith.constant 4096 : index
143
143
%c8 = arith.constant 8 : index
144
144
%txcount = arith.constant 32768 : index
145
+ %c24576 = arith.constant 24576 : index
146
+ %c16384 = arith.constant 16384 : index
147
+ %c49152 = arith.constant 49152 : index
148
+ %c57344 = arith.constant 57344 : index
149
+ %c40960 = arith.constant 40960 : index
145
150
146
151
%tidx = gpu.thread_id x
147
152
%dynamicMem = memref.get_global @dynamicShmem : memref <0 xf16 , 3 >
148
153
%lhsShmem = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [2 , 128 , 64 ], strides : [8192 , 64 , 1 ] : memref <0 xf16 , 3 > to memref <2 x128 x64 xf16 , 3 >
149
154
%rhsShmem2 = memref.reinterpret_cast %dynamicMem to offset : [0 ], sizes : [4 , 64 , 128 ], strides : [8192 ,128 ,1 ] : memref <0 xf16 , 3 > to memref <4 x64 x128 xf16 ,3 >
150
155
%rhsShmem = memref.subview %rhsShmem2 [2 , 0 , 0 ][2 , 64 , 128 ][1 , 1 , 1 ] : memref <4 x64 x128 xf16 ,3 > to memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 >
151
-
156
+ %dynsmem = gpu.dynamic_shared_memory : memref <?x i8 , #gpu.address_space < workgroup >>
152
157
// Step 1. [GPU] Create Async Transactional Barriers (mbarriers)
153
158
%barrier = nvgpu.mbarrier.create -> !barrierType
154
159
@@ -175,28 +180,25 @@ func.func @main() {
175
180
176
181
// Step 4.2 [GPU] TMA Load Pipeline 1 (predicated)
177
182
%pipe1 = arith.constant 0 : index
178
- %p1lhsSlice = memref.subview %lhsShmem [0 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , 3 >
179
- %p1rhsSlice = memref.subview %rhsShmem [0 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 >
180
- %p1halfFirst = memref.subview %p1rhsSlice [0 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 >
181
- %p1halfSecond = memref.subview %p1rhsSlice [32 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 16384 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 20480 >, 3 >
183
+ %lhsSlice1 = memref.view %dynsmem [%c0 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <128 x64 xf16 , #gpu.address_space <workgroup >>
184
+ %halfFirst1 = memref.view %dynsmem [%c32768 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
185
+ %halfSecond1 = memref.view %dynsmem [%c40960 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
182
186
nvgpu.mbarrier.arrive.expect_tx %barrier [%pipe1 ], %txcount , predicate = %cnd : !barrierType
183
187
%dim1 = arith.muli %pipe1 , %c64 : index
184
- nvgpu.tma.async.load %descA [%dim1 , %c0 ], %barrier [%pipe1 ] to %p1lhsSlice , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , 3 >
185
- nvgpu.tma.async.load %descB [%c0 , %dim1 ], %barrier [%pipe1 ] to %p1halfFirst , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[ 128 , 1 ], offset : 16384 >, 3 >
186
- nvgpu.tma.async.load %descB [%c64 , %dim1 ], %barrier [%pipe1 ] to %p1halfSecond , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[ 128 , 1 ], offset : 20480 >, 3 >
188
+ nvgpu.tma.async.load %descA [%dim1 , %c0 ], %barrier [%pipe1 ] to %lhsSlice1 , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , #gpu.address_space < workgroup > >
189
+ nvgpu.tma.async.load %descB [%c0 , %dim1 ], %barrier [%pipe1 ] to %halfFirst1 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space < workgroup > >
190
+ nvgpu.tma.async.load %descB [%c64 , %dim1 ], %barrier [%pipe1 ] to %halfSecond1 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space < workgroup > >
187
191
188
192
// Step 5. [GPU] TMA Load Pipeline 2 (predicated)
189
193
%pipe2 = arith.constant 1 : index
190
- %p2lhsSlice = memref.subview %lhsShmem [1 , 0 , 0 ][1 , 128 , 64 ][1 , 1 , 1 ] : memref <2 x128 x64 xf16 , 3 > to memref <128 x64 xf16 , strided <[64 , 1 ], offset : 8192 >, 3 >
191
- %p2rhsSlice = memref.subview %rhsShmem [1 , 0 , 0 ][1 , 64 , 128 ][1 , 1 , 1 ] : memref <2 x64 x128 xf16 , strided <[8192 , 128 , 1 ], offset : 16384 >, 3 > to memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
192
- %p2halfFirst = memref.subview %p2rhsSlice [0 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
193
- %p2halfSecond = memref.subview %p2rhsSlice [32 , 0 ][64 , 64 ][1 , 1 ] : memref <64 x128 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 > to memref <64 x64 xf16 , strided <[128 , 1 ], offset : 28672 >, 3 >
194
+ %lhsSlice2 = memref.view %dynsmem [%c16384 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <128 x64 xf16 , #gpu.address_space <workgroup >>
195
+ %halfFirst2 = memref.view %dynsmem [%c49152 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
196
+ %halfSecond2 = memref.view %dynsmem [%c57344 ][] : memref <?xi8 , #gpu.address_space <workgroup >> to memref <64 x64 xf16 , #gpu.address_space <workgroup >>
194
197
nvgpu.mbarrier.arrive.expect_tx %barrier [%pipe2 ], %txcount , predicate = %cnd : !barrierType
195
198
%dim2 = arith.muli %pipe2 , %c64 : index
196
- nvgpu.tma.async.load %descA [%dim2 , %c0 ], %barrier [%pipe2 ] to %p2lhsSlice , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , strided <[64 , 1 ], offset : 8192 >, 3 >
197
- nvgpu.tma.async.load %descB [%c0 , %dim2 ], %barrier [%pipe2 ] to %p2halfFirst , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[128 , 1 ], offset : 24576 >, 3 >
198
- nvgpu.tma.async.load %descB [%c64 , %dim2 ], %barrier [%pipe2 ] to %p2halfSecond , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , strided <[128 , 1 ], offset : 28672 >, 3 >
199
-
199
+ nvgpu.tma.async.load %descA [%dim2 , %c0 ], %barrier [%pipe2 ] to %lhsSlice2 , predicate = %cnd : !lhsTensorMap , !barrierType -> memref <128 x64 xf16 , #gpu.address_space <workgroup >>
200
+ nvgpu.tma.async.load %descB [%c0 , %dim2 ], %barrier [%pipe2 ] to %halfFirst2 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space <workgroup >>
201
+ nvgpu.tma.async.load %descB [%c64 , %dim2 ], %barrier [%pipe2 ] to %halfSecond2 , predicate = %cnd : !rhsTensorMap , !barrierType -> memref <64 x64 xf16 , #gpu.address_space <workgroup >>
200
202
// Step 6. [GPU] Initiliaze accumulator matrix
201
203
%14 = nvgpu.warpgroup.mma.init.accumulator -> <fragmented = vector <128 x128 xf32 >>
202
204
@@ -282,4 +284,3 @@ func.func @main() {
282
284
return
283
285
}
284
286
285
-
0 commit comments