diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 5debebd3218ed..4fc8fce27ce21 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -593,10 +593,19 @@ struct ConvertVectorStore final : OpConversionPattern { auto origElements = valueToStore.getType().getNumElements(); // Note, per-element-alignment was already verified above. bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0; + // Do the trailing dim for source and destination match? If yes, then the + // corresponding index must be 0. + // FIXME: There's no way to tell for dynamic shapes, so we should bail out. + // However, that makes some tests fail, so we need to audit first. + auto trailingDim = op.getBase().getType().getShape().back(); + bool trailingDimsMatch = + ShapedType::isDynamic(trailingDim) || trailingDim == origElements; auto stridedMetadata = rewriter.create(loc, op.getBase()); + // FIXME: ATM, we do not test cases where offsets, sizes, or strides are + // non-zero. As such, this is not needed. OpFoldResult linearizedIndices; memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndices) = @@ -608,8 +617,9 @@ struct ConvertVectorStore final : OpConversionPattern { getAsOpFoldResult(adaptor.getIndices())); std::optional foldedNumFrontPadElems = - isDivisibleInSize ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + (isDivisibleInSize && trailingDimsMatch) + ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); if (!foldedNumFrontPadElems) { return rewriter.notifyMatchFailure( @@ -619,15 +629,38 @@ struct ConvertVectorStore final : OpConversionPattern { auto memrefBase = cast(adaptor.getBase()); - // Conditions when atomic RMWs are not needed: + // RMWs are not needed when: + // * no _partial_ stores are required. + // A partial store is defined as a store in which only a part of the + // container element is overwritten, e.g. + // + // Dest before (8 bits) + // +----------+ + // | 11000000 | + // +----------+ + // + // Dest after storing 0xF at offset 4 (in bits) + // +----------+ + // | 11001111 | + // +----------+ + // + // At a higher level, this translats to: // 1. The source vector size (in bits) is a multiple of byte size. - // 2. The address of the store is aligned to the emulated width boundary. + // 2. The address of the store is aligned to the container type width + // boundary. + // + // EXAMPLE 1: + // Requires partial store: + // vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2> // - // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not - // need unaligned emulation because the store address is aligned and the - // source is a whole byte. - bool emulationRequiresPartialStores = - !isDivisibleInSize || *foldedNumFrontPadElems != 0; + // EXAMPLE 2: + // Does not require a partial store: + // vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2> + // + // TODO: Take linearizedInfo.linearizedOffset into account. This is + // currently not needed/used/exercised as all our tests set offset to 0. + bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0; + if (!emulationRequiresPartialStores) { // Basic case: storing full bytes. auto numElements = origElements / emulatedPerContainerElem; diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir index 6fc974200c6f3..21f073efc49b2 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir @@ -361,6 +361,74 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>) /// vector.store ///---------------------------------------------------------------------------------------- +// ----- + +// Most basic example to demonstrate where partial stores are not needed. + +func.func @vector_store_i2_const_index_no_partial_store(%arg0: vector<4xi2>) { + %0 = memref.alloc() : memref<13xi2> + %c4 = arith.constant 4 : index + vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2> + return +} +// CHECK-LABEL: func.func @vector_store_i2_const_index_no_partial_store( +// CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) { +// CHECK-NOT: memref.generic_atomic_rmw +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi8> +// CHECK: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4xi2> to vector<1xi8> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: vector.store %[[UPCAST]], %[[ALLOC]]{{\[}}%[[C1]]] : memref<4xi8>, vector<1xi8> + +// ----- + +// Small modification of the example above to demonstrate where partial stores +// are needed. + +func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<4xi2>) { + %0 = memref.alloc() : memref<13xi2> + %c3 = arith.constant 3 : index + vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2> + return +} + +// CHECK-LABEL: func.func @vector_store_i2_const_index_two_partial_stores( +// CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) { +// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<4xi8> + +// First atomic RMW: +// CHECK: %[[IDX_1:.*]] = arith.constant 0 : index +// CHECK: %[[MASK_1:.*]] = arith.constant dense<[false, false, false, true]> : vector<4xi1> +// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<4xi2> +// CHECK: %[[SLICE_1:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xi2> to vector<1xi2> +// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[SLICE_1]], %[[INIT]] {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2> +// CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_1]]] : memref<4xi8> { +// CHECK: ^bb0(%[[VAL_8:.*]]: i8): +// CHECK: %[[VAL_9:.*]] = vector.from_elements %[[VAL_8]] : vector<1xi8> +// CHECK: %[[DOWNCAST_1:.*]] = vector.bitcast %[[VAL_9]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT_1:.*]] = arith.select %[[MASK_1]], %[[V1]], %[[DOWNCAST_1]] : vector<4xi1>, vector<4xi2> +// CHECK: %[[UPCAST_1:.*]] = vector.bitcast %[[SELECT_1]] : vector<4xi2> to vector<1xi8> +// CHECK: %[[RES_1:.*]] = vector.extract %[[UPCAST_1]][0] : i8 from vector<1xi8> +// CHECK: memref.atomic_yield %[[RES_1]] : i8 +// CHECK: } + +// Second atomic RMW: +// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index +// CHECK: %[[IDX_2:.*]] = arith.addi %[[IDX_1]], %[[VAL_14]] : index +// CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2> +// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[INIT]] {offsets = [0], strides = [1]} : vector<3xi2> into vector<4xi2> +// CHECK: %[[MASK_2:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1> +// CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_2]]] : memref<4xi8> { +// CHECK: ^bb0(%[[VAL_20:.*]]: i8): +// CHECK: %[[VAL_21:.*]] = vector.from_elements %[[VAL_20]] : vector<1xi8> +// CHECK: %[[DONWCAST_2:.*]] = vector.bitcast %[[VAL_21]] : vector<1xi8> to vector<4xi2> +// CHECK: %[[SELECT_2:.*]] = arith.select %[[MASK_2]], %[[V2]], %[[DONWCAST_2]] : vector<4xi1>, vector<4xi2> +// CHECK: %[[UPCAST_2:.*]] = vector.bitcast %[[SELECT_2]] : vector<4xi2> to vector<1xi8> +// CHECK: %[[RES_2:.*]] = vector.extract %[[UPCAST_2]][0] : i8 from vector<1xi8> +// CHECK: memref.atomic_yield %[[RES_2]] : i8 +// CHECK: } + +// ----- + func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) { %src = memref.alloc() : memref<3x3xi2> %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir index 9dc3eb6989c6c..9e2d131f421b7 100644 --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir @@ -439,6 +439,11 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) { // ----- +// FIXME: This example assumes that the store happens at a byte boundary, but +// that's not guaranteed. Below is a counter-example with specific dimensions: +// vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4> +// TODO: Revisit post #136797 + func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { %0 = memref.alloc(%arg1, %arg2) : memref vector.store %arg0, %0[%arg3, %arg4] : memref, vector<8xi4>