From 3cdf6e781e997c7483840ad1dec932d1ac5a27de Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Sun, 16 Mar 2025 00:17:37 +0000 Subject: [PATCH 1/2] [mlir][ODS] Fix default inferReturnTypes generation for variadic operands --- mlir/test/mlir-tblgen/op-result.td | 18 +++----- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 51 ++++++++++----------- 2 files changed, 30 insertions(+), 39 deletions(-) diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index a4f7af6dbcf1c..334ca118e31c0 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -136,9 +136,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: LogicalResult OpL1::inferReturnTypes // CHECK-NOT: } -// CHECK: if (operands.size() <= 0) -// CHECK-NEXT: return ::mlir::failure(); -// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType(); +// CHECK: OpL1::Adaptor adaptor +// CHECK: ::mlir::Type odsInferredType0 = adaptor.getA().getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; def OpL2 : NS_Op<"op_with_all_types_constraint", @@ -149,11 +148,9 @@ def OpL2 : NS_Op<"op_with_all_types_constraint", // CHECK-LABEL: LogicalResult OpL2::inferReturnTypes // CHECK-NOT: } -// CHECK: if (operands.size() <= 2) -// CHECK-NEXT: return ::mlir::failure(); -// CHECK-NOT: if (operands.size() <= 0) -// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType(); -// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType(); +// CHECK: OpL2::Adaptor adaptor +// CHECK: ::mlir::Type odsInferredType0 = adaptor.getC().getType(); +// CHECK: ::mlir::Type odsInferredType1 = adaptor.getA().getType(); // CHECK: inferredReturnTypes[0] = odsInferredType0; // CHECK: inferredReturnTypes[1] = odsInferredType1; @@ -177,9 +174,8 @@ def OpL4 : NS_Op<"two_inference_edges", [ } // CHECK-LABEL: LogicalResult OpL4::inferReturnTypes -// CHECK: if (operands.size() <= 0) -// CHECK-NEXT: return ::mlir::failure(); -// CHECK: odsInferredType0 = fromInput(operands[0].getType()) +// CHECK: OpL4::Adaptor adaptor +// CHECK: odsInferredType0 = fromInput(adaptor.getInput().getType()) // CHECK: odsInferredType1 = infer0(odsInferredType0) // CHECK: odsInferredType2 = infer1(odsInferredType1) // CHECK: inferredReturnTypes[0] = odsInferredType0 diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index b957c8ee9f8ab..8288e77b8f653 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2641,8 +2641,7 @@ void OpEmitter::genSeparateArgParamBuilder() { // Avoid emitting "resultTypes.size() >= 0u" which is always true. if (!hasVariadicResult || numNonVariadicResults != 0) - body << " " - << "assert(resultTypes.size() " + body << " " << "assert(resultTypes.size() " << (hasVariadicResult ? ">=" : "==") << " " << numNonVariadicResults << "u && \"mismatched number of results\");\n"; @@ -3751,29 +3750,24 @@ void OpEmitter::genTypeInterfaceMethods() { fctx.addSubst("_ctxt", "context"); body << " ::mlir::Builder odsBuilder(context);\n"; - // Preprocessing stage to verify all accesses to operands are valid. - int maxAccessedIndex = -1; - for (int i = 0, e = op.getNumResults(); i != e; ++i) { - const InferredResultType &infer = op.getInferredResultType(i); - if (!infer.isArg()) - continue; - Operator::OperandOrAttribute arg = - op.getArgToOperandOrAttribute(infer.getIndex()); - if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { - maxAccessedIndex = - std::max(maxAccessedIndex, arg.operandOrAttributeIndex()); - } - } - if (maxAccessedIndex != -1) { - body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n"; - body << " return ::mlir::failure();\n"; - } + // Emit an adaptor to access right ranges for ods operands. + body << " " << op.getCppClassName() + << "::Adaptor adaptor(operands, attributes, properties, regions);\n"; - // Process the type inference graph in topological order, starting from types - // that are always fully-inferred: operands and results with constructible - // types. The type inference graph here will always be a DAG, so this gives - // us the correct order for generating the types. -1 is a placeholder to - // indicate the type for a result has not been generated. + // TODO: Ideally, we should be doing some sort of verification here. This + // is however problemetic due to 2 reasons: + // + // 1. Adaptor::verify only verifies attributes. It really should verify + // if the number of given attributes is right too. + // 2. PDL passes empty properties to inferReturnTypes, which does not verify. + // Without properties, it's not really possible to verify the number of + // operands as we do not know the variadic operand segment sizes. + + // Process the type inference graph in topological order, starting from + // types that are always fully-inferred: operands and results with + // constructible types. The type inference graph here will always be a + // DAG, so this gives us the correct order for generating the types. -1 is + // a placeholder to indicate the type for a result has not been generated. SmallVector constructedIndices(op.getNumResults(), -1); int inferredTypeIdx = 0; for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) { @@ -3788,10 +3782,11 @@ void OpEmitter::genTypeInterfaceMethods() { Operator::OperandOrAttribute arg = op.getArgToOperandOrAttribute(infer.getIndex()); if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { - typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) + - "].getType()") - .str(); - + std::string getter = + "adaptor." + + op.getGetterName( + op.getOperand(arg.operandOrAttributeIndex()).name); + typeStr = (getter + "().getType()"); // If this is an attribute, index into the attribute dictionary. } else { auto *attr = From 51273370675c4e01e20987a1beb53fd8704243ad Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Sun, 16 Mar 2025 00:18:30 +0000 Subject: [PATCH 2/2] [mlir][Vector] Infer mask and pass_thru types for maskedload/store --- .../mlir/Dialect/Vector/IR/VectorOps.td | 33 +++++++--- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 29 ++------- .../ArmSMEToSCF/arm-sme-to-scf.mlir | 2 +- .../vector-to-llvm-interface.mlir | 16 ++--- .../Dialect/MemRef/fold-memref-alias-ops.mlir | 16 ++--- .../Dialect/SparseTensor/sparse_vector.mlir | 32 +++++----- .../SparseTensor/sparse_vector_chain.mlir | 4 +- .../SparseTensor/sparse_vector_index.mlir | 6 +- .../SparseTensor/sparse_vector_peeled.mlir | 4 +- .../SparseTensor/vectorize_reduction.mlir | 14 ++-- mlir/test/Dialect/Vector/canonicalize.mlir | 10 +-- .../emulate-narrow-type-unsupported.mlir | 6 +- mlir/test/Dialect/Vector/invalid.mlir | 27 ++++---- mlir/test/Dialect/Vector/ops.mlir | 16 ++--- .../vector-emulate-masked-load-store.mlir | 4 +- .../vector-emulate-narrow-type-unaligned.mlir | 16 ++--- .../Vector/vector-emulate-narrow-type.mlir | 64 +++++++++---------- .../Dialect/Vector/vector-mem-transforms.mlir | 12 ++-- 18 files changed, 157 insertions(+), 154 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index fbbf817ecff98..fd77249402934 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1840,7 +1840,16 @@ def Vector_StoreOp : Vector_Op<"store"> { } def Vector_MaskedLoadOp : - Vector_Op<"maskedload">, + Vector_Op<"maskedload", [ + AllTypesMatch<["result", "pass_thru"]>, + TypesMatchWith<"mask shape should match result shape", + "result", + "mask", + "VectorType::get(::llvm::cast($_self).getShape()," + "IntegerType::get($_ctxt, 1)," + "::llvm::cast($_self).getScalableDims())">, + AllElementTypesMatch<["result", "base"]> + ]>, Arguments<(ins Arg:$base, Variadic:$indices, VectorOfNonZeroRankOf<[I1]>:$mask, @@ -1875,10 +1884,10 @@ def Vector_MaskedLoadOp : ```mlir %0 = vector.maskedload %base[%i], %mask, %pass_thru - : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> + : memref, vector<8xf32> %1 = vector.maskedload %base[%i, %j], %mask, %pass_thru - : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + : memref, vector<16xf32> ``` }]; let extraClassDeclaration = [{ @@ -1896,14 +1905,22 @@ def Vector_MaskedLoadOp : } }]; let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` " - "type($base) `,` type($mask) `,` type($pass_thru) `into` type($result)"; + "type($base) `,` type($result)"; let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; } def Vector_MaskedStoreOp : - Vector_Op<"maskedstore">, + Vector_Op<"maskedstore", [ + TypesMatchWith<"mask shape should match result shape", + "valueToStore", + "mask", + "VectorType::get(::llvm::cast($_self).getShape()," + "IntegerType::get($_ctxt, 1)," + "::llvm::cast($_self).getScalableDims())">, + AllElementTypesMatch<["valueToStore", "base"]> + ]>, Arguments<(ins Arg:$base, Variadic:$indices, VectorOfNonZeroRankOf<[I1]>:$mask, @@ -1937,10 +1954,10 @@ def Vector_MaskedStoreOp : ```mlir vector.maskedstore %base[%i], %mask, %value - : memref, vector<8xi1>, vector<8xf32> + : memref, vector<8xf32> vector.maskedstore %base[%i, %j], %mask, %value - : memref, vector<16xi1>, vector<16xf32> + : memref, vector<16xf32> ``` }]; let extraClassDeclaration = [{ @@ -1956,7 +1973,7 @@ def Vector_MaskedStoreOp : }]; let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $valueToStore " - "attr-dict `:` type($base) `,` type($mask) `,` type($valueToStore)"; + "attr-dict `:` type($base) `,` type($valueToStore)"; let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8e0e723cf4ed3..83b962e54110a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5127,19 +5127,9 @@ LogicalResult StoreOp::fold(FoldAdaptor adaptor, //===----------------------------------------------------------------------===// LogicalResult MaskedLoadOp::verify() { - VectorType maskVType = getMaskVectorType(); - VectorType passVType = getPassThruVectorType(); - VectorType resVType = getVectorType(); - MemRefType memType = getMemRefType(); - - if (resVType.getElementType() != memType.getElementType()) - return emitOpError("base and result element type should match"); - if (llvm::size(getIndices()) != memType.getRank()) - return emitOpError("requires ") << memType.getRank() << " indices"; - if (resVType.getShape() != maskVType.getShape()) - return emitOpError("expected result shape to match mask shape"); - if (resVType != passVType) - return emitOpError("expected pass_thru of same type as result type"); + int64_t memRank = getMemRefType().getRank(); + if (llvm::size(getIndices()) != memRank) + return emitOpError("requires ") << memRank << " indices"; return success(); } @@ -5181,16 +5171,9 @@ OpFoldResult MaskedLoadOp::fold(FoldAdaptor) { //===----------------------------------------------------------------------===// LogicalResult MaskedStoreOp::verify() { - VectorType maskVType = getMaskVectorType(); - VectorType valueVType = getVectorType(); - MemRefType memType = getMemRefType(); - - if (valueVType.getElementType() != memType.getElementType()) - return emitOpError("base and valueToStore element type should match"); - if (llvm::size(getIndices()) != memType.getRank()) - return emitOpError("requires ") << memType.getRank() << " indices"; - if (valueVType.getShape() != maskVType.getShape()) - return emitOpError("expected valueToStore shape to match mask shape"); + int64_t memRank = getMemRefType().getRank(); + if (llvm::size(getIndices()) != memRank) + return emitOpError("requires ") << memRank << " indices"; return success(); } diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir index 4ae710aa29113..10224aec95d48 100644 --- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir +++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir @@ -88,7 +88,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref) // CHECK-NEXT: %[[MASK_1D:.*]] = vector.create_mask %[[MASK_INDEX]] : vector<[4]xi1> // CHECK-NEXT: %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index // CHECK: %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32> -// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32> +// CHECK: %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref, vector<[4]xi32> // CHECK: %[[TILE_UPDATE:.*]] = arm_sme.insert_tile_slice %[[LOAD_SLICE]], %[[CURRENT_TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xi32> into vector<[4]x[4]xi32> // CHECK-NEXT: scf.yield %[[TILE_UPDATE]] : vector<[4]x[4]xi32> func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref, %pad : i32) { diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index c3f06dd4d5dd1..6dcae67abda57 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -1891,7 +1891,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) { func.func @masked_load(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> { %c0 = arith.constant 0: index - %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<16xf32> return %0 : vector<16xf32> } @@ -1906,7 +1906,7 @@ func.func @masked_load(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector func.func @masked_load_scalable(%arg0: memref, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) -> vector<[16]xf32> { %c0 = arith.constant 0: index - %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<[16]xi1>, vector<[16]xf32> into vector<[16]xf32> + %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<[16]xf32> return %0 : vector<[16]xf32> } @@ -1921,7 +1921,7 @@ func.func @masked_load_scalable(%arg0: memref, %arg1: vector<[16]xi1>, %a func.func @masked_load_index(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xindex>) -> vector<16xindex> { %c0 = arith.constant 0: index - %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xindex> into vector<16xindex> + %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<16xindex> return %0 : vector<16xindex> } // CHECK-LABEL: func @masked_load_index @@ -1931,7 +1931,7 @@ func.func @masked_load_index(%arg0: memref, %arg1: vector<16xi1>, %arg2 func.func @masked_load_index_scalable(%arg0: memref, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) -> vector<[16]xindex> { %c0 = arith.constant 0: index - %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<[16]xi1>, vector<[16]xindex> into vector<[16]xindex> + %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<[16]xindex> return %0 : vector<[16]xindex> } // CHECK-LABEL: func @masked_load_index_scalable @@ -1945,7 +1945,7 @@ func.func @masked_load_index_scalable(%arg0: memref, %arg1: vector<[16] func.func @masked_store(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) { %c0 = arith.constant 0: index - vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> + vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<16xf32> return } @@ -1959,7 +1959,7 @@ func.func @masked_store(%arg0: memref, %arg1: vector<16xi1>, %arg2: vecto func.func @masked_store_scalable(%arg0: memref, %arg1: vector<[16]xi1>, %arg2: vector<[16]xf32>) { %c0 = arith.constant 0: index - vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<[16]xi1>, vector<[16]xf32> + vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<[16]xf32> return } @@ -1973,7 +1973,7 @@ func.func @masked_store_scalable(%arg0: memref, %arg1: vector<[16]xi1>, % func.func @masked_store_index(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xindex>) { %c0 = arith.constant 0: index - vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xindex> + vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<16xindex> return } // CHECK-LABEL: func @masked_store_index @@ -1983,7 +1983,7 @@ func.func @masked_store_index(%arg0: memref, %arg1: vector<16xi1>, %arg func.func @masked_store_index_scalable(%arg0: memref, %arg1: vector<[16]xi1>, %arg2: vector<[16]xindex>) { %c0 = arith.constant 0: index - vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<[16]xi1>, vector<[16]xindex> + vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<[16]xindex> return } // CHECK-LABEL: func @masked_store_index_scalable diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 327cacf7d9a20..7246ae4884a19 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -837,7 +837,7 @@ func.func @fold_vector_load_subview( func.func @fold_vector_maskedload_subview( %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> vector<32xf32> { %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref> - %1 = vector.maskedload %0[], %arg3, %arg4 : memref>, vector<32xi1>, vector<32xf32> into vector<32xf32> + %1 = vector.maskedload %0[], %arg3, %arg4 : memref>, vector<32xf32> return %1 : vector<32xf32> } @@ -847,7 +847,7 @@ func.func @fold_vector_maskedload_subview( // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1> // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32> -// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> into vector<32xf32> +// CHECK: vector.maskedload %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xf32> // ----- @@ -871,7 +871,7 @@ func.func @fold_vector_store_subview( func.func @fold_vector_maskedstore_subview( %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3: vector<32xi1>, %arg4: vector<32xf32>) -> () { %0 = memref.subview %arg0[%arg1, %arg2][1, 1][1, 1] : memref<12x32xf32> to memref> - vector.maskedstore %0[], %arg3, %arg4 : memref>, vector<32xi1>, vector<32xf32> + vector.maskedstore %0[], %arg3, %arg4 : memref>, vector<32xf32> return } @@ -881,7 +881,7 @@ func.func @fold_vector_maskedstore_subview( // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index // CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: vector<32xi1> // CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: vector<32xf32> -// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xi1>, vector<32xf32> +// CHECK: vector.maskedstore %[[ARG0]][%[[ARG1]], %[[ARG2]]], %[[ARG3]], %[[ARG4]] : memref<12x32xf32>, vector<32xf32> // CHECK: return // ----- @@ -907,7 +907,7 @@ func.func @fold_vector_maskedload_expand_shape( %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> { %c0 = arith.constant 0 : index %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> - %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32> + %1 = vector.maskedload %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xf32> return %1 : vector<8xf32> } @@ -943,7 +943,7 @@ func.func @fold_vector_maskedstore_expand_shape( %arg0 : memref<32xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) { %c0 = arith.constant 0 : index %0 = memref.expand_shape %arg0 [[0, 1]] output_shape [4, 8] : memref<32xf32> into memref<4x8xf32> - vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xi1>, vector<8xf32> + vector.maskedstore %0[%arg1, %c0], %arg3, %arg4 : memref<4x8xf32>, vector<8xf32> return } @@ -979,7 +979,7 @@ func.func @fold_vector_load_collapse_shape( func.func @fold_vector_maskedload_collapse_shape( %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) -> vector<8xf32> { %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32> - %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32> + %1 = vector.maskedload %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xf32> return %1 : vector<8xf32> } @@ -1017,7 +1017,7 @@ func.func @fold_vector_store_collapse_shape( func.func @fold_vector_maskedstore_collapse_shape( %arg0 : memref<4x8xf32>, %arg1 : index, %arg3: vector<8xi1>, %arg4: vector<8xf32>) { %0 = memref.collapse_shape %arg0 [[0, 1]] : memref<4x8xf32> into memref<32xf32> - vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xi1>, vector<8xf32> + vector.maskedstore %0[%arg1], %arg3, %arg4 : memref<32xf32>, vector<8xf32> return } diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir index 364ba6e71ff3b..c50d44f7faa1e 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector.mlir @@ -65,10 +65,10 @@ // CHECK-VEC4-SVE: scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] { // CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]] // CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1> -// CHECK-VEC4-SVE: %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> +// CHECK-VEC4-SVE: %[[val:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref, vector<[4]xf32> // CHECK-VEC4-SVE: %[[scalev:.*]] = vector.broadcast %{{.*}} : f32 to vector<[4]xf32> // CHECK-VEC4-SVE: %[[scaled:.*]] = arith.mulf %[[val]], %[[scalev]] : vector<[4]xf32> -// CHECK-VEC4-SVE: vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32> +// CHECK-VEC4-SVE: vector.maskedstore %{{.*}}[%[[i]]], %[[mask]], %[[scaled]] : memref<1024xf32>, vector<[4]xf32> // CHECK-VEC4-SVE: } // CHECK-VEC4-SVE: return // @@ -136,9 +136,9 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor // CHECK-VEC16: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { // CHECK-VEC16: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]] // CHECK-VEC16: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> -// CHECK-VEC16: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK-VEC16: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi32> // CHECK-VEC16: %[[zi:.*]] = arith.extui %[[li]] : vector<16xi32> to vector<16xi64> -// CHECK-VEC16: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC16: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xf32> // CHECK-VEC16: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> // CHECK-VEC16: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> // CHECK-VEC16: vector.scatter %{{.*}}[%[[c0]]] [%[[zi]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> @@ -159,8 +159,8 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor // CHECK-VEC16-IDX32: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[c16]] { // CHECK-VEC16-IDX32: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[c16]]] // CHECK-VEC16-IDX32: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> -// CHECK-VEC16-IDX32: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> -// CHECK-VEC16-IDX32: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC16-IDX32: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xi32> +// CHECK-VEC16-IDX32: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %{{.*}} : memref, vector<16xf32> // CHECK-VEC16-IDX32: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %{{.*}} : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> // CHECK-VEC16-IDX32: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> // CHECK-VEC16-IDX32: vector.scatter %{{.*}}[%[[c0]]] [%[[li]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> @@ -185,9 +185,9 @@ func.func @scale_d(%arga: tensor<1024xf32, #DenseVector>, %b: f32, %argx: tensor // CHECK-VEC4-SVE: scf.for %[[i:.*]] = %[[q]] to %[[s]] step %[[step]] { // CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[i]])[%[[step]]] // CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1> -// CHECK-VEC4-SVE: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32> +// CHECK-VEC4-SVE: %[[li:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0i]] : memref, vector<[4]xi32> // CHECK-VEC4-SVE: %[[lii64:.*]] = arith.extui %[[li]] : vector<[4]xi32> to vector<[4]xi64> -// CHECK-VEC4-SVE: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0f]] : memref, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> +// CHECK-VEC4-SVE: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0f]] : memref, vector<[4]xf32> // CHECK-VEC4-SVE: %[[lb:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[lii64]]], %[[mask]], %[[v0f]] : memref<1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> // CHECK-VEC4-SVE: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32> // CHECK-VEC4-SVE: vector.scatter %{{.*}}[%[[c0]]] [%[[lii64]]], %[[mask]], %[[m]] : memref<1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> @@ -282,8 +282,8 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>, // CHECK-VEC4-SVE: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) { // CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]] // CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1> -// CHECK-VEC4-SVE: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> -// CHECK-VEC4-SVE: %[[lb:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<1024xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> +// CHECK-VEC4-SVE: %[[la:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref, vector<[4]xf32> +// CHECK-VEC4-SVE: %[[lb:.*]] = vector.maskedload %{{.*}}[%[[i]]], %[[mask]], %[[v0]] : memref<1024xf32>, vector<[4]xf32> // CHECK-VEC4-SVE: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32> // CHECK-VEC4-SVE: %[[a:.*]] = arith.addf %[[red_in]], %[[m]] : vector<[4]xf32> // CHECK-VEC4-SVE: %[[sa:.*]] = arith.select %[[mask]], %[[a]], %[[red_in]] : vector<[4]xi1>, vector<[4]xf32> @@ -366,9 +366,9 @@ func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, // CHECK-VEC16: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] { // CHECK-VEC16: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]] // CHECK-VEC16: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> -// CHECK-VEC16: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK-VEC16: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi32> // CHECK-VEC16: %[[zj:.*]] = arith.extui %[[lj]] : vector<16xi32> to vector<16xi64> -// CHECK-VEC16: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC16: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xf32> // CHECK-VEC16: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> // CHECK-VEC16: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> // CHECK-VEC16: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[zj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> @@ -393,8 +393,8 @@ func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, // CHECK-VEC16-IDX32: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[c16]] { // CHECK-VEC16-IDX32: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[c16]]] // CHECK-VEC16-IDX32: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<16xi1> -// CHECK-VEC16-IDX32: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> -// CHECK-VEC16-IDX32: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK-VEC16-IDX32: %[[lj:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xi32> +// CHECK-VEC16-IDX32: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %{{.*}} : memref, vector<16xf32> // CHECK-VEC16-IDX32: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %{{.*}} : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> // CHECK-VEC16-IDX32: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<16xf32> // CHECK-VEC16-IDX32: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> @@ -423,9 +423,9 @@ func.func @reduction_d(%arga: tensor<1024xf32, #DenseVector>, // CHECK-VEC4-SVE: scf.for %[[j:.*]] = %[[q]] to %[[s]] step %[[step]] { // CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[s]], %[[j]])[%[[step]]] // CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1> -// CHECK-VEC4-SVE: %[[lji32:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0i]] : memref, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32> +// CHECK-VEC4-SVE: %[[lji32:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0i]] : memref, vector<[4]xi32> // CHECK-VEC4-SVE: %[[lj:.*]] = arith.extui %[[lji32]] : vector<[4]xi32> to vector<[4]xi64> -// CHECK-VEC4-SVE: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0f]] : memref, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> +// CHECK-VEC4-SVE: %[[la:.*]] = vector.maskedload %{{.*}}[%[[j]]], %[[mask]], %[[v0f]] : memref, vector<[4]xf32> // CHECK-VEC4-SVE: %[[lb:.*]] = vector.gather %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[v0f]] : memref<512x1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32> // CHECK-VEC4-SVE: %[[m:.*]] = arith.mulf %[[la]], %[[lb]] : vector<[4]xf32> // CHECK-VEC4-SVE: vector.scatter %{{.*}}[%[[i]], %[[c0]]] [%[[lj]]], %[[mask]], %[[m]] : memref<512x1024xf32>, vector<[4]xi64>, vector<[4]xi1>, vector<[4]xf32> diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir index c99d5d25f7b4a..aaa83c9707329 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir @@ -86,7 +86,7 @@ // CHECK: %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) { // CHECK: %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]] // CHECK: %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1> -// CHECK: %[[VAL_66:.*]] = vector.maskedload %[[VAL_10]]{{\[}}%[[VAL_62]]], %[[VAL_65]], %[[VAL_4]] : memref, vector<8xi1>, vector<8xf64> into vector<8xf64> +// CHECK: %[[VAL_66:.*]] = vector.maskedload %[[VAL_10]]{{\[}}%[[VAL_62]]], %[[VAL_65]], %[[VAL_4]] : memref, vector<8xf64> // CHECK: %[[VAL_67:.*]] = arith.addf %[[VAL_63]], %[[VAL_66]] : vector<8xf64> // CHECK: %[[VAL_68:.*]] = arith.select %[[VAL_65]], %[[VAL_67]], %[[VAL_63]] : vector<8xi1>, vector<8xf64> // CHECK: scf.yield %[[VAL_68]] : vector<8xf64> @@ -94,7 +94,7 @@ // CHECK: %[[VAL_69:.*]] = scf.for %[[VAL_70:.*]] = %[[VAL_60]]#1 to %[[VAL_23]] step %[[VAL_3]] iter_args(%[[VAL_71:.*]] = %[[VAL_61]]) -> (vector<8xf64>) { // CHECK: %[[VAL_73:.*]] = affine.min #map(%[[VAL_23]], %[[VAL_70]]){{\[}}%[[VAL_3]]] // CHECK: %[[VAL_74:.*]] = vector.create_mask %[[VAL_73]] : vector<8xi1> -// CHECK: %[[VAL_75:.*]] = vector.maskedload %[[VAL_13]]{{\[}}%[[VAL_70]]], %[[VAL_74]], %[[VAL_4]] : memref, vector<8xi1>, vector<8xf64> into vector<8xf64> +// CHECK: %[[VAL_75:.*]] = vector.maskedload %[[VAL_13]]{{\[}}%[[VAL_70]]], %[[VAL_74]], %[[VAL_4]] : memref, vector<8xf64> // CHECK: %[[VAL_76:.*]] = arith.addf %[[VAL_71]], %[[VAL_75]] : vector<8xf64> // CHECK: %[[VAL_77:.*]] = arith.select %[[VAL_74]], %[[VAL_76]], %[[VAL_71]] : vector<8xi1>, vector<8xf64> // CHECK: scf.yield %[[VAL_77]] : vector<8xf64> diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir index d88372276989d..37de1cb2bb0ff 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_index.mlir @@ -35,8 +35,8 @@ // CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_12]] to %[[VAL_13]] step %[[VAL_1]] { // CHECK: %[[VAL_15:.*]] = affine.min #map(%[[VAL_13]], %[[VAL_14]]){{\[}}%[[VAL_1]]] // CHECK: %[[VAL_16:.*]] = vector.create_mask %[[VAL_15]] : vector<8xi1> -// CHECK: %[[VAL_17:.*]] = vector.maskedload %[[VAL_9]]{{\[}}%[[VAL_14]]], %[[VAL_16]], %[[VAL_3]] : memref, vector<8xi1>, vector<8xindex> into vector<8xindex> -// CHECK: %[[VAL_18:.*]] = vector.maskedload %[[VAL_10]]{{\[}}%[[VAL_14]]], %[[VAL_16]], %[[VAL_2]] : memref, vector<8xi1>, vector<8xi64> into vector<8xi64> +// CHECK: %[[VAL_17:.*]] = vector.maskedload %[[VAL_9]]{{\[}}%[[VAL_14]]], %[[VAL_16]], %[[VAL_3]] : memref, vector<8xindex> +// CHECK: %[[VAL_18:.*]] = vector.maskedload %[[VAL_10]]{{\[}}%[[VAL_14]]], %[[VAL_16]], %[[VAL_2]] : memref, vector<8xi64> // CHECK: %[[VAL_19:.*]] = arith.index_cast %[[VAL_17]] : vector<8xindex> to vector<8xi64> // CHECK: %[[VAL_20:.*]] = arith.muli %[[VAL_18]], %[[VAL_19]] : vector<8xi64> // CHECK: vector.scatter %[[VAL_11]]{{\[}}%[[VAL_5]]] {{\[}}%[[VAL_17]]], %[[VAL_16]], %[[VAL_20]] : memref<8xi64>, vector<8xindex>, vector<8xi1>, vector<8xi64> @@ -104,7 +104,7 @@ func.func @sparse_index_1d_conj(%arga: tensor<8xi64, #SparseVector>) -> tensor<8 // CHECK: %[[VAL_33:.*]] = vector.broadcast %[[VAL_29]] : index to vector<8xindex> // CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_2]] : vector<8xindex> // CHECK: %[[VAL_35:.*]] = arith.index_cast %[[VAL_34]] : vector<8xindex> to vector<8xi64> -// CHECK: vector.maskedstore %[[VAL_11]]{{\[}}%[[VAL_29]]], %[[VAL_32]], %[[VAL_35]] : memref<8xi64>, vector<8xi1>, vector<8xi64> +// CHECK: vector.maskedstore %[[VAL_11]]{{\[}}%[[VAL_29]]], %[[VAL_32]], %[[VAL_35]] : memref<8xi64>, vector<8xi64> // CHECK: } {"Emitted from" = "linalg.generic"} // CHECK: %[[VAL_36:.*]] = bufferization.to_tensor %[[VAL_11]] : memref<8xi64> // CHECK: return %[[VAL_36]] : tensor<8xi64> diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir index 99d6a3dc390e0..cbb90ee9394ac 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_peeled.mlir @@ -42,9 +42,9 @@ // CHECK: scf.for %[[i2:.*]] = %[[boundary]] to %[[s]] step %[[c16]] { // CHECK: %[[sub:.*]] = affine.apply #[[$map1]](%[[i2]])[%[[s]]] // CHECK: %[[mask2:.*]] = vector.create_mask %[[sub]] : vector<16xi1> -// CHECK: %[[li2:.*]] = vector.maskedload %{{.*}}[%[[i2]]], %[[mask2]], %{{.*}} : memref, vector<16xi1>, vector<16xi32> into vector<16xi32> +// CHECK: %[[li2:.*]] = vector.maskedload %{{.*}}[%[[i2]]], %[[mask2]], %{{.*}} : memref, vector<16xi32> // CHECK: %[[zi2:.*]] = arith.extui %[[li2]] : vector<16xi32> to vector<16xi64> -// CHECK: %[[la2:.*]] = vector.maskedload %{{.*}}[%[[i2]]], %[[mask2]], %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK: %[[la2:.*]] = vector.maskedload %{{.*}}[%[[i2]]], %[[mask2]], %{{.*}} : memref, vector<16xf32> // CHECK: %[[lb2:.*]] = vector.gather %{{.*}}[%[[c0]]] [%[[zi2]]], %[[mask2]], %{{.*}} : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> into vector<16xf32> // CHECK: %[[m2:.*]] = arith.mulf %[[la2]], %[[lb2]] : vector<16xf32> // CHECK: vector.scatter %{{.*}}[%[[c0]]] [%[[zi2]]], %[[mask2]], %[[m2]] : memref<1024xf32>, vector<16xi64>, vector<16xi1>, vector<16xf32> diff --git a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir index 15228c6a5f79a..4f170b5cfa5f7 100644 --- a/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir +++ b/mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir @@ -24,7 +24,7 @@ // CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi13>) { // CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]] // CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1> -// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_3]] : memref, vector<8xi1>, vector<8xi13> into vector<8xi13> +// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_3]] : memref, vector<8xi13> // CHECK-ON: %[[VAL_19:.*]] = arith.ori %[[VAL_15]], %[[VAL_18]] : vector<8xi13> // CHECK-ON: %[[VAL_20:.*]] = arith.select %[[VAL_17]], %[[VAL_19]], %[[VAL_15]] : vector<8xi1>, vector<8xi13> // CHECK-ON: scf.yield %[[VAL_20]] : vector<8xi13> @@ -101,7 +101,7 @@ func.func @sparse_reduction_ori(%argx: tensor, // CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi13>) { // CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]] // CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1> -// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_3]] : memref, vector<8xi1>, vector<8xi13> into vector<8xi13> +// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_3]] : memref, vector<8xi13> // CHECK-ON: %[[VAL_19:.*]] = arith.ori %[[VAL_18]], %[[VAL_15]] : vector<8xi13> // CHECK-ON: %[[VAL_20:.*]] = arith.select %[[VAL_17]], %[[VAL_19]], %[[VAL_15]] : vector<8xi1>, vector<8xi13> // CHECK-ON: scf.yield %[[VAL_20]] : vector<8xi13> @@ -176,7 +176,7 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor, // CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) { // CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]] // CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1> -// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_4]] : memref, vector<8xi1>, vector<8xi32> into vector<8xi32> +// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_4]] : memref, vector<8xi32> // CHECK-ON: %[[VAL_19:.*]] = arith.subi %[[VAL_15]], %[[VAL_18]] : vector<8xi32> // CHECK-ON: %[[VAL_20:.*]] = arith.select %[[VAL_17]], %[[VAL_19]], %[[VAL_15]] : vector<8xi1>, vector<8xi32> // CHECK-ON: scf.yield %[[VAL_20]] : vector<8xi32> @@ -251,7 +251,7 @@ func.func @sparse_reduction_subi(%argx: tensor, // CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) { // CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]] // CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1> -// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_3]] : memref, vector<8xi1>, vector<8xi32> into vector<8xi32> +// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_3]] : memref, vector<8xi32> // CHECK-ON: %[[VAL_19:.*]] = arith.xori %[[VAL_15]], %[[VAL_18]] : vector<8xi32> // CHECK-ON: %[[VAL_20:.*]] = arith.select %[[VAL_17]], %[[VAL_19]], %[[VAL_15]] : vector<8xi1>, vector<8xi32> // CHECK-ON: scf.yield %[[VAL_20]] : vector<8xi32> @@ -327,7 +327,7 @@ func.func @sparse_reduction_xor(%argx: tensor, // CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) { // CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]] // CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1> -// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_3]] : memref, vector<8xi1>, vector<8xi32> into vector<8xi32> +// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_3]] : memref, vector<8xi32> // CHECK-ON: %[[VAL_19:.*]] = arith.addi %[[VAL_15]], %[[VAL_18]] : vector<8xi32> // CHECK-ON: %[[VAL_20:.*]] = arith.select %[[VAL_17]], %[[VAL_19]], %[[VAL_15]] : vector<8xi1>, vector<8xi32> // CHECK-ON: scf.yield %[[VAL_20]] : vector<8xi32> @@ -403,7 +403,7 @@ func.func @sparse_reduction_addi(%argx: tensor, // CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) { // CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]] // CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1> -// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_3]] : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_3]] : memref, vector<8xf32> // CHECK-ON: %[[VAL_19:.*]] = arith.subf %[[VAL_15]], %[[VAL_18]] : vector<8xf32> // CHECK-ON: %[[VAL_20:.*]] = arith.select %[[VAL_17]], %[[VAL_19]], %[[VAL_15]] : vector<8xi1>, vector<8xf32> // CHECK-ON: scf.yield %[[VAL_20]] : vector<8xf32> @@ -479,7 +479,7 @@ func.func @sparse_reduction_subf(%argx: tensor, // CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) { // CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]] // CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1> -// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_3]] : memref, vector<8xi1>, vector<8xf32> into vector<8xf32> +// CHECK-ON: %[[VAL_18:.*]] = vector.maskedload %[[VAL_7]]{{\[}}%[[VAL_14]]], %[[VAL_17]], %[[VAL_3]] : memref, vector<8xf32> // CHECK-ON: %[[VAL_19:.*]] = arith.addf %[[VAL_15]], %[[VAL_18]] : vector<8xf32> // CHECK-ON: %[[VAL_20:.*]] = arith.select %[[VAL_17]], %[[VAL_19]], %[[VAL_15]] : vector<8xi1>, vector<8xf32> // CHECK-ON: scf.yield %[[VAL_20]] : vector<8xf32> diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index bf755b466c7eb..99873cee7af5e 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1277,7 +1277,7 @@ func.func @dead_load(%base: memref, %indices: vector<16xi32>, %mask: vector<16xi1>, %passthru: vector<16xf32>) { %c0 = arith.constant 0 : index %0 = vector.maskedload %base[%c0], %mask, %passthru : - memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + memref, vector<16xf32> %1 = vector.gather %base[%c0][%indices], %mask, %passthru : memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> %2 = vector.expandload %base[%c0], %mask, %passthru : @@ -3072,7 +3072,7 @@ func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_s // CHECK-LABEL: @contiguous_gather // CHECK-SAME: (%[[BASE:.*]]: memref, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>) // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref, vector<16xf32> // CHECK: return %[[R]] func.func @contiguous_gather(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { @@ -3136,7 +3136,7 @@ func.func @contiguous_gather_const_mask(%base: memref, // CHECK-LABEL: @contiguous_gather_step // CHECK-SAME: (%[[BASE:.*]]: memref, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>) // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref, vector<16xf32> // CHECK: return %[[R]] func.func @contiguous_gather_step(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { @@ -3167,7 +3167,7 @@ func.func @gather_broadcast(%base: memref, // CHECK-LABEL: @contiguous_scatter // CHECK-SAME: (%[[BASE:.*]]: memref, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref, vector<16xi1>, vector<16xf32> +// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref, vector<16xf32> func.func @contiguous_scatter(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = arith.constant 0 : index @@ -3198,7 +3198,7 @@ func.func @contiguous_scatter_const_mask(%base: memref, // CHECK-LABEL: @contiguous_scatter_step // CHECK-SAME: (%[[BASE:.*]]: memref, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref, vector<16xi1>, vector<16xf32> +// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref, vector<16xf32> func.func @contiguous_scatter_step(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir index a5a6fc4acfe10..1e8d7a3927e0c 100644 --- a/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir +++ b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir @@ -45,7 +45,7 @@ func.func @vector_maskedload_2d_i8_negative(%arg1: index, %arg2: index, %arg3: i %0 = memref.alloc() : memref<3x4xi8> %mask = vector.create_mask %arg3, %arg3 : vector<2x4xi1> %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru : - memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8> + memref<3x4xi8>, vector<2x4xi8> return %1 : vector<2x4xi8> } @@ -68,7 +68,7 @@ func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x %cst_2 = arith.constant dense<0> : vector<8x16xi4> %27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1> %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1> - %50 = vector.maskedload %0[%c0, %c0, %c0], %48, %cst_2 : memref<8x8x16xi4>, vector<8x16xi1>, vector<8x16xi4> into vector<8x16xi4> + %50 = vector.maskedload %0[%c0, %c0, %c0], %48, %cst_2 : memref<8x8x16xi4>, vector<8x16xi4> %63 = vector.insert %50, %cst_1 [0] : vector<8x16xi4> into vector<8x8x16xi4> return %63 : vector<8x8x16xi4> } @@ -102,7 +102,7 @@ func.func @vector_store_2d_i8_negative(%arg0: vector<2x8xi8>, %arg1: index, %arg func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<2x8xi8>) { %0 = memref.alloc() : memref<3x8xi8> %mask = vector.create_mask %arg2, %arg2 : vector<2x8xi1> - vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<2x8xi1>, vector<2x8xi8> + vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<2x8xi8> return } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 57e348c7d5991..b01ece32859ac 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1348,47 +1348,50 @@ func.func @store_memref_index_mismatch(%base : memref, %value : vector<16 func.func @maskedload_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %pass: vector<16xf32>) { %c0 = arith.constant 0 : index - // expected-error@+1 {{'vector.maskedload' op base and result element type should match}} - %0 = vector.maskedload %base[%c0], %mask, %pass : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + // expected-error@+1 {{'vector.maskedload' op failed to verify that all of {result, base} have same element type}} + %0 = vector.maskedload %base[%c0], %mask, %pass : memref, vector<16xf32> } // ----- + // expected-note@+1 {{prior use here}} func.func @maskedload_dim_mask_mismatch(%base: memref, %mask: vector<15xi1>, %pass: vector<16xf32>) { %c0 = arith.constant 0 : index - // expected-error@+1 {{'vector.maskedload' op expected result shape to match mask shape}} - %0 = vector.maskedload %base[%c0], %mask, %pass : memref, vector<15xi1>, vector<16xf32> into vector<16xf32> + // expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<16xi1>' vs 'vector<15xi1>'}} + %0 = vector.maskedload %base[%c0], %mask, %pass : memref, vector<16xf32> } // ----- + // expected-note@+1 {{prior use here}} func.func @maskedload_pass_thru_type_mask_mismatch(%base: memref, %mask: vector<16xi1>, %pass: vector<16xi32>) { %c0 = arith.constant 0 : index - // expected-error@+1 {{'vector.maskedload' op expected pass_thru of same type as result type}} - %0 = vector.maskedload %base[%c0], %mask, %pass : memref, vector<16xi1>, vector<16xi32> into vector<16xf32> + // expected-error@+1 {{use of value '%pass' expects different type than prior uses: 'vector<16xf32>' vs 'vector<16xi32>'}} + %0 = vector.maskedload %base[%c0], %mask, %pass : memref, vector<16xf32> } // ----- func.func @maskedload_memref_mismatch(%base: memref, %mask: vector<16xi1>, %pass: vector<16xf32>) { // expected-error@+1 {{'vector.maskedload' op requires 1 indices}} - %0 = vector.maskedload %base[], %mask, %pass : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + %0 = vector.maskedload %base[], %mask, %pass : memref, vector<16xf32> } // ----- func.func @maskedstore_base_type_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = arith.constant 0 : index - // expected-error@+1 {{'vector.maskedstore' op base and valueToStore element type should match}} - vector.maskedstore %base[%c0], %mask, %value : memref, vector<16xi1>, vector<16xf32> + // expected-error@+1 {{vector.maskedstore' op failed to verify that all of {valueToStore, base} have same element type}} + vector.maskedstore %base[%c0], %mask, %value : memref, vector<16xf32> } // ----- +// expected-note@+1 {{prior use here}} func.func @maskedstore_dim_mask_mismatch(%base: memref, %mask: vector<15xi1>, %value: vector<16xf32>) { %c0 = arith.constant 0 : index - // expected-error@+1 {{'vector.maskedstore' op expected valueToStore shape to match mask shape}} - vector.maskedstore %base[%c0], %mask, %value : memref, vector<15xi1>, vector<16xf32> + // expected-error@+1 {{use of value '%mask' expects different type than prior uses: 'vector<16xi1>' vs 'vector<15xi1>'}} + vector.maskedstore %base[%c0], %mask, %value : memref, vector<16xf32> } // ----- @@ -1396,7 +1399,7 @@ func.func @maskedstore_dim_mask_mismatch(%base: memref, %mask: vector<15x func.func @maskedstore_memref_mismatch(%base: memref, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = arith.constant 0 : index // expected-error@+1 {{'vector.maskedstore' op requires 1 indices}} - vector.maskedstore %base[%c0, %c0], %mask, %value : memref, vector<16xi1>, vector<16xf32> + vector.maskedstore %base[%c0, %c0], %mask, %value : memref, vector<16xf32> } // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 67484e06f456d..05bdb63721abf 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -845,20 +845,20 @@ func.func @vector_load_and_store_2d_vector_memref(%memref : memref<200x100xvecto // CHECK-LABEL: @masked_load_and_store func.func @masked_load_and_store(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) { %c0 = arith.constant 0 : index - // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> - %0 = vector.maskedload %base[%c0], %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> - // CHECK: vector.maskedstore %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi1>, vector<16xf32> - vector.maskedstore %base[%c0], %mask, %0 : memref, vector<16xi1>, vector<16xf32> + // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}[%{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xf32> + %0 = vector.maskedload %base[%c0], %mask, %passthru : memref, vector<16xf32> + // CHECK: vector.maskedstore %{{.*}}[%{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xf32> + vector.maskedstore %base[%c0], %mask, %0 : memref, vector<16xf32> return } // CHECK-LABEL: @masked_load_and_store2d func.func @masked_load_and_store2d(%base: memref, %mask: vector<16xi1>, %passthru: vector<16xf32>) { %c0 = arith.constant 0 : index - // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> - %0 = vector.maskedload %base[%c0, %c0], %mask, %passthru : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> - // CHECK: vector.maskedstore %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xi1>, vector<16xf32> - vector.maskedstore %base[%c0, %c0], %mask, %0 : memref, vector<16xi1>, vector<16xf32> + // CHECK: %[[X:.*]] = vector.maskedload %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %{{.*}} : memref, vector<16xf32> + %0 = vector.maskedload %base[%c0, %c0], %mask, %passthru : memref, vector<16xf32> + // CHECK: vector.maskedstore %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}}, %[[X]] : memref, vector<16xf32> + vector.maskedstore %base[%c0, %c0], %mask, %0 : memref, vector<16xf32> return } diff --git a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir index 3867f075af8e4..e9ff08c537778 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-masked-load-store.mlir @@ -50,7 +50,7 @@ func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> { %mask = vector.create_mask %idx_1 : vector<4xi1> %s = arith.constant 0.0 : f32 %pass_thru = vector.splat %s : vector<4xf32> - %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32> + %0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xf32> return %0: vector<4xf32> } @@ -90,6 +90,6 @@ func.func @vector_maskedstore(%arg0 : memref<4x5xf32>, %arg1 : vector<4xf32>) { %idx_1 = arith.constant 1 : index %idx_4 = arith.constant 4 : index %mask = vector.create_mask %idx_1 : vector<4xi1> - vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32>, vector<4xi1>, vector<4xf32> + vector.maskedstore %arg0[%idx_0, %idx_4], %mask, %arg1 : memref<4x5xf32>, vector<4xf32> return } diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 6fc974200c6f3..d3604c37a65c5 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -48,7 +48,7 @@ func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %1 = vector.maskedload %0[%c2, %c0], %mask, %passthru : - memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2> + memref<3x5xi2>, vector<5xi2> return %1 : vector<5xi2> } // CHECK-LABEL: func @vector_constant_mask_maskedload_i2( @@ -62,7 +62,7 @@ func.func @vector_constant_mask_maskedload_i2(%passthru: vector<5xi2>) -> vector // CHECK: %[[BITCAST1:.+]] = vector.bitcast %[[INSERT1]] : vector<8xi2> to vector<2xi8> // CHECK: %[[C2:.+]] = arith.constant 2 : index // CHECK: %[[MASKEDLOAD:.+]] = vector.maskedload %alloc[%[[C2]]], %[[NEWMASK:.+]], %[[BITCAST1]] -// CHECK-SAME: : memref<4xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8> +// CHECK-SAME: : memref<4xi8>, vector<2xi8> // CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[MASKEDLOAD]] : vector<2xi8> to vector<8xi2> // CHECK: %[[CST2:.+]] = arith.constant dense : vector<8xi1> // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[ORIGINMASK]], %[[CST2]] @@ -82,7 +82,7 @@ func.func @unaligned_create_mask_dynamic_i2(%m : index, %passthru: vector<5xi2>) %c1 = arith.constant 1 : index %mask = vector.create_mask %m : vector<5xi1> %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru : - memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2> + memref<3x5xi2>, vector<5xi2> return %1 : vector<5xi2> } @@ -107,7 +107,7 @@ func.func @check_unaligned_create_mask_static_i2(%passthru: vector<7xi2>) -> vec %c3 = arith.constant 3 : index %mask = vector.create_mask %c3 : vector<7xi1> %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru : - memref<3x7xi2>, vector<7xi1>, vector<7xi2> into vector<7xi2> + memref<3x7xi2>, vector<7xi2> return %1 : vector<7xi2> } @@ -129,7 +129,7 @@ func.func @vector_constant_mask_maskedload_i2_multidim(%passthru: vector<5xi2>) %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %1 = vector.maskedload %0[%c2, %c0, %c0], %ext_mask, %passthru : - memref<4x3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2> + memref<4x3x5xi2>, vector<5xi2> return %1 : vector<5xi2> } @@ -261,7 +261,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, %c2 = arith.constant 2 : index %mask = vector.constant_mask [3] : vector<3xi1> %1 = vector.maskedload %0[%idx, %c2], %mask, %passthru : - memref<3x3xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2> + memref<3x3xi2>, vector<3xi2> return %1 : vector<3xi2> } @@ -293,7 +293,7 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>, // Use the emulated i8 vector for masked load from the source memory // CHECK: %[[SOURCE:.+]] = vector.maskedload %[[ALLOC]][%[[LINEAR1]]], %[[ONE]], %[[BCAST_PASSTHRU]] -// CHECK-SAME: memref<3xi8>, vector<2xi1>, vector<2xi8> into vector<2xi8> +// CHECK-SAME: memref<3xi8>, vector<2xi8> // Bitcast back to i2 vector // CHECK: %[[BCAST_MASKLOAD:.+]] = vector.bitcast %[[SOURCE]] : vector<2xi8> to vector<8xi2> @@ -328,7 +328,7 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>) %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %1 = vector.maskedload %0[%c1, %c0], %mask, %passthru : - memref<3x5xi2>, vector<5xi1>, vector<5xi2> into vector<5xi2> + memref<3x5xi2>, vector<5xi2> return %1 : vector<5xi2> } diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir index 9dc3eb6989c6c..05c7e635bde3c 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir @@ -127,7 +127,7 @@ func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passt %0 = memref.alloc() : memref<3x4xi8> %mask = vector.create_mask %arg3 : vector<4xi1> %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru : - memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> + memref<3x4xi8>, vector<4xi8> return %1 : vector<4xi8> } // Expect no conversions, i8 is supported. @@ -137,7 +137,7 @@ func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passt // CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8> // CHECK-NEXT: %[[MASK:.+]] = vector.create_mask %[[ARG2]] : vector<4xi1> // CHECK-NEXT: [[L:%.+]] = vector.maskedload %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[ARG3]] : -// CHECK-SAME: memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> +// CHECK-SAME: memref<3x4xi8>, vector<4xi8> // CHECK-NEXT: return // CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)> @@ -152,7 +152,7 @@ func.func @vector_maskedload_i8(%arg1: index, %arg2: index, %arg3: index, %passt // CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1> // CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<4xi8> to vector<1xi32> // CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] : -// CHECK32-SAME: memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> +// CHECK32-SAME: memref<3xi32>, vector<1xi32> // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<4xi8> // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<4xi1>, vector<4xi8> // CHECK32: return %[[SELECT]] @@ -164,7 +164,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt %cst = arith.constant dense<0> : vector<3x8xi4> %mask = vector.create_mask %arg3 : vector<8xi1> %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru : - memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4> + memref<3x8xi4>, vector<8xi4> %2 = vector.insert %1, %cst [0] : vector<8xi4> into vector<3x8xi4> return %2 : vector<3x8xi4> } @@ -180,7 +180,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt // CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1> // CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<8xi4> to vector<4xi8> // CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] : -// CHECK-SAME: memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> +// CHECK-SAME: memref<12xi8>, vector<4xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4> // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4> @@ -196,7 +196,7 @@ func.func @vector_maskedload_i4(%arg1: index, %arg2: index, %arg3: index, %passt // CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1> // CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<8xi4> to vector<1xi32> // CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] : -// CHECK32-SAME: memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> +// CHECK32-SAME: memref<3xi32>, vector<1xi32> // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4> // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<8xi1>, vector<8xi4> @@ -206,7 +206,7 @@ func.func @vector_maskedload_i8_constant_mask(%arg1: index, %arg2: index, %passt %0 = memref.alloc() : memref<3x4xi8> %mask = vector.constant_mask [2] : vector<4xi1> %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru : - memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> + memref<3x4xi8>, vector<4xi8> return %1 : vector<4xi8> } // Expect no conversions, i8 is supported. @@ -216,7 +216,7 @@ func.func @vector_maskedload_i8_constant_mask(%arg1: index, %arg2: index, %passt // CHECK-NEXT: %[[ALLOC:.+]] = memref.alloc() : memref<3x4xi8> // CHECK-NEXT: %[[MASK:.+]] = vector.constant_mask [2] : vector<4xi1> // CHECK-NEXT: [[L:%.+]] = vector.maskedload %[[ALLOC]][%[[ARG0]], %[[ARG1]]], %[[MASK]], %[[ARG2]] : -// CHECK-SAME: memref<3x4xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> +// CHECK-SAME: memref<3x4xi8>, vector<4xi8> // CHECK-NEXT: return // CHECK32-DAG: #[[LOAD_IDX_MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 4)> @@ -229,7 +229,7 @@ func.func @vector_maskedload_i8_constant_mask(%arg1: index, %arg2: index, %passt // CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1> // CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG3]] : vector<4xi8> to vector<1xi32> // CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] : -// CHECK32-SAME: memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> +// CHECK32-SAME: memref<3xi32>, vector<1xi32> // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<4xi8> // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG3]] : vector<4xi1>, vector<4xi8> // CHECK32: return %[[SELECT]] @@ -241,7 +241,7 @@ func.func @vector_maskedload_i4_constant_mask(%arg1: index, %arg2: index, %passt %cst = arith.constant dense<0> : vector<3x8xi4> %mask = vector.constant_mask [4] : vector<8xi1> %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru : - memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4> + memref<3x8xi4>, vector<8xi4> %2 = vector.insert %1, %cst [0] : vector<8xi4> into vector<3x8xi4> return %2 : vector<3x8xi4> } @@ -255,7 +255,7 @@ func.func @vector_maskedload_i4_constant_mask(%arg1: index, %arg2: index, %passt // CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1> // CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG2]] : vector<8xi4> to vector<4xi8> // CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] : -// CHECK-SAME: memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> +// CHECK-SAME: memref<12xi8>, vector<4xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4> // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4> @@ -269,7 +269,7 @@ func.func @vector_maskedload_i4_constant_mask(%arg1: index, %arg2: index, %passt // CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1> // CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[ARG2]] : vector<8xi4> to vector<1xi32> // CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LD_IDX]]], %[[NEW_MASK]], %[[NEW_PASSTHRU]] : -// CHECK32-SAME: memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> +// CHECK32-SAME: memref<3xi32>, vector<1xi32> // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4> // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[BITCAST]], %[[ARG2]] : vector<8xi1>, vector<8xi4> @@ -281,7 +281,7 @@ func.func @vector_maskedload_i4_arith_constant(%passthru: vector<8xi4>) -> vecto %mask = arith.constant dense<[false, true, true, true, true, false, false, false]> : vector<8xi1> %c0 = arith.constant 0 : index %1 = vector.maskedload %0[%c0, %c0], %mask, %passthru : - memref<3x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4> + memref<3x8xi4>, vector<8xi4> return %1 : vector<8xi4> } @@ -313,7 +313,7 @@ func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> { %27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1> %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1> %49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1> - %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4> + %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi4> %63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4> return %63 : vector<8x8x16xi4> } @@ -328,7 +328,7 @@ func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> { // CHECK: %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<8xi1> // CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<8xi8> // CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] : -// CHECK-SAME: memref<512xi8>, vector<8xi1>, vector<8xi8> into vector<8xi8> +// CHECK-SAME: memref<512xi8>, vector<8xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<8xi8> to vector<16xi4> // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4> @@ -343,7 +343,7 @@ func.func @vector_extract_maskedload_i4(%arg1: index) -> vector<8x8x16xi4> { // CHECK32: %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<2xi1> // CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<2xi32> // CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] : -// CHECK32-SAME: memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32> +// CHECK32-SAME: memref<128xi32>, vector<2xi32> // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<16xi4> // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4> @@ -357,7 +357,7 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> { %27 = vector.constant_mask [8, 4, 16] : vector<8x8x16xi1> %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1> %49 = vector.extract %48[0] : vector<16xi1> from vector<8x16xi1> - %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi1>, vector<16xi4> into vector<16xi4> + %50 = vector.maskedload %0[%c0, %c0, %c0], %49, %cst_2 : memref<8x8x16xi4>, vector<16xi4> %63 = vector.insert %50, %cst_1 [0, 0] : vector<16xi4> into vector<8x8x16xi4> return %63 : vector<8x8x16xi4> } @@ -372,7 +372,7 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> { // CHECK: %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<8xi1> // CHECK: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<8xi8> // CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] : -// CHECK-SAME: memref<512xi8>, vector<8xi1>, vector<8xi8> into vector<8xi8> +// CHECK-SAME: memref<512xi8>, vector<8xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<8xi8> to vector<16xi4> // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4> @@ -387,7 +387,7 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> { // CHECK32: %[[NEW_EXT2:.+]] = vector.extract %[[NEW_EXT1]][0] : vector<2xi1> // CHECK32: %[[NEW_PASSTHRU:.+]] = vector.bitcast %[[PASSTHRU]] : vector<16xi4> to vector<2xi32> // CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%c0], %[[NEW_EXT2]], %[[NEW_PASSTHRU]] : -// CHECK32-SAME: memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32> +// CHECK32-SAME: memref<128xi32>, vector<2xi32> // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<16xi4> // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4> @@ -482,7 +482,7 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind func.func @vector_maskedstore_i8(%arg0: index, %arg1: index, %arg2: index, %value: vector<8xi8>) { %0 = memref.alloc() : memref<3x8xi8> %mask = vector.create_mask %arg2 : vector<8xi1> - vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8> + vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi8> return } // Expect no conversions, i8 is supported. @@ -526,7 +526,7 @@ func.func @vector_maskedstore_i4( %0 = memref.alloc() : memref<3x8xi4> %mask = vector.create_mask %num_elements_to_store : vector<8xi1> vector.maskedstore %0[%idx1, %idx2], %mask, %value : - memref<3x8xi4>, vector<8xi1>, vector<8xi4> + memref<3x8xi4>, vector<8xi4> return } // CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> @@ -543,11 +543,11 @@ func.func @vector_maskedstore_i4( // CHECK: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_11]]()[%[[NUM_EL_TO_STORE]]] // CHECK: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<4xi1> // CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8> -// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> +// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4> // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> // CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8> -// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8> +// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi8> // CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> // CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> @@ -563,18 +563,18 @@ func.func @vector_maskedstore_i4( // CHECK32: %[[MASK_IDX:.+]] = affine.apply #[[$ATTR_18]]()[%[[NUM_EL_TO_STORE]]] // CHECK32: %[[NEW_MASK:.+]] = vector.create_mask %[[MASK_IDX]] : vector<1xi1> // CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32> -// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> +// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi32> // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4> // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> // CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32> -// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32> +// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi32> // ----- func.func @vector_maskedstore_i8_constant_mask(%arg0: index, %arg1: index, %value: vector<8xi8>) { %0 = memref.alloc() : memref<3x8xi8> %mask = vector.constant_mask [4] : vector<8xi1> - vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi1>, vector<8xi8> + vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<8xi8> return } // Expect no conversions, i8 is supported. @@ -613,7 +613,7 @@ func.func @vector_maskedstore_i4_constant_mask( %0 = memref.alloc() : memref<3x8xi4> %mask = vector.constant_mask [4] : vector<8xi1> vector.maskedstore %0[%idx_1, %idx_2], %mask, %val_to_store : - memref<3x8xi4>, vector<8xi1>, vector<8xi4> + memref<3x8xi4>, vector<8xi4> return } @@ -627,11 +627,11 @@ func.func @vector_maskedstore_i4_constant_mask( // CHECK: %[[LIDX:.+]] = affine.apply #[[$ATTR_12]]()[%[[IDX_1]], %[[IDX_2]]] // CHECK: %[[NEW_MASK:.+]] = vector.constant_mask [2] : vector<4xi1> // CHECK: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<4xi8> -// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi1>, vector<4xi8> into vector<4xi8> +// CHECK: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<12xi8>, vector<4xi8> // CHECK: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<4xi8> to vector<8xi4> // CHECK: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> // CHECK: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<4xi8> -// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8> +// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi8> // CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> // CHECK32-LABEL: func.func @vector_maskedstore_i4_constant_mask( @@ -643,11 +643,11 @@ func.func @vector_maskedstore_i4_constant_mask( // CHECK32: %[[LIDX:.+]] = affine.apply #[[$ATTR_20]]()[%[[IDX_1]], %[[IDX_2]]] // CHECK32: %[[NEW_MASK:.+]] = vector.constant_mask [1] : vector<1xi1> // CHECK32: %[[PASS_THRU:.+]] = arith.constant dense<0> : vector<1xi32> -// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi1>, vector<1xi32> into vector<1xi32> +// CHECK32: %[[LOAD:.+]] = vector.maskedload %[[ALLOC]][%[[LIDX]]], %[[NEW_MASK]], %[[PASS_THRU]] : memref<3xi32>, vector<1xi32> // CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi32> to vector<8xi4> // CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_MASK]], %[[VAL_TO_STORE]], %[[BITCAST]] : vector<8xi1>, vector<8xi4> // CHECK32: %[[NEW_VAL:.+]] = vector.bitcast %[[SELECT]] : vector<8xi4> to vector<1xi32> -// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi1>, vector<1xi32> +// CHECK32: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<3xi32>, vector<1xi32> // ----- @@ -658,7 +658,7 @@ func.func @vector_maskedstore_i4_arith_constant(%val_to_store: vector<8xi4>) { %c0 = arith.constant 0 : index %c3 = arith.constant 3 : index vector.maskedstore %0[%c3, %c0], %mask, %val_to_store : - memref<5x8xi4>, vector<8xi1>, vector<8xi4> + memref<5x8xi4>, vector<8xi4> return } diff --git a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir index 8cb25c7578495..7a55ff3664433 100644 --- a/mlir/test/Dialect/Vector/vector-mem-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-mem-transforms.mlir @@ -10,7 +10,7 @@ func.func @maskedload0(%base: memref, %pass_thru: vector<16xf32>) -> vect %c0 = arith.constant 0 : index %mask = vector.constant_mask [16] : vector<16xi1> %ld = vector.maskedload %base[%c0], %mask, %pass_thru - : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + : memref, vector<16xf32> return %ld : vector<16xf32> } @@ -24,7 +24,7 @@ func.func @maskedload1(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vec %c0 = arith.constant 0 : index %mask = vector.constant_mask [16] : vector<16xi1> %ld = vector.maskedload %base[%c0], %mask, %pass_thru - : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + : memref<16xf32>, vector<16xf32> return %ld : vector<16xf32> } @@ -36,7 +36,7 @@ func.func @maskedload2(%base: memref<16xf32>, %pass_thru: vector<16xf32>) -> vec %c0 = arith.constant 0 : index %mask = vector.constant_mask [0] : vector<16xi1> %ld = vector.maskedload %base[%c0], %mask, %pass_thru - : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + : memref<16xf32>, vector<16xf32> return %ld : vector<16xf32> } @@ -50,7 +50,7 @@ func.func @maskedload3(%base: memref, %pass_thru: vector<16xf32>) -> vect %c8 = arith.constant 8 : index %mask = vector.constant_mask [16] : vector<16xi1> %ld = vector.maskedload %base[%c8], %mask, %pass_thru - : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> + : memref, vector<16xf32> return %ld : vector<16xf32> } @@ -63,7 +63,7 @@ func.func @maskedload3(%base: memref, %pass_thru: vector<16xf32>) -> vect func.func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) { %c0 = arith.constant 0 : index %mask = vector.constant_mask [16] : vector<16xi1> - vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32> + vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xf32> return } @@ -74,7 +74,7 @@ func.func @maskedstore1(%base: memref<16xf32>, %value: vector<16xf32>) { func.func @maskedstore2(%base: memref<16xf32>, %value: vector<16xf32>) { %c0 = arith.constant 0 : index %mask = vector.constant_mask [0] : vector<16xi1> - vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xi1>, vector<16xf32> + vector.maskedstore %base[%c0], %mask, %value : memref<16xf32>, vector<16xf32> return }