Skip to content

Commit 2ce50fe

Browse files
committed
Revert "[mlir][vector] Fix emulation of "narrow" type vector.store (llvm#133231)"
This reverts commit 2de936b.
1 parent 4cce326 commit 2ce50fe

File tree

3 files changed

+9
-115
lines changed

3 files changed

+9
-115
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -593,19 +593,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
593593
auto origElements = valueToStore.getType().getNumElements();
594594
// Note, per-element-alignment was already verified above.
595595
bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0;
596-
// Do the trailing dim for source and destination match? If yes, then the
597-
// corresponding index must be 0.
598-
// FIXME: There's no way to tell for dynamic shapes, so we should bail out.
599-
// However, that makes some tests fail, so we need to audit first.
600-
auto trailingDim = op.getBase().getType().getShape().back();
601-
bool trailingDimsMatch =
602-
ShapedType::isDynamic(trailingDim) || trailingDim == origElements;
603596

604597
auto stridedMetadata =
605598
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
606599

607-
// FIXME: ATM, we do not test cases where offsets, sizes, or strides are
608-
// non-zero. As such, this is not needed.
609600
OpFoldResult linearizedIndices;
610601
memref::LinearizedMemRefInfo linearizedInfo;
611602
std::tie(linearizedInfo, linearizedIndices) =
@@ -617,9 +608,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
617608
getAsOpFoldResult(adaptor.getIndices()));
618609

619610
std::optional<int64_t> foldedNumFrontPadElems =
620-
(isDivisibleInSize && trailingDimsMatch)
621-
? 0
622-
: getConstantIntValue(linearizedInfo.intraDataOffset);
611+
isDivisibleInSize ? 0
612+
: getConstantIntValue(linearizedInfo.intraDataOffset);
623613

