Skip to content

Commit 2de936b

Browse files
authored
[mlir][vector] Fix emulation of "narrow" type vector.store (#133231)
Below are two examples of "narrow" `vector.stores`. The first example does not require partial stores and hence no RMW stores. This is currently emulated correctly. ```mlir func.func @example_1(%arg0: vector<4xi2>) { %0 = memref.alloc() : memref<13xi2> %c4 = arith.constant 4 : index vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2> return } ``` The second example requires a partial (and hence RMW) store due to the offset pointing outside the emulated type boundary (`%c3`). ```mlir func.func @example_2(%arg0: vector<4xi2>) { %0 = memref.alloc() : memref<13xi2> %c3 = arith.constant 3 : index vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2> return } ``` This is currently incorrectly emulated as a single "full" store (note that the offset is incorrect) instead of partial stores: ```mlir func.func @example_2(%arg0: vector<4xi2>) { %alloc = memref.alloc() : memref<4xi8> %0 = vector.bitcast %arg0 : vector<4xi2> to vector<1xi8> %c0 = arith.constant 0 : index vector.store %0, %alloc[%c0] : memref<4xi8>, vector<1xi8> return } ``` The incorrect emulation stems from this simplified (i.e. incomplete) calculation of the front padding: ```cpp std::optional<int64_t> foldedNumFrontPadElems = isDivisibleInSize ? 0 : getConstantIntValue(linearizedInfo.intraDataOffset); ``` Since `isDivisibleInSize` is `true` (i8 / i2 = 4): * front padding is set to `0` and, as a result, * the input offset (`%c3`) is ignored, and * we incorrectly assume that partial stores won't be needed. Note that in both examples we are storing `vector<4xi2>` into `memref<13xi2>` (note _different_ trailing dims) and hence partial stores might in fact be required. The condition above is updated to: ```cpp std::optional<int64_t> foldedNumFrontPadElems = (isDivisibleInSize && trailingDimsMatch) ? 0 : getConstantIntValue(linearizedInfo.intraDataOffset); ``` This change ensures that the input offset is properly taken into account, which fixes the issue. It doesn't affect `@example1`. Additional comments are added to clarify the current logic.
1 parent e78b763 commit 2de936b

File tree

3 files changed

+115
-9
lines changed

3 files changed

+115
-9
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

+42-9
Original file line numberDiff line numberDiff line change
@@ -593,10 +593,19 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
593593
auto origElements = valueToStore.getType().getNumElements();
594594
// Note, per-element-alignment was already verified above.
595595
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
596+
// Do the trailing dim for source and destination match? If yes, then the
597+
// corresponding index must be 0.
598+
// FIXME: There's no way to tell for dynamic shapes, so we should bail out.
599+
// However, that makes some tests fail, so we need to audit first.
600+
auto trailingDim = op.getBase().getType().getShape().back();
601+
bool trailingDimsMatch =
602+
ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
596603

597604
auto stridedMetadata =
598605
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
599606

607+
// FIXME: ATM, we do not test cases where offsets, sizes, or strides are
608+
// non-zero. As such, this is not needed.
600609
OpFoldResult linearizedIndices;
601610
memref::LinearizedMemRefInfo linearizedInfo;
602611
std::tie(linearizedInfo, linearizedIndices) =
@@ -608,8 +617,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
608617
getAsOpFoldResult(adaptor.getIndices()));
609618

610619
std::optional<int64_t> foldedNumFrontPadElems =
611-
isDivisibleInSize ? 0
612-
: getConstantIntValue(linearizedInfo.intraDataOffset);
620+
(isDivisibleInSize && trailingDimsMatch)
621+
? 0
622+
: getConstantIntValue(linearizedInfo.intraDataOffset);
613623

