Skip to content

[mlir][Vector] Infer mask and pass_thru types for maskedload/store #131482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1840,7 +1840,16 @@ def Vector_StoreOp : Vector_Op<"store"> {
}

def Vector_MaskedLoadOp :
Vector_Op<"maskedload">,
Vector_Op<"maskedload", [
AllTypesMatch<["result", "pass_thru"]>,
TypesMatchWith<"mask shape should match result shape",
"result",
"mask",
"VectorType::get(::llvm::cast<VectorType>($_self).getShape(),"
"IntegerType::get($_ctxt, 1),"
"::llvm::cast<VectorType>($_self).getScalableDims())">,
AllElementTypesMatch<["result", "base"]>
]>,
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -1875,10 +1884,10 @@ def Vector_MaskedLoadOp :

```mlir
%0 = vector.maskedload %base[%i], %mask, %pass_thru
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
: memref<?xf32>, vector<8xf32>

%1 = vector.maskedload %base[%i, %j], %mask, %pass_thru
: memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
: memref<?x?xf32>, vector<16xf32>
```
}];
let extraClassDeclaration = [{
Expand All @@ -1896,14 +1905,22 @@ def Vector_MaskedLoadOp :
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
"type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)";
"type($base) `,` type($result)";
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}

def Vector_MaskedStoreOp :
Vector_Op<"maskedstore">,
Vector_Op<"maskedstore", [
TypesMatchWith<"mask shape should match result shape",
"valueToStore",
"mask",
"VectorType::get(::llvm::cast<VectorType>($_self).getShape(),"
"IntegerType::get($_ctxt, 1),"
"::llvm::cast<VectorType>($_self).getScalableDims())">,
AllElementTypesMatch<["valueToStore", "base"]>
]>,
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfNonZeroRankOf<[I1]>:$mask,
Expand Down Expand Up @@ -1937,10 +1954,10 @@ def Vector_MaskedStoreOp :

```mlir
vector.maskedstore %base[%i], %mask, %value
: memref<?xf32>, vector<8xi1>, vector<8xf32>
: memref<?xf32>, vector<8xf32>

vector.maskedstore %base[%i, %j], %mask, %value
: memref<?x?xf32>, vector<16xi1>, vector<16xf32>
: memref<?x?xf32>, vector<16xf32>
```
}];
let extraClassDeclaration = [{
Expand All @@ -1956,7 +1973,7 @@ def Vector_MaskedStoreOp :
}];
let assemblyFormat =
"$base `[` $indices `]` `,` $mask `,` $valueToStore "
"attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)";
"attr-dict `:` type($base) `,` type($valueToStore)";
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
Expand Down
29 changes: 6 additions & 23 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5127,19 +5127,9 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor,
//===----------------------------------------------------------------------===//

LogicalResult MaskedLoadOp::verify() {
VectorType maskVType = getMaskVectorType();
VectorType passVType = getPassThruVectorType();
VectorType resVType = getVectorType();
MemRefType memType = getMemRefType();

if (resVType.getElementType() != memType.getElementType())
return emitOpError("base and result element type should match");
if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (resVType.getShape() != maskVType.getShape())
return emitOpError("expected result shape to match mask shape");
if (resVType != passVType)
return emitOpError("expected pass_thru of same type as result type");
int64_t memRank = getMemRefType().getRank();
if (llvm::size(getIndices()) != memRank)
return emitOpError("requires ") << memRank << " indices";
return success();
}