624614
if (!foldedNumFrontPadElems) {
625615
return rewriter.notifyMatchFailure(
@@ -629,38 +619,15 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
629619

630620
auto memrefBase = cast<MemRefValue>(adaptor.getBase());
631621

632-
// RMWs are not needed when:
633-
// * no _partial_ stores are required.
634-
// A partial store is defined as a store in which only a part of the
635-
// container element is overwritten, e.g.
636-
//
637-
// Dest before (8 bits)
638-
// +----------+
639-
// | 11000000 |
640-
// +----------+
641-
//
642-
// Dest after storing 0xF at offset 4 (in bits)
643-
// +----------+
644-
// | 11001111 |
645-
// +----------+
646-
//
647-
// At a higher level, this translats to:
622+
// Conditions when atomic RMWs are not needed:
648623
// 1. The source vector size (in bits) is a multiple of byte size.
649-
// 2. The address of the store is aligned to the container type width
650-
// boundary.
651-
//
652-
// EXAMPLE 1:
653-
// Requires partial store:
654-
// vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
624+
// 2. The address of the store is aligned to the emulated width boundary.
655625
//
656-
// EXAMPLE 2:
657-
// Does not require a partial store:
658-
// vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
659-
//
660-
// TODO: Take linearizedInfo.linearizedOffset into account. This is
661-
// currently not needed/used/exercised as all our tests set offset to 0.
662-
bool emulationRequiresPartialStores = *foldedNumFrontPadElems != 0;
663-
626+
// For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
627+
// need unaligned emulation because the store address is aligned and the
628+
// source is a whole byte.
629+
bool emulationRequiresPartialStores =
630+
!isDivisibleInSize || *foldedNumFrontPadElems != 0;
664631
if (!emulationRequiresPartialStores) {
665632
// Basic case: storing full bytes.
666633
auto numElements = origElements / emulatedPerContainerElem;

mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -361,74 +361,6 @@ func.func @vector_maskedload_i2_constant_mask_unaligned(%passthru: vector<5xi2>)
361361
/// vector.store
362362
///----------------------------------------------------------------------------------------
363363

364-
// -----
365-
366-
// Most basic example to demonstrate where partial stores are not needed.
367-
368-
func.func @vector_store_i2_const_index_no_partial_store(%arg0: vector<4xi2>) {
369-
%0 = memref.alloc() : memref<13xi2>
370-
%c4 = arith.constant 4 : index
371-
vector.store %arg0, %0[%c4] : memref<13xi2>, vector<4xi2>
372-
return
373-
}
374-
// CHECK-LABEL: func.func @vector_store_i2_const_index_no_partial_store(
375-
// CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
376-
// CHECK-NOT: memref.generic_atomic_rmw
377-
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<4xi8>
378-
// CHECK: %[[UPCAST:.*]] = vector.bitcast %[[ARG_0]] : vector<4xi2> to vector<1xi8>
379-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
380-
// CHECK: vector.store %[[UPCAST]], %[[ALLOC]]{{\[}}%[[C1]]] : memref<4xi8>, vector<1xi8>
381-
382-
// -----
383-
384-
// Small modification of the example above to demonstrate where partial stores
385-
// are needed.
386-
387-
func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<4xi2>) {
388-
%0 = memref.alloc() : memref<13xi2>
389-
%c3 = arith.constant 3 : index
390-
vector.store %arg0, %0[%c3] : memref<13xi2>, vector<4xi2>
391-
return
392-
}
393-
394-
// CHECK-LABEL: func.func @vector_store_i2_const_index_two_partial_stores(
395-
// CHECK-SAME: %[[ARG_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: vector<4xi2>) {
396-
// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<4xi8>
397-
398-
// First atomic RMW:
399-
// CHECK: %[[IDX_1:.*]] = arith.constant 0 : index
400-
// CHECK: %[[MASK_1:.*]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
401-
// CHECK: %[[INIT:.*]] = arith.constant dense<0> : vector<4xi2>
402-
// CHECK: %[[SLICE_1:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xi2> to vector<1xi2>
403-
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[SLICE_1]], %[[INIT]] {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
404-
// CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_1]]] : memref<4xi8> {
405-
// CHECK: ^bb0(%[[VAL_8:.*]]: i8):
406-
// CHECK: %[[VAL_9:.*]] = vector.from_elements %[[VAL_8]] : vector<1xi8>
407-
// CHECK: %[[DOWNCAST_1:.*]] = vector.bitcast %[[VAL_9]] : vector<1xi8> to vector<4xi2>
408-
// CHECK: %[[SELECT_1:.*]] = arith.select %[[MASK_1]], %[[V1]], %[[DOWNCAST_1]] : vector<4xi1>, vector<4xi2>
409-
// CHECK: %[[UPCAST_1:.*]] = vector.bitcast %[[SELECT_1]] : vector<4xi2> to vector<1xi8>
410-
// CHECK: %[[RES_1:.*]] = vector.extract %[[UPCAST_1]][0] : i8 from vector<1xi8>
411-
// CHECK: memref.atomic_yield %[[RES_1]] : i8
412-
// CHECK: }
413-
414-
// Second atomic RMW:
415-
// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
416-
// CHECK: %[[IDX_2:.*]] = arith.addi %[[IDX_1]], %[[VAL_14]] : index
417-
// CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[ARG_0]] {offsets = [1], sizes = [3], strides = [1]} : vector<4xi2> to vector<3xi2>
418-
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[VAL_16]], %[[INIT]] {offsets = [0], strides = [1]} : vector<3xi2> into vector<4xi2>
419-
// CHECK: %[[MASK_2:.*]] = arith.constant dense<[true, true, true, false]> : vector<4xi1>
420-
// CHECK: memref.generic_atomic_rmw %[[VAL_1]]{{\[}}%[[IDX_2]]] : memref<4xi8> {
421-
// CHECK: ^bb0(%[[VAL_20:.*]]: i8):
422-
// CHECK: %[[VAL_21:.*]] = vector.from_elements %[[VAL_20]] : vector<1xi8>
423-
// CHECK: %[[DONWCAST_2:.*]] = vector.bitcast %[[VAL_21]] : vector<1xi8> to vector<4xi2>
424-
// CHECK: %[[SELECT_2:.*]] = arith.select %[[MASK_2]], %[[V2]], %[[DONWCAST_2]] : vector<4xi1>, vector<4xi2>
425-
// CHECK: %[[UPCAST_2:.*]] = vector.bitcast %[[SELECT_2]] : vector<4xi2> to vector<1xi8>
426-
// CHECK: %[[RES_2:.*]] = vector.extract %[[UPCAST_2]][0] : i8 from vector<1xi8>
427-
// CHECK: memref.atomic_yield %[[RES_2]] : i8
428-
// CHECK: }
429-
430-
// -----
431-
432364
func.func @vector_store_i2_const_index_two_partial_stores(%arg0: vector<3xi2>) {
433365
%src = memref.alloc() : memref<3x3xi2>
434366
%c0 = arith.constant 0 : index

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -439,11 +439,6 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
439439

440440
// -----
441441

442-
// FIXME: This example assumes that the store happens at a byte boundary, but
443-
// that's not guaranteed. Below is a counter-example with specific dimensions:
444-
// vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4>
445-
// TODO: Revisit post #136797
446-
447442
func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
448443
%0 = memref.alloc(%arg1, %arg2) : memref<?x?xi4>
449444
vector.store %arg0, %0[%arg3, %arg4] : memref<?x?xi4>, vector<8xi4>

0 commit comments

Comments
 (0)