614624
if (!foldedNumFrontPadElems) {
615625
return rewriter.notifyMatchFailure(
@@ -619,15 +629,38 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
619629

620630
auto memrefBase = cast<MemRefValue>(adaptor.getBase());
621631

622-
// Conditions when atomic RMWs are not needed:
632+
// RMWs are not needed when:
633+
// * no _partial_ stores are required.
634+
// A partial store is defined as a store in which only a part of the
635+
// container element is overwritten, e.g.
636+
//
637+
// Dest before (8 bits)
638+
// +----------+
639+
// | 11000000 |
640+
// +----------+
641+
//
642+
// Dest after storing 0xF at offset 4 (in bits)
643+
// +----------+
644+
// | 11001111 |
645+
// +----------+
646+
//
647+
// At a higher level, this translats to:
623648
// 1. The source vector size (in bits) is a multiple of byte size.
624-
// 2. The address of the store is aligned to the emulated width boundary.
649+
// 2. The address of the store is aligned to the container type width
650+
// boundary.
651+
//
652+
// EXAMPLE 1:
653+
// Requires partial store:
654+
// vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
625655
//
626-
// For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
627-
// need unaligned emulation because the store address is aligned and the
628-
// source is a whole byte.
629-
bool emulationRequiresPartialStores =
630-
!isDivisibleInSize || *foldedNumFrontPadElems != 0;
656+
// EXAMPLE 2:
657+
// Does not require a partial store:
658+
// vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
659+
//
660+
// TODO: Take linearizedInfo.linearizedOffset into account. This is
661+
// currently not needed/used/exercised as all our tests set offset to 0.
662+
bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
663+
631664
if (!emulationRequiresPartialStores) {
632665
// Basic case: storing full bytes.
633666
auto numElements = origElements / emulatedPerContainerElem;

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

+68
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,74 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
361361
/// vector.store
362362
///----------------------------------------------------------------------------------------
363363

364+
// -----
365+
366+
// Most basic example to demonstrate where partial stores are not needed.
367+
368+
func.func @vector_store_i2_const_index_no_partial_store(%arg0: vector<4xi2>) {
369+
%0 = memref.alloc() : memref<13xi2>
370+
%c4 = arith.constant 4 : index
371+
vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
372+
return
373+
}
374+
// CHECK-LABEL: func.func @vector_store_i2_const_index_no_partial_store(
375+
// CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
376+
// CHECK-NOT: memref.generic_atomic_rmw
377+
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi8>
378+
// CHECK: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4xi2> to vector<1xi8>
379+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
380+
// CHECK: vector.store %[[UPCAST]], %[[ALLOC]]{{\[}}%[[C1]]] : memref<4xi8>, vector<1xi8>
381+
382+
// -----
383+
384+
// Small modification of the example above to demonstrate where partial stores
385+
// are needed.
386+
387+
func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<4xi2>) {
388+
%0 = memref.alloc() : memref<13xi2>
389+
%c3 = arith.constant 3 : index
390+
vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
391+
return
392+
}
393+
394+
// CHECK-LABEL: func.func @vector_store_i2_const_index_two_partial_stores(
395+
// CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
396+
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<4xi8>
397+
398+
// First atomic RMW:
399+
// CHECK: %[[IDX_1:.*]] = arith.constant 0 : index
400+
// CHECK: %[[MASK_1:.*]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
401+
// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<4xi2>
402+
// CHECK: %[[SLICE_1:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xi2> to vector<1xi2>
403+
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[SLICE_1]], %[[INIT]] {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
404+
// CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_1]]] : memref<4xi8> {
405+
// CHECK: ^bb0(%[[VAL_8:.*]]: i8):
406+
// CHECK: %[[VAL_9:.*]] = vector.from_elements %[[VAL_8]] : vector<1xi8>
407+
// CHECK: %[[DOWNCAST_1:.*]] = vector.bitcast %[[VAL_9]] : vector<1xi8> to vector<4xi2>
408+
// CHECK: %[[SELECT_1:.*]] = arith.select %[[MASK_1]], %[[V1]], %[[DOWNCAST_1]] : vector<4xi1>, vector<4xi2>
409+
// CHECK: %[[UPCAST_1:.*]] = vector.bitcast %[[SELECT_1]] : vector<4xi2> to vector<1xi8>
410+
// CHECK: %[[RES_1:.*]] = vector.extract %[[UPCAST_1]][0] : i8 from vector<1xi8>
411+
// CHECK: memref.atomic_yield %[[RES_1]] : i8
412+
// CHECK: }
413+
414+
// Second atomic RMW:
415+
// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
416+
// CHECK: %[[IDX_2:.*]] = arith.addi %[[IDX_1]], %[[VAL_14]] : index
417+
// CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
418+
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[INIT]] {offsets = [0], strides = [1]} : vector<3xi2> into vector<4xi2>
419+
// CHECK: %[[MASK_2:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
420+
// CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_2]]] : memref<4xi8> {
421+
// CHECK: ^bb0(%[[VAL_20:.*]]: i8):
422+
// CHECK: %[[VAL_21:.*]] = vector.from_elements %[[VAL_20]] : vector<1xi8>
423+
// CHECK: %[[DONWCAST_2:.*]] = vector.bitcast %[[VAL_21]] : vector<1xi8> to vector<4xi2>
424+
// CHECK: %[[SELECT_2:.*]] = arith.select %[[MASK_2]], %[[V2]], %[[DONWCAST_2]] : vector<4xi1>, vector<4xi2>
425+
// CHECK: %[[UPCAST_2:.*]] = vector.bitcast %[[SELECT_2]] : vector<4xi2> to vector<1xi8>
426+
// CHECK: %[[RES_2:.*]] = vector.extract %[[UPCAST_2]][0] : i8 from vector<1xi8>
427+
// CHECK: memref.atomic_yield %[[RES_2]] : i8
428+
// CHECK: }
429+
430+
// -----
431+
364432
func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
365433
%src = memref.alloc() : memref<3x3xi2>
366434
%c0 = arith.constant 0 : index

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

+5
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,11 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
439439

440440
// -----
441441

442+
// FIXME: This example assumes that the store happens at a byte boundary, but
443+
// that's not guaranteed. Below is a counter-example with specific dimensions:
444+
// vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4>
445+
// TODO: Revisit post #136797
446+
442447
func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
443448
%0 = memref.alloc(%arg1, %arg2) : memref<?x?xi4>
444449
vector.store %arg0, %0[%arg3, %arg4] : memref<?x?xi4>, vector<8xi4>

0 commit comments

Comments
 (0)