Skip to content

Commit 0abf227

Browse files
authored
[mlir][amdgpu] Add amdgpu.swizzle_bitmode op (llvm#135513)
High level wrapper on top of `rocdl.ds_swizzle`. Also some DPP op cleanup while I'm at here. Will do lowering in separate PR.
1 parent 20cd74a commit 0abf227

File tree

3 files changed

+50
-8
lines changed

3 files changed

+50
-8
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ def AMDGPU_Dialect : Dialect {
3535
let useDefaultAttributePrinterParser = 1;
3636
}
3737

38+
def AnyIntegerOrFloat : AnyTypeOf<[AnySignlessInteger, AnyFloat], "Integer or Float">;
39+
40+
def AnyIntegerOrFloatOr1DVector :
41+
AnyTypeOf<[AnyIntegerOrFloat, VectorOfRankAndType<[1], [AnyIntegerOrFloat]>]>;
42+
3843
//===----------------------------------------------------------------------===//
3944
// AMDGPU general attribute definitions
4045
//===----------------------------------------------------------------------===//
@@ -533,14 +538,15 @@ def AMDGPU_DPPPerm : I32EnumAttr<"DPPPerm",
533538
def AMDGPU_DPPPermAttr : EnumAttr<AMDGPU_Dialect, AMDGPU_DPPPerm,
534539
"dpp_perm">;
535540

536-
def AMDGPU_DPPOp : AMDGPU_Op<"dpp", [SameTypeOperands, AllTypesMatch<["result", "old", "src"]>]>,
541+
def AMDGPU_DPPOp : AMDGPU_Op<"dpp",
542+
[Pure, SameTypeOperands, AllTypesMatch<["result", "old", "src"]>]>,
537543
Arguments<(ins AnyType:$old,
538-
AnyType:$src,
539-
AMDGPU_DPPPermAttr:$kind,
540-
OptionalAttr<AnyAttrOf<[I32Attr, ArrayAttr, UnitAttr]>>:$permArgument,
541-
DefaultValuedAttr<I32Attr, "0xf">:$row_mask,
542-
DefaultValuedAttr<I32Attr, "0xf">:$bank_mask,
543-
DefaultValuedAttr<BoolAttr, "false">:$bound_ctrl)> {
544+
AnyType:$src,
545+
AMDGPU_DPPPermAttr:$kind,
546+
OptionalAttr<AnyAttrOf<[I32Attr, ArrayAttr, UnitAttr]>>:$permArgument,
547+
DefaultValuedAttr<I32Attr, "0xf">:$row_mask,
548+
DefaultValuedAttr<I32Attr, "0xf">:$bank_mask,
549+
DefaultValuedAttr<BoolAttr, "false">:$bound_ctrl)> {
544550
let summary = "AMDGPU DPP operation";
545551
let description = [{
546552
This operation represents DPP functionality in a GPU program.
@@ -565,6 +571,27 @@ def AMDGPU_DPPOp : AMDGPU_Op<"dpp", [SameTypeOperands, AllTypesMatch<["result",
565571
let hasVerifier = 1;
566572
}
567573

574+
def AMDGPU_SwizzleBitModeOp : AMDGPU_Op<"swizzle_bitmode",
575+
[Pure, AllTypesMatch<["result", "src"]>]>,
576+
Arguments<(ins AnyIntegerOrFloatOr1DVector:$src,
577+
I32Attr:$and_mask,
578+
I32Attr:$or_mask,
579+
I32Attr:$xor_mask
580+
)> {
581+
let summary = "AMDGPU ds_swizzle op, bitmode variant";
582+
let description = [{
583+
High-level wrapper on bitmode `rocdl.ds_swizzle` op, masks are represented
584+
as separate fields so user won't need to do manual bitpacking.
585+
586+
Supports arbitrary int/float/vector types, which will be repacked to i32 and
587+
one or more `rocdl.ds_swizzle` ops during lowering.
588+
}];
589+
let results = (outs AnyIntegerOrFloatOr1DVector:$result);
590+
let assemblyFormat = [{
591+
$src $and_mask $or_mask $xor_mask attr-dict `:` type($result)
592+
}];
593+
}
594+
568595
def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
569596
let summary = "Barrier that includes a wait for LDS memory operations.";
570597
let description = [{
@@ -794,7 +821,7 @@ def AMDGPU_GatherToLDSOp :
794821

795822
The `$dst`, along with its indices, points to the memory location the subgroup of this thread
796823
will write to.
797-
824+
798825
Note: only enabled for gfx942 and later.
799826
}];
800827
let assemblyFormat = [{

mlir/test/Dialect/AMDGPU/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,11 @@ func.func @fat_raw_buffer_cast_stripping_offset_affine_map(%m: memref<8xi32, aff
150150
%ret = amdgpu.fat_raw_buffer_cast %m resetOffset : memref<8xi32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
151151
func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
152152
}
153+
154+
// -----
155+
156+
func.func @swizzle_invalid_type(%arg0 : si32) -> si32 {
157+
// expected-error@+1 {{amdgpu.swizzle_bitmode' op operand #0 must be Integer or Float or vector of Integer or Float values of ranks 1}}
158+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : si32
159+
func.return %0 : si32
160+
}

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,10 @@ func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16>
157157
%0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16>
158158
func.return %0 : vector<8xf16>
159159
}
160+
161+
// CHECK-LABEL: func @swizzle_bitmode
162+
func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
163+
// CHECK: amdgpu.swizzle_bitmode
164+
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
165+
func.return %0 : f32
166+
}

0 commit comments

Comments
 (0)