-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[MLIR] Add a utility pass to linearize memref
#136797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
I think we probably need a type converter that also linearize other memrefs not being handled by this transformation (by collapsing their shapes?), so at the end we would only deal with 1-dimensional memrefs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a memref linearizer pass to simplify memory access patterns by flattening higher-ranked memrefs into rank-1 memrefs while supporting vector dialect operations.
- Adds a new pass implementation in FlattenMemRefs.cpp for linearizing memrefs.
- Updates the AffineOps.cpp to use a newly defined affine::computeProduct function and removes a duplicate implementation.
- Extends header files in Transforms.h and Passes.h to declare the necessary pattern and pass creation routines.
Reviewed Changes
Copilot reviewed 4 out of 7 changed files in this pull request and generated no comments.
File | Description |
---|---|
mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp | Implements the memref linearizer pass with transformations and pattern rewrites. |
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | Refactors and centralizes the computeProduct logic within the affine namespace. |
mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h | Declares population of flatten memref patterns. |
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h | Declares the creation routine for the flatten memrefs pass. |
Files not reviewed (3)
- mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td: Language not supported
- mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt: Language not supported
- mlir/test/Dialect/MemRef/flatten_memref.mlir: Language not supported
Comments suppressed due to low confidence (1)
mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp:140
- The function name 'needFlattenning' contains a spelling error. Consider renaming it to 'needFlattening' for clarity and consistency.
static bool needFlattenning(Value val) {
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir-memref Author: Alan Li (lialan) ChangesTo add a transformation that simplifies memory access patterns, this PR adds a memref linearizer which is based on the GPU/DecomposeMemRefs pass, with the following changes:
Notes:
Patch is 31.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136797.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index d7050156862df..7580985754843 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -77,6 +77,10 @@ std::unique_ptr<Pass> createExpandStridedMetadataPass();
/// components.
std::unique_ptr<Pass> createExpandReallocPass(bool emitDeallocs = true);
+/// Creates an operation pass to flatten multiple dimensional memrefs into
+/// 1-d memrefs.
+std::unique_ptr<Pass> createFlattenMemrefsPass();
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 651ee05ae1f3c..c87472851fd78 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -253,5 +253,17 @@ def ExpandRealloc : Pass<"expand-realloc"> {
];
}
+def FlattenMemrefsPass : Pass<"flatten-memref"> {
+ let summary = "Flatten a multiple dimensional memref to 1-dimensional";
+ let description = [{
+
+ }];
+
+ let constructor = "mlir::memref::createFlattenMemrefsPass()";
+ let dependentDialects = [
+ "affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
+ ];
+}
+
#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 62a2297c80e78..c2b8cb05be922 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -144,6 +144,8 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
/// ```
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
+void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
+
/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
/// given independencies. If the op is already independent of all
/// independencies, the same AllocaOp result is returned.
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index aa49c49062c76..43224de5604ed 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -5022,6 +5022,31 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
return ret;
}
+namespace mlir {
+namespace affine {
+OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+ ArrayRef<OpFoldResult> terms) {
+ int64_t nDynamic = 0;
+ SmallVector<Value> dynamicPart;
+ AffineExpr result = builder.getAffineConstantExpr(1);
+ for (OpFoldResult term : terms) {
+ if (!term)
+ return term;
+ std::optional<int64_t> maybeConst = getConstantIntValue(term);
+ if (maybeConst) {
+ result = result * builder.getAffineConstantExpr(*maybeConst);
+ } else {
+ dynamicPart.push_back(cast<Value>(term));
+ result = result * builder.getAffineSymbolExpr(nDynamic++);
+ }
+ }
+ if (auto constant = dyn_cast<AffineConstantExpr>(result))
+ return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+ return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+} // namespace affine
+} // namespace mlir
+
namespace {
/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -5081,27 +5106,6 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
}
};
-OpFoldResult computeProduct(Location loc, OpBuilder &builder,
- ArrayRef<OpFoldResult> terms) {
- int64_t nDynamic = 0;
- SmallVector<Value> dynamicPart;
- AffineExpr result = builder.getAffineConstantExpr(1);
- for (OpFoldResult term : terms) {
- if (!term)
- return term;
- std::optional<int64_t> maybeConst = getConstantIntValue(term);
- if (maybeConst) {
- result = result * builder.getAffineConstantExpr(*maybeConst);
- } else {
- dynamicPart.push_back(cast<Value>(term));
- result = result * builder.getAffineSymbolExpr(nDynamic++);
- }
- }
- if (auto constant = dyn_cast<AffineConstantExpr>(result))
- return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
- return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
-}
-
/// If conseceutive outputs of a delinearize_index are linearized with the same
/// bounds, canonicalize away the redundant arithmetic.
///
@@ -5248,7 +5252,7 @@ struct CancelLinearizeOfDelinearizePortion final
// We use the slice from the linearize's basis above because of the
// "bounds inferred from `disjoint`" case above.
OpFoldResult newSize =
- computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
+ affine::computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
// Trivial case where we can just skip past the delinearize all together
if (m.length == m.delinearize.getNumResults()) {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index ecab97bc2b8e7..48e8bccd369fa 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
EmulateWideInt.cpp
EmulateNarrowType.cpp
ExtractAddressComputations.cpp
+ FlattenMemRefs.cpp
FoldMemRefAliasOps.cpp
IndependenceTransforms.cpp
MultiBuffer.cpp
@@ -23,6 +24,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
LINK_LIBS PUBLIC
MLIRAffineTransforms
+ MLIRAffineDialect
MLIRAffineUtils
MLIRArithDialect
MLIRArithTransforms
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
new file mode 100644
index 0000000000000..5ec524967444a
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -0,0 +1,356 @@
+//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns for flattening an multi-rank memref-related
+// ops into 1-d memref ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_FLATTENMEMREFSPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+static void setInsertionPointToStart(OpBuilder &builder, Value val) {
+ if (auto *parentOp = val.getDefiningOp()) {
+ builder.setInsertionPointAfter(parentOp);
+ } else {
+ builder.setInsertionPointToStart(val.getParentBlock());
+ }
+}
+
+static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
+ OpFoldResult>
+getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
+ ArrayRef<OpFoldResult> subOffsets,
+ ArrayRef<OpFoldResult> subStrides = std::nullopt) {
+ auto sourceType = cast<MemRefType>(source.getType());
+ auto sourceRank = static_cast<unsigned>(sourceType.getRank());
+
+ memref::ExtractStridedMetadataOp newExtractStridedMetadata;
+ {
+ OpBuilder::InsertionGuard g(rewriter);
+ setInsertionPointToStart(rewriter, source);
+ newExtractStridedMetadata =
+ rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+ }
+
+ auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
+
+ auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
+ return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
+ : rewriter.getIndexAttr(dim);
+ };
+
+ OpFoldResult origOffset =
+ getDim(sourceOffset, newExtractStridedMetadata.getOffset());
+ ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
+ OpFoldResult outmostDim =
+ getDim(sourceType.getShape().front(),
+ newExtractStridedMetadata.getSizes().front());
+
+ SmallVector<OpFoldResult> origStrides;
+ origStrides.reserve(sourceRank);
+
+ SmallVector<OpFoldResult> strides;
+ strides.reserve(sourceRank);
+
+ AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
+ AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
+ for (auto i : llvm::seq(0u, sourceRank)) {
+ OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
+
+ if (!subStrides.empty()) {
+ strides.push_back(affine::makeComposedFoldedAffineApply(
+ rewriter, loc, s0 * s1, {subStrides[i], origStride}));
+ }
+
+ origStrides.emplace_back(origStride);
+ }
+
+ // Compute linearized index:
+ auto &&[expr, values] =
+ computeLinearIndex(rewriter.getIndexAttr(0), origStrides, subOffsets);
+ OpFoldResult linearizedIndex =
+ affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
+
+ // Compute collapsed size: (the outmost stride * outmost dimension).
+ SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
+ OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
+
+ return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex,
+ origStrides, origOffset, collapsedSize};
+}
+
+static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
+ OpFoldResult in) {
+ if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
+ return rewriter.create<arith::ConstantIndexOp>(
+ loc, cast<IntegerAttr>(offsetAttr).getInt());
+ }
+ return cast<Value>(in);
+}
+
+/// Returns a collapsed memref and the linearized index to access the element
+/// at the specified indices.
+static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
+ Location loc,
+ Value source,
+ ValueRange indices) {
+ auto &&[base, index, strides, offset, collapsedShape] =
+ getFlatOffsetAndStrides(rewriter, loc, source,
+ getAsOpFoldResult(indices));
+
+ return std::make_pair(
+ rewriter.create<memref::ReinterpretCastOp>(
+ loc, source,
+ /* offset = */ offset,
+ /* shapes = */ ArrayRef<OpFoldResult>{collapsedShape},
+ /* strides = */ ArrayRef<OpFoldResult>{strides.back()}),
+ getValueFromOpFoldResult(rewriter, loc, index));
+}
+
+static bool needFlattenning(Value val) {
+ auto type = cast<MemRefType>(val.getType());
+ return type.getRank() > 1;
+}
+
+static bool checkLayout(Value val) {
+ auto type = cast<MemRefType>(val.getType());
+ return type.getLayout().isIdentity() ||
+ isa<StridedLayoutAttr>(type.getLayout());
+}
+
+namespace {
+template <typename T>
+static Value getTargetMemref(T op) {
+ if constexpr (std::is_same_v<T, memref::LoadOp>) {
+ return op.getMemref();
+ } else if constexpr (std::is_same_v<T, vector::LoadOp>) {
+ return op.getBase();
+ } else if constexpr (std::is_same_v<T, memref::StoreOp>) {
+ return op.getMemref();
+ } else if constexpr (std::is_same_v<T, vector::StoreOp>) {
+ return op.getBase();
+ } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
+ return op.getBase();
+ } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
+ return op.getBase();
+ } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
+ return op.getSource();
+ } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
+ return op.getSource();
+ }
+ return {};
+}
+
+template <typename T>
+static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
+ Value offset) {
+ if constexpr (std::is_same_v<T, memref::LoadOp>) {
+ auto newLoad = rewriter.create<memref::LoadOp>(
+ op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
+ newLoad->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newLoad.getResult());
+ } else if constexpr (std::is_same_v<T, vector::LoadOp>) {
+ auto newLoad = rewriter.create<vector::LoadOp>(
+ op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
+ newLoad->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newLoad.getResult());
+ } else if constexpr (std::is_same_v<T, memref::StoreOp>) {
+ auto newStore = rewriter.create<memref::StoreOp>(
+ op->getLoc(), op->getOperands().front(), flatMemref,
+ ValueRange{offset});
+ newStore->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newStore);
+ } else if constexpr (std::is_same_v<T, vector::StoreOp>) {
+ auto newStore = rewriter.create<vector::StoreOp>(
+ op->getLoc(), op->getOperands().front(), flatMemref,
+ ValueRange{offset});
+ newStore->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newStore);
+ } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
+ auto newTransferRead = rewriter.create<vector::TransferReadOp>(
+ op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
+ op.getPadding());
+ rewriter.replaceOp(op, newTransferRead.getResult());
+ } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
+ auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
+ op->getLoc(), op.getVector(), flatMemref, ValueRange{offset});
+ rewriter.replaceOp(op, newTransferWrite);
+ } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
+ auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
+ op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
+ op.getMask(), op.getPassThru());
+ newMaskedLoad->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newMaskedLoad.getResult());
+ } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
+ auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
+ op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(),
+ op.getValueToStore());
+ newMaskedStore->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newMaskedStore);
+ } else {
+ op.emitOpError("unimplemented: do not know how to replace op.");
+ }
+}
+
+template <typename T>
+struct MemRefRewritePatternBase : public OpRewritePattern<T> {
+ using OpRewritePattern<T>::OpRewritePattern;
+ LogicalResult matchAndRewrite(T op,
+ PatternRewriter &rewriter) const override {
+ Value memref = getTargetMemref<T>(op);
+ if (!needFlattenning(memref) || !checkLayout(memref))
+ return rewriter.notifyMatchFailure(op,
+ "nothing to do or unsupported layout");
+ auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
+ rewriter, op->getLoc(), memref, op.getIndices());
+ replaceOp<T>(op, rewriter, flatMemref, offset);
+ return success();
+ }
+};
+
+struct FlattenMemrefLoad : public MemRefRewritePatternBase<memref::LoadOp> {
+ using MemRefRewritePatternBase<memref::LoadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorLoad : public MemRefRewritePatternBase<vector::LoadOp> {
+ using MemRefRewritePatternBase<vector::LoadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenMemrefStore : public MemRefRewritePatternBase<memref::StoreOp> {
+ using MemRefRewritePatternBase<memref::StoreOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorStore : public MemRefRewritePatternBase<vector::StoreOp> {
+ using MemRefRewritePatternBase<vector::StoreOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorMaskedLoad
+ : public MemRefRewritePatternBase<vector::MaskedLoadOp> {
+ using MemRefRewritePatternBase<
+ vector::MaskedLoadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorMaskedStore
+ : public MemRefRewritePatternBase<vector::MaskedStoreOp> {
+ using MemRefRewritePatternBase<
+ vector::MaskedStoreOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorTransferRead
+ : public MemRefRewritePatternBase<vector::TransferReadOp> {
+ using MemRefRewritePatternBase<
+ vector::TransferReadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorTransferWrite
+ : public MemRefRewritePatternBase<vector::TransferWriteOp> {
+ using MemRefRewritePatternBase<
+ vector::TransferWriteOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(memref::SubViewOp op,
+ PatternRewriter &rewriter) const override {
+ Value memref = op.getSource();
+ if (!needFlattenning(memref))
+ return rewriter.notifyMatchFailure(op, "nothing to do");
+
+ if (!checkLayout(memref))
+ return rewriter.notifyMatchFailure(op, "unsupported layout");
+
+ Location loc = op.getLoc();
+ SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
+ SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
+ SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
+ auto &&[base, finalOffset, strides, _, __] =
+ getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
+
+ auto srcType = cast<MemRefType>(memref.getType());
+ auto resultType = cast<MemRefType>(op.getType());
+ unsigned subRank = static_cast<unsigned>(resultType.getRank());
+
+ llvm::SmallBitVector droppedDims = op.getDroppedDims();
+
+ SmallVector<OpFoldResult> finalSizes;
+ finalSizes.reserve(subRank);
+
+ SmallVector<OpFoldResult> finalStrides;
+ finalStrides.reserve(subRank);
+
+ for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
+ if (droppedDims.test(i))
+ continue;
+
+ finalSizes.push_back(subSizes[i]);
+ finalStrides.push_back(strides[i]);
+ }
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, resultType, base, finalOffset, finalSizes, finalStrides);
+ return success();
+ }
+};
+
+struct FlattenMemrefsPass
+ : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
+ using Base::Base;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<affine::AffineDialect, arith::ArithDialect,
+ memref::MemRefDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+
+ memref::populateFlattenMemrefsPatterns(patterns);
+
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
+ patterns.insert<FlattenMemrefLoad, FlattenMemrefStore, FlattenSubview,
+ FlattenVectorMaskedLoad, FlattenVectorMaskedStore,
+ FlattenVectorLoad, FlattenVectorStore,
+ FlattenVectorTransferRead, FlattenVectorTransferWrite>(
+ patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::memref::createFlattenMemrefsPass() {
+ return std::make_unique<FlattenMemrefsPass>();
+}
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatte...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
I think we probably need a type converter that also linearize other memrefs not being handled by this transformation
What memrefs would that entail? I was actually going to ask - what about the more basic cases where simple memref.cast
would be sufficient?
Thanks!
affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values); | ||
|
||
// Compute collapsed size: (the outmost stride * outmost dimension). | ||
SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you want "outermost" here, but max(size[dim] * stride[dim), unless you know the layout is contiguous / row-major-ish. Consider a column-major memref
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I changed the way how it is computed, so now we need to make sure it only works with row-major.
598f3b9
to
c9db9bf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you take the GPU dialect tests for the pass this was based on and switch them over to this version scheduled on a GPU module?
Also, I'd like to add a test
func.func @load_from_subview(%x: memref<8x8xf32>, %col: index, %row: index) -> f32 {
%subview = memref.subview %input[0, 0] [4, 4] [1, 1] : memref<8x8xf32> to memref<4x4xf32, strided<[8, 1]>>
%ret = memref.load %subview[%row, %col] : memref<4x4xf32, strided<[8, 1]>>
return %ret : f32
}
to make sure that this doesn't become
memref.load[(row * 8 + col)] : memref<16xf32>
Well I think it actually transforms into what you have described above, because the transformation folds |
But that's a bug because the index into the memref will be larger than the size Please fix the PR so it that becomes a load from a memref<64xf32> instead (My sense of the fix is that the getLimearizedDizeAndOffset utility should be doing |
To add a transformation that simplifies memory access patterns, this PR adds a memref linearizer which is based on the GPU/DecomposeMemRefs pass, with the following changes:
Notes:
memref<4x8xf32, strided<[8, 1], offset: 100>>
becomesmemref<32xf32, strided<[1], offset: 100>>
.outermostStride * outermostDimSize
.