Expand Down Expand Up @@ -5181,16 +5171,9 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) {
//===----------------------------------------------------------------------===//

LogicalResult MaskedStoreOp::verify() {
VectorType maskVType = getMaskVectorType();
VectorType valueVType = getVectorType();
MemRefType memType = getMemRefType();

if (valueVType.getElementType() != memType.getElementType())
return emitOpError("base and valueToStore element type should match");
if (llvm::size(getIndices()) != memType.getRank())
return emitOpError("requires ") << memType.getRank() << " indices";
if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore shape to match mask shape");
int64_t memRank = getMemRefType().getRank();
if (llvm::size(getIndices()) != memRank)
return emitOpError("requires ") << memRank << " indices";
return success();
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
// CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1>
// CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
// CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi32>
// CHECK: %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32>
// CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32>
func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1891,7 +1891,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {

func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> {
%c0 = arith.constant 0: index
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xf32>
return %0 : vector<16xf32>
}

Expand All @@ -1906,7 +1906,7 @@ func.func @masked_load(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector

func.func @masked_load_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) -> vector<[16]xf32> {
%c0 = arith.constant 0: index
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32>
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xf32>
return %0 : vector<[16]xf32>
}

Expand All @@ -1921,7 +1921,7 @@ func.func @masked_load_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %a

func.func @masked_load_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) -> vector<16xindex> {
%c0 = arith.constant 0: index
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex> into vector<16xindex>
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xindex>
return %0 : vector<16xindex>
}
// CHECK-LABEL: func @masked_load_index
Expand All @@ -1931,7 +1931,7 @@ func.func @masked_load_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2

func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) -> vector<[16]xindex> {
%c0 = arith.constant 0: index
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex> into vector<[16]xindex>
%0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xindex>
return %0 : vector<[16]xindex>
}
// CHECK-LABEL: func @masked_load_index_scalable
Expand All @@ -1945,7 +1945,7 @@ func.func @masked_load_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]

func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vector<16xf32>) {
%c0 = arith.constant 0: index
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xi1>, vector<16xf32>
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<16xf32>
return
}

Expand All @@ -1959,7 +1959,7 @@ func.func @masked_store(%arg0: memref<?xf32>, %arg1: vector<16xi1>, %arg2: vecto

func.func @masked_store_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) {
%c0 = arith.constant 0: index
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xi1>, vector<[16]xf32>
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xf32>, vector<[16]xf32>
return
}

Expand All @@ -1973,7 +1973,7 @@ func.func @masked_store_scalable(%arg0: memref<?xf32>, %arg1: vector<[16]xi1>, %

func.func @masked_store_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg2: vector<16xindex>) {
%c0 = arith.constant 0: index
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xi1>, vector<16xindex>
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<16xindex>
return
}
// CHECK-LABEL: func @masked_store_index
Expand All @@ -1983,7 +1983,7 @@ func.func @masked_store_index(%arg0: memref<?xindex>, %arg1: vector<16xi1>, %arg

func.func @masked_store_index_scalable(%arg0: memref<?xindex>, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) {
%c0 = arith.constant 0: index
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xi1>, vector<[16]xindex>
vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref<?xindex>, vector<[16]xindex>
return
}
// CHECK-LABEL: func @masked_store_index_scalable
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ func.func @fold_vector_load_subview(
func.func @fold_vector_maskedload_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
%1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32> into vector<32xf32>
%1 = vector.maskedload %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xf32>
return %1 : vector<32xf32>
}

Expand All @@ -847,7 +847,7 @@ func.func @fold_vector_maskedload_subview(
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32>
// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xf32>

// -----

Expand All @@ -871,7 +871,7 @@ func.func @fold_vector_store_subview(
func.func @fold_vector_maskedstore_subview(
%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () {
%0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref<f32, strided<[], offset: ?>>
vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xi1>, vector<32xf32>
vector.maskedstore %0[], %arg3, %arg4 : memref<f32, strided<[], offset: ?>>, vector<32xf32>
return
}

Expand All @@ -881,7 +881,7 @@ func.func @fold_vector_maskedstore_subview(
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32>
// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32>
// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xf32>
// CHECK: return

// -----
Expand All @@ -907,7 +907,7 @@ func.func @fold_vector_maskedload_expand_shape(
%arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
%1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
%1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xf32>
return %1 : vector<8xf32>
}

Expand Down Expand Up @@ -943,7 +943,7 @@ func.func @fold_vector_maskedstore_expand_shape(
%arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
%c0 = arith.constant 0 : index
%0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32>
vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32>
vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xf32>
return
}

Expand Down Expand Up @@ -979,7 +979,7 @@ func.func @fold_vector_load_collapse_shape(
func.func @fold_vector_maskedload_collapse_shape(
%arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> {
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
%1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
%1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xf32>
return %1 : vector<8xf32>
}

Expand Down Expand Up @@ -1017,7 +1017,7 @@ func.func @fold_vector_store_collapse_shape(
func.func @fold_vector_maskedstore_collapse_shape(
%arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) {
%0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32>
vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32>
vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xf32>
return
}

Expand Down
Loading