Skip to content

[mlir][vector] Fix emulation of "narrow" type vector.store #133231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,10 +593,19 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto origElements = valueToStore.getType().getNumElements();
// Note, per-element-alignment was already verified above.
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
// Do the trailing dim for source and destination match? If yes, then the
// corresponding index must be 0.
// FIXME: There's no way to tell for dynamic shapes, so we should bail out.
// However, that makes some tests fail, so we need to audit first.
auto trailingDim = op.getBase().getType().getShape().back();
bool trailingDimsMatch =
ShapedType::isDynamic(trailingDim) || trailingDim == origElements;

auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());

// FIXME: ATM, we do not test cases where offsets, sizes, or strides are
// non-zero. As such, this is not needed.
OpFoldResult linearizedIndices;
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
Expand All @@ -608,8 +617,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
getAsOpFoldResult(adaptor.getIndices()));

std::optional<int64_t> foldedNumFrontPadElems =
isDivisibleInSize ? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);
(isDivisibleInSize && trailingDimsMatch)
? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);

if (!foldedNumFrontPadElems) {
return rewriter.notifyMatchFailure(
Expand All @@ -619,15 +629,38 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {

auto memrefBase = cast<MemRefValue>(adaptor.getBase());

// Conditions when atomic RMWs are not needed:
// RMWs are not needed when:
// * no _partial_ stores are required.
// A partial store is defined as a store in which only a part of the
// container element is overwritten, e.g.
//
// Dest before (8 bits)
// +----------+
// | 11000000 |
// +----------+
//
// Dest after storing 0xF at offset 4 (in bits)
// +----------+
// | 11001111 |
// +----------+
//
// At a higher level, this translats to:
// 1. The source vector size (in bits) is a multiple of byte size.
// 2. The address of the store is aligned to the emulated width boundary.
// 2. The address of the store is aligned to the container type width
// boundary.
//
// EXAMPLE 1:
// Requires partial store:
// vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
//
// For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
// need unaligned emulation because the store address is aligned and the
// source is a whole byte.
bool emulationRequiresPartialStores =
!isDivisibleInSize || *foldedNumFrontPadElems != 0;
// EXAMPLE 2:
// Does not require a partial store:
// vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
//
// TODO: Take linearizedInfo.linearizedOffset into account. This is
// currently not needed/used/exercised as all our tests set offset to 0.
bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;

if (!emulationRequiresPartialStores) {
// Basic case: storing full bytes.
auto numElements = origElements / emulatedPerContainerElem;
Expand Down
68 changes: 68 additions & 0 deletions mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,74 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
/// vector.store
///----------------------------------------------------------------------------------------

// -----

// Most basic example to demonstrate where partial stores are not needed.

func.func @vector_store_i2_const_index_no_partial_store(%arg0: vector<4xi2>) {
%0 = memref.alloc() : memref<13xi2>
%c4 = arith.constant 4 : index
vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
return
}
// CHECK-LABEL: func.func @vector_store_i2_const_index_no_partial_store(
// CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
// CHECK-NOT: memref.generic_atomic_rmw
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi8>
// CHECK: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4xi2> to vector<1xi8>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: vector.store %[[UPCAST]], %[[ALLOC]]{{\[}}%[[C1]]] : memref<4xi8>, vector<1xi8>

// -----

// Small modification of the example above to demonstrate where partial stores
// are needed.

func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<4xi2>) {
%0 = memref.alloc() : memref<13xi2>
%c3 = arith.constant 3 : index
vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
return
}

// CHECK-LABEL: func.func @vector_store_i2_const_index_two_partial_stores(
// CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<4xi8>

// First atomic RMW:
// CHECK: %[[IDX_1:.*]] = arith.constant 0 : index
// CHECK: %[[MASK_1:.*]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<4xi2>
// CHECK: %[[SLICE_1:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xi2> to vector<1xi2>
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[SLICE_1]], %[[INIT]] {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
// CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_1]]] : memref<4xi8> {
// CHECK: ^bb0(%[[VAL_8:.*]]: i8):
// CHECK: %[[VAL_9:.*]] = vector.from_elements %[[VAL_8]] : vector<1xi8>
// CHECK: %[[DOWNCAST_1:.*]] = vector.bitcast %[[VAL_9]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT_1:.*]] = arith.select %[[MASK_1]], %[[V1]], %[[DOWNCAST_1]] : vector<4xi1>, vector<4xi2>
// CHECK: %[[UPCAST_1:.*]] = vector.bitcast %[[SELECT_1]] : vector<4xi2> to vector<1xi8>
// CHECK: %[[RES_1:.*]] = vector.extract %[[UPCAST_1]][0] : i8 from vector<1xi8>
// CHECK: memref.atomic_yield %[[RES_1]] : i8
// CHECK: }

// Second atomic RMW:
// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
// CHECK: %[[IDX_2:.*]] = arith.addi %[[IDX_1]], %[[VAL_14]] : index
// CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[INIT]] {offsets = [0], strides = [1]} : vector<3xi2> into vector<4xi2>
// CHECK: %[[MASK_2:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
// CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_2]]] : memref<4xi8> {
// CHECK: ^bb0(%[[VAL_20:.*]]: i8):
// CHECK: %[[VAL_21:.*]] = vector.from_elements %[[VAL_20]] : vector<1xi8>
// CHECK: %[[DONWCAST_2:.*]] = vector.bitcast %[[VAL_21]] : vector<1xi8> to vector<4xi2>
// CHECK: %[[SELECT_2:.*]] = arith.select %[[MASK_2]], %[[V2]], %[[DONWCAST_2]] : vector<4xi1>, vector<4xi2>
// CHECK: %[[UPCAST_2:.*]] = vector.bitcast %[[SELECT_2]] : vector<4xi2> to vector<1xi8>
// CHECK: %[[RES_2:.*]] = vector.extract %[[UPCAST_2]][0] : i8 from vector<1xi8>
// CHECK: memref.atomic_yield %[[RES_2]] : i8
// CHECK: }

// -----

func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
%src = memref.alloc() : memref<3x3xi2>
%c0 = arith.constant 0 : index
Expand Down
5 changes: 5 additions & 0 deletions mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,11 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {

// -----

// FIXME: This example assumes that the store happens at a byte boundary, but
// that's not guaranteed. Below is a counter-example with specific dimensions:
// vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side note: I am trying to add a pass to "linearize" all memrefs: #136797

I think there is already a linearizer for vectors.

So in the future we only need to deal with 1-d memrefs.

// TODO: Revisit post #136797

func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
%0 = memref.alloc(%arg1, %arg2) : memref<?x?xi4>
vector.store %arg0, %0[%arg3, %arg4] : memref<?x?xi4>, vector<8xi4>
Expand Down
Loading