Skip to content

[MLIR] Add a utility pass to linearize memref #136797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

lialan
Copy link
Member

@lialan lialan commented Apr 23, 2025

To add a transformation that simplifies memory access patterns, this PR adds a memref linearizer which is based on the GPU/DecomposeMemRefs pass, with the following changes:

  • support vector dialect ops
  • instead of decompose memrefs to rank-0 memrefs, flatten higher-ranked memrefs to rank-1.

Notes:

  • After the linearization, a MemRef's offset is kept, so a memref<4x8xf32, strided<[8, 1], offset: 100>> becomes memref<32xf32, strided<[1], offset: 100>>.
  • It also works with dynamic shapes and strides and offsets (see test cases for details).
  • The shape of the casted memref is computed as 1d, flattened, with size calculated as outermostStride * outermostDimSize.

Copy link

github-actions bot commented Apr 23, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@lialan
Copy link
Member Author

lialan commented Apr 23, 2025

I think we probably need a type converter that also linearize other memrefs not being handled by this transformation (by collapsing their shapes?), so at the end we would only deal with 1-dimensional memrefs.

@lialan lialan requested a review from Copilot April 23, 2025 01:49
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a memref linearizer pass to simplify memory access patterns by flattening higher-ranked memrefs into rank-1 memrefs while supporting vector dialect operations.

  • Adds a new pass implementation in FlattenMemRefs.cpp for linearizing memrefs.
  • Updates the AffineOps.cpp to use a newly defined affine::computeProduct function and removes a duplicate implementation.
  • Extends header files in Transforms.h and Passes.h to declare the necessary pattern and pass creation routines.

Reviewed Changes

Copilot reviewed 4 out of 7 changed files in this pull request and generated no comments.

File Description
mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp Implements the memref linearizer pass with transformations and pattern rewrites.
mlir/lib/Dialect/Affine/IR/AffineOps.cpp Refactors and centralizes the computeProduct logic within the affine namespace.
mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h Declares population of flatten memref patterns.
mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h Declares the creation routine for the flatten memrefs pass.
Files not reviewed (3)
  • mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td: Language not supported
  • mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt: Language not supported
  • mlir/test/Dialect/MemRef/flatten_memref.mlir: Language not supported
Comments suppressed due to low confidence (1)

mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp:140

  • The function name 'needFlattenning' contains a spelling error. Consider renaming it to 'needFlattening' for clarity and consistency.
static bool needFlattenning(Value val) {

@llvmbot
Copy link
Member

llvmbot commented Apr 23, 2025

@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Alan Li (lialan)

Changes

To add a transformation that simplifies memory access patterns, this PR adds a memref linearizer which is based on the GPU/DecomposeMemRefs pass, with the following changes:

  • support vector dialect ops
  • instead of decompose memrefs to rank-0 memrefs, flatten higher-ranked memrefs to rank-1.

Notes:

  • After the linearization, a MemRef's offset is kept, so a memref&lt;4x8xf32, strided&lt;[8, 1], offset: 100&gt;&gt; becomes memref&lt;32xf32, strided&lt;[1], offset: 100&gt;&gt;.
  • It also works with dynamic shapes and strides and offsets (see test cases for details).
  • The shape of the casted memref is computed as 1d, flattened, with size calculated as outermostStride * outermostDimSize.

Patch is 31.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136797.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h (+4)
  • (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td (+12)
  • (modified) mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h (+2)
  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+26-22)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp (+356)
  • (added) mlir/test/Dialect/MemRef/flatten_memref.mlir (+225)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
index d7050156862df..7580985754843 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -77,6 +77,10 @@ std::unique_ptr<Pass> createExpandStridedMetadataPass();
 /// components.
 std::unique_ptr<Pass> createExpandReallocPass(bool emitDeallocs = true);
 
+/// Creates an operation pass to flatten multiple dimensional memrefs into
+/// 1-d memrefs.
+std::unique_ptr<Pass> createFlattenMemrefsPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
index 651ee05ae1f3c..c87472851fd78 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
@@ -253,5 +253,17 @@ def ExpandRealloc : Pass<"expand-realloc"> {
   ];
 }
 
+def FlattenMemrefsPass : Pass<"flatten-memref"> {
+  let summary = "Flatten a multiple dimensional memref to 1-dimensional";
+  let description = [{
+
+  }];
+
+  let constructor = "mlir::memref::createFlattenMemrefsPass()";
+  let dependentDialects = [
+      "affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
+  ];
+}
+
 #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES
 
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
index 62a2297c80e78..c2b8cb05be922 100644
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
@@ -144,6 +144,8 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
 /// ```
 void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);
 
+void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);
+
 /// Build a new memref::AllocaOp whose dynamic sizes are independent of all
 /// given independencies. If the op is already independent of all
 /// independencies, the same AllocaOp result is returned.
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index aa49c49062c76..43224de5604ed 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -5022,6 +5022,31 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
   return ret;
 }
 
+namespace mlir {
+namespace affine {
+OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+                            ArrayRef<OpFoldResult> terms) {
+  int64_t nDynamic = 0;
+  SmallVector<Value> dynamicPart;
+  AffineExpr result = builder.getAffineConstantExpr(1);
+  for (OpFoldResult term : terms) {
+    if (!term)
+      return term;
+    std::optional<int64_t> maybeConst = getConstantIntValue(term);
+    if (maybeConst) {
+      result = result * builder.getAffineConstantExpr(*maybeConst);
+    } else {
+      dynamicPart.push_back(cast<Value>(term));
+      result = result * builder.getAffineSymbolExpr(nDynamic++);
+    }
+  }
+  if (auto constant = dyn_cast<AffineConstantExpr>(result))
+    return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+  return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+} // namespace affine
+} // namespace mlir
+
 namespace {
 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -5081,27 +5106,6 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
   }
 };
 
-OpFoldResult computeProduct(Location loc, OpBuilder &builder,
-                            ArrayRef<OpFoldResult> terms) {
-  int64_t nDynamic = 0;
-  SmallVector<Value> dynamicPart;
-  AffineExpr result = builder.getAffineConstantExpr(1);
-  for (OpFoldResult term : terms) {
-    if (!term)
-      return term;
-    std::optional<int64_t> maybeConst = getConstantIntValue(term);
-    if (maybeConst) {
-      result = result * builder.getAffineConstantExpr(*maybeConst);
-    } else {
-      dynamicPart.push_back(cast<Value>(term));
-      result = result * builder.getAffineSymbolExpr(nDynamic++);
-    }
-  }
-  if (auto constant = dyn_cast<AffineConstantExpr>(result))
-    return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
-  return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
-}
-
 /// If conseceutive outputs of a delinearize_index are linearized with the same
 /// bounds, canonicalize away the redundant arithmetic.
 ///
@@ -5248,7 +5252,7 @@ struct CancelLinearizeOfDelinearizePortion final
       // We use the slice from the linearize's basis above because of the
       // "bounds inferred from `disjoint`" case above.
       OpFoldResult newSize =
-          computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
+          affine::computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
 
       // Trivial case where we can just skip past the delinearize all together
       if (m.length == m.delinearize.getNumResults()) {
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index ecab97bc2b8e7..48e8bccd369fa 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   EmulateWideInt.cpp
   EmulateNarrowType.cpp
   ExtractAddressComputations.cpp
+  FlattenMemRefs.cpp
   FoldMemRefAliasOps.cpp
   IndependenceTransforms.cpp
   MultiBuffer.cpp
@@ -23,6 +24,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
 
   LINK_LIBS PUBLIC
   MLIRAffineTransforms
+  MLIRAffineDialect
   MLIRAffineUtils
   MLIRArithDialect
   MLIRArithTransforms
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
new file mode 100644
index 0000000000000..5ec524967444a
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp
@@ -0,0 +1,356 @@
+//===----- FlattenMemRefs.cpp - MemRef ops flattener pass  ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains patterns for flattening an multi-rank memref-related
+// ops into 1-d memref ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+namespace memref {
+#define GEN_PASS_DEF_FLATTENMEMREFSPASS
+#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
+} // namespace memref
+} // namespace mlir
+
+using namespace mlir;
+
+static void setInsertionPointToStart(OpBuilder &builder, Value val) {
+  if (auto *parentOp = val.getDefiningOp()) {
+    builder.setInsertionPointAfter(parentOp);
+  } else {
+    builder.setInsertionPointToStart(val.getParentBlock());
+  }
+}
+
+static std::tuple<Value, OpFoldResult, SmallVector<OpFoldResult>, OpFoldResult,
+                  OpFoldResult>
+getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source,
+                        ArrayRef<OpFoldResult> subOffsets,
+                        ArrayRef<OpFoldResult> subStrides = std::nullopt) {
+  auto sourceType = cast<MemRefType>(source.getType());
+  auto sourceRank = static_cast<unsigned>(sourceType.getRank());
+
+  memref::ExtractStridedMetadataOp newExtractStridedMetadata;
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    setInsertionPointToStart(rewriter, source);
+    newExtractStridedMetadata =
+        rewriter.create<memref::ExtractStridedMetadataOp>(loc, source);
+  }
+
+  auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset();
+
+  auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult {
+    return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal)
+                                      : rewriter.getIndexAttr(dim);
+  };
+
+  OpFoldResult origOffset =
+      getDim(sourceOffset, newExtractStridedMetadata.getOffset());
+  ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides();
+  OpFoldResult outmostDim =
+      getDim(sourceType.getShape().front(),
+             newExtractStridedMetadata.getSizes().front());
+
+  SmallVector<OpFoldResult> origStrides;
+  origStrides.reserve(sourceRank);
+
+  SmallVector<OpFoldResult> strides;
+  strides.reserve(sourceRank);
+
+  AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
+  AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
+  for (auto i : llvm::seq(0u, sourceRank)) {
+    OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]);
+
+    if (!subStrides.empty()) {
+      strides.push_back(affine::makeComposedFoldedAffineApply(
+          rewriter, loc, s0 * s1, {subStrides[i], origStride}));
+    }
+
+    origStrides.emplace_back(origStride);
+  }
+
+  // Compute linearized index:
+  auto &&[expr, values] =
+      computeLinearIndex(rewriter.getIndexAttr(0), origStrides, subOffsets);
+  OpFoldResult linearizedIndex =
+      affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);
+
+  // Compute collapsed size: (the outmost stride * outmost dimension).
+  SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
+  OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops);
+
+  return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex,
+          origStrides, origOffset, collapsedSize};
+}
+
+static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
+                                      OpFoldResult in) {
+  if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
+    return rewriter.create<arith::ConstantIndexOp>(
+        loc, cast<IntegerAttr>(offsetAttr).getInt());
+  }
+  return cast<Value>(in);
+}
+
+/// Returns a collapsed memref and the linearized index to access the element
+/// at the specified indices.
+static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
+                                                         Location loc,
+                                                         Value source,
+                                                         ValueRange indices) {
+  auto &&[base, index, strides, offset, collapsedShape] =
+      getFlatOffsetAndStrides(rewriter, loc, source,
+                              getAsOpFoldResult(indices));
+
+  return std::make_pair(
+      rewriter.create<memref::ReinterpretCastOp>(
+          loc, source,
+          /* offset = */ offset,
+          /* shapes = */ ArrayRef<OpFoldResult>{collapsedShape},
+          /* strides = */ ArrayRef<OpFoldResult>{strides.back()}),
+      getValueFromOpFoldResult(rewriter, loc, index));
+}
+
+static bool needFlattenning(Value val) {
+  auto type = cast<MemRefType>(val.getType());
+  return type.getRank() > 1;
+}
+
+static bool checkLayout(Value val) {
+  auto type = cast<MemRefType>(val.getType());
+  return type.getLayout().isIdentity() ||
+         isa<StridedLayoutAttr>(type.getLayout());
+}
+
+namespace {
+template <typename T>
+static Value getTargetMemref(T op) {
+  if constexpr (std::is_same_v<T, memref::LoadOp>) {
+    return op.getMemref();
+  } else if constexpr (std::is_same_v<T, vector::LoadOp>) {
+    return op.getBase();
+  } else if constexpr (std::is_same_v<T, memref::StoreOp>) {
+    return op.getMemref();
+  } else if constexpr (std::is_same_v<T, vector::StoreOp>) {
+    return op.getBase();
+  } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
+    return op.getBase();
+  } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
+    return op.getBase();
+  } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
+    return op.getSource();
+  } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
+    return op.getSource();
+  }
+  return {};
+}
+
+template <typename T>
+static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
+                      Value offset) {
+  if constexpr (std::is_same_v<T, memref::LoadOp>) {
+    auto newLoad = rewriter.create<memref::LoadOp>(
+        op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
+    newLoad->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newLoad.getResult());
+  } else if constexpr (std::is_same_v<T, vector::LoadOp>) {
+    auto newLoad = rewriter.create<vector::LoadOp>(
+        op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset});
+    newLoad->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newLoad.getResult());
+  } else if constexpr (std::is_same_v<T, memref::StoreOp>) {
+    auto newStore = rewriter.create<memref::StoreOp>(
+        op->getLoc(), op->getOperands().front(), flatMemref,
+        ValueRange{offset});
+    newStore->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newStore);
+  } else if constexpr (std::is_same_v<T, vector::StoreOp>) {
+    auto newStore = rewriter.create<vector::StoreOp>(
+        op->getLoc(), op->getOperands().front(), flatMemref,
+        ValueRange{offset});
+    newStore->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newStore);
+  } else if constexpr (std::is_same_v<T, vector::TransferReadOp>) {
+    auto newTransferRead = rewriter.create<vector::TransferReadOp>(
+        op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
+        op.getPadding());
+    rewriter.replaceOp(op, newTransferRead.getResult());
+  } else if constexpr (std::is_same_v<T, vector::TransferWriteOp>) {
+    auto newTransferWrite = rewriter.create<vector::TransferWriteOp>(
+        op->getLoc(), op.getVector(), flatMemref, ValueRange{offset});
+    rewriter.replaceOp(op, newTransferWrite);
+  } else if constexpr (std::is_same_v<T, vector::MaskedLoadOp>) {
+    auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>(
+        op->getLoc(), op.getType(), flatMemref, ValueRange{offset},
+        op.getMask(), op.getPassThru());
+    newMaskedLoad->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newMaskedLoad.getResult());
+  } else if constexpr (std::is_same_v<T, vector::MaskedStoreOp>) {
+    auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>(
+        op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(),
+        op.getValueToStore());
+    newMaskedStore->setAttrs(op->getAttrs());
+    rewriter.replaceOp(op, newMaskedStore);
+  } else {
+    op.emitOpError("unimplemented: do not know how to replace op.");
+  }
+}
+
+template <typename T>
+struct MemRefRewritePatternBase : public OpRewritePattern<T> {
+  using OpRewritePattern<T>::OpRewritePattern;
+  LogicalResult matchAndRewrite(T op,
+                                PatternRewriter &rewriter) const override {
+    Value memref = getTargetMemref<T>(op);
+    if (!needFlattenning(memref) || !checkLayout(memref))
+      return rewriter.notifyMatchFailure(op,
+                                         "nothing to do or unsupported layout");
+    auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
+        rewriter, op->getLoc(), memref, op.getIndices());
+    replaceOp<T>(op, rewriter, flatMemref, offset);
+    return success();
+  }
+};
+
+struct FlattenMemrefLoad : public MemRefRewritePatternBase<memref::LoadOp> {
+  using MemRefRewritePatternBase<memref::LoadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorLoad : public MemRefRewritePatternBase<vector::LoadOp> {
+  using MemRefRewritePatternBase<vector::LoadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenMemrefStore : public MemRefRewritePatternBase<memref::StoreOp> {
+  using MemRefRewritePatternBase<memref::StoreOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorStore : public MemRefRewritePatternBase<vector::StoreOp> {
+  using MemRefRewritePatternBase<vector::StoreOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorMaskedLoad
+    : public MemRefRewritePatternBase<vector::MaskedLoadOp> {
+  using MemRefRewritePatternBase<
+      vector::MaskedLoadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorMaskedStore
+    : public MemRefRewritePatternBase<vector::MaskedStoreOp> {
+  using MemRefRewritePatternBase<
+      vector::MaskedStoreOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorTransferRead
+    : public MemRefRewritePatternBase<vector::TransferReadOp> {
+  using MemRefRewritePatternBase<
+      vector::TransferReadOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenVectorTransferWrite
+    : public MemRefRewritePatternBase<vector::TransferWriteOp> {
+  using MemRefRewritePatternBase<
+      vector::TransferWriteOp>::MemRefRewritePatternBase;
+};
+
+struct FlattenSubview : public OpRewritePattern<memref::SubViewOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(memref::SubViewOp op,
+                                PatternRewriter &rewriter) const override {
+    Value memref = op.getSource();
+    if (!needFlattenning(memref))
+      return rewriter.notifyMatchFailure(op, "nothing to do");
+
+    if (!checkLayout(memref))
+      return rewriter.notifyMatchFailure(op, "unsupported layout");
+
+    Location loc = op.getLoc();
+    SmallVector<OpFoldResult> subOffsets = op.getMixedOffsets();
+    SmallVector<OpFoldResult> subSizes = op.getMixedSizes();
+    SmallVector<OpFoldResult> subStrides = op.getMixedStrides();
+    auto &&[base, finalOffset, strides, _, __] =
+        getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides);
+
+    auto srcType = cast<MemRefType>(memref.getType());
+    auto resultType = cast<MemRefType>(op.getType());
+    unsigned subRank = static_cast<unsigned>(resultType.getRank());
+
+    llvm::SmallBitVector droppedDims = op.getDroppedDims();
+
+    SmallVector<OpFoldResult> finalSizes;
+    finalSizes.reserve(subRank);
+
+    SmallVector<OpFoldResult> finalStrides;
+    finalStrides.reserve(subRank);
+
+    for (auto i : llvm::seq(0u, static_cast<unsigned>(srcType.getRank()))) {
+      if (droppedDims.test(i))
+        continue;
+
+      finalSizes.push_back(subSizes[i]);
+      finalStrides.push_back(strides[i]);
+    }
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, resultType, base, finalOffset, finalSizes, finalStrides);
+    return success();
+  }
+};
+
+struct FlattenMemrefsPass
+    : public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
+  using Base::Base;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<affine::AffineDialect, arith::ArithDialect,
+                    memref::MemRefDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+
+    memref::populateFlattenMemrefsPatterns(patterns);
+
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
+  patterns.insert<FlattenMemrefLoad, FlattenMemrefStore, FlattenSubview,
+                  FlattenVectorMaskedLoad, FlattenVectorMaskedStore,
+                  FlattenVectorLoad, FlattenVectorStore,
+                  FlattenVectorTransferRead, FlattenVectorTransferWrite>(
+      patterns.getContext());
+}
+
+std::unique_ptr<Pass> mlir::memref::createFlattenMemrefsPass() {
+  return std::make_unique<FlattenMemrefsPass>();
+}
diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatte...
[truncated]

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

I think we probably need a type converter that also linearize other memrefs not being handled by this transformation

What memrefs would that entail? I was actually going to ask - what about the more basic cases where simple memref.cast would be sufficient?

Thanks!

affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values);

// Compute collapsed size: (the outmost stride * outmost dimension).
SmallVector<OpFoldResult> ops{origStrides.front(), outmostDim};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you want "outermost" here, but max(size[dim] * stride[dim), unless you know the layout is contiguous / row-major-ish. Consider a column-major memref

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I changed the way how it is computed, so now we need to make sure it only works with row-major.

@lialan lialan force-pushed the lialan/flatten_memref branch from 598f3b9 to c9db9bf Compare May 1, 2025 18:26
@lialan lialan requested review from banach-space and krzysz00 May 1, 2025 20:00
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you take the GPU dialect tests for the pass this was based on and switch them over to this version scheduled on a GPU module?

Also, I'd like to add a test

func.func @load_from_subview(%x: memref<8x8xf32>, %col: index, %row: index) -> f32 {
  %subview = memref.subview %input[0, 0] [4, 4] [1, 1] : memref<8x8xf32> to memref<4x4xf32, strided<[8, 1]>>
  %ret = memref.load %subview[%row, %col] : memref<4x4xf32, strided<[8, 1]>>
  return %ret : f32
}

to make sure that this doesn't become

memref.load[(row * 8 + col)] : memref<16xf32>

@lialan
Copy link
Member Author

lialan commented May 2, 2025

Can you take the GPU dialect tests for the pass this was based on and switch them over to this version scheduled on a GPU module?

Also, I'd like to add a test

func.func @load_from_subview(%x: memref<8x8xf32>, %col: index, %row: index) -> f32 {
  %subview = memref.subview %input[0, 0] [4, 4] [1, 1] : memref<8x8xf32> to memref<4x4xf32, strided<[8, 1]>>
  %ret = memref.load %subview[%row, %col] : memref<4x4xf32, strided<[8, 1]>>
  return %ret : f32
}

to make sure that this doesn't become

memref.load[(row * 8 + col)] : memref<16xf32>

Well I think it actually transforms into what you have described above, because the transformation folds memref.subview into the use of memref.

@krzysz00
Copy link
Contributor

krzysz00 commented May 2, 2025

But that's a bug because the index into the memref will be larger than the size

Please fix the PR so it that becomes a load from a memref<64xf32> instead

(My sense of the fix is that the getLimearizedDizeAndOffset utility should be doing max_(d = 0 upto rank) {strides[d] * sizes[d]} for non-identity layouts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants