From b57c07fea33fb1f696b6410d14e1fbfff23babf6 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 22 Apr 2025 14:42:38 -0400 Subject: [PATCH 01/18] First commit. --- .../include/mlir/Dialect/MemRef/Transforms/Passes.td | 12 ++++++++++++ .../mlir/Dialect/MemRef/Transforms/Transforms.h | 3 +++ mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt | 2 ++ 3 files changed, 17 insertions(+) diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td index a46f73350bb3c..a2a9047bda808 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -245,5 +245,17 @@ def ExpandReallocPass : Pass<"expand-realloc"> { ]; } +def FlattenMemrefsPass : Pass<"flatten-memref"> { + let summary = "Flatten a multiple dimensional memref to 1-dimensional"; + let description = [{ + + }]; + + let constructor = "mlir::memref::createFlattenMemrefsPass()"; + let dependentDialects = [ + "affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect" + ]; +} + #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h index 62a2297c80e78..0649bf9c099f9 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -144,6 +144,9 @@ FailureOr multiBuffer(memref::AllocOp allocOp, /// ``` void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns); + +void populateFlattenMemrefsPatterns(RewritePatternSet &patterns); + /// Build a new memref::AllocaOp whose dynamic sizes are independent of all /// given independencies. If the op is already independent of all /// independencies, the same AllocaOp result is returned. diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index ecab97bc2b8e7..48e8bccd369fa 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms EmulateWideInt.cpp EmulateNarrowType.cpp ExtractAddressComputations.cpp + FlattenMemRefs.cpp FoldMemRefAliasOps.cpp IndependenceTransforms.cpp MultiBuffer.cpp @@ -23,6 +24,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms LINK_LIBS PUBLIC MLIRAffineTransforms + MLIRAffineDialect MLIRAffineUtils MLIRArithDialect MLIRArithTransforms From c003e70ea3d9d43a0679304fdfeceeb4f18940cb Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 22 Apr 2025 14:45:51 -0400 Subject: [PATCH 02/18] Adding test cases --- mlir/test/Dialect/MemRef/flatten_memref.mlir | 225 +++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 mlir/test/Dialect/MemRef/flatten_memref.mlir diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir new file mode 100644 index 0000000000000..6c9b09985acf7 --- /dev/null +++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir @@ -0,0 +1,225 @@ +// RUN: mlir-opt --flatten-memref %s --split-input-file --verify-diagnostics | FileCheck %s + +func.func @load_scalar_from_memref(%input: memref<4x8xf32, strided<[8, 1], offset: 100>>) -> f32 { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %value = memref.load %input[%c1, %c2] : memref<4x8xf32, strided<[8, 1], offset: 100>> + return %value : f32 +} +// CHECK: func @load_scalar_from_memref +// CHECK: %[[C10:.*]] = arith.constant 10 : index +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [100], sizes: [32], strides: [1] +// CHECK-SAME: memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>> +// CHECK: memref.load %[[REINT]][%[[C10]]] : memref<32xf32, strided<[1], offset: 100>> + +// ----- + +func.func @load_scalar_from_memref_static_dim_2(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index) -> f32 { + %value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>> + return %value : f32 +} +// CHECK: [[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)> +// CHECK: func @load_scalar_from_memref_static_dim_2 +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +// CHECK: %[[IDX:.*]] = affine.apply [[MAP]]()[%[[ARG2]], %[[ARG1]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [12] +// CHECK-SAME: to memref<32xf32, strided<[12], offset: 100>> +// CHECK: memref.load %[[REINT]][%[[IDX]]] + +// ----- + +func.func @load_scalar_from_memref_dynamic_dim(%input: memref>, %row: index, %col: index) -> f32 { + %value = memref.load %input[%col, %row] : memref> + return %value : f32 +} + +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s2 * s3)> +// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: func @load_scalar_from_memref_dynamic_dim +// CHECK-SAME: (%[[ARG0:.*]]: memref>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]] +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] +// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1] +// CHECK: memref.load %[[REINT]][%[[IDX]]] + +// ----- + +func.func @load_scalar_from_memref_subview(%input: memref<4x8xf32>, %row: index, %col: index) -> memref<1x1xf32, strided<[8, 1], offset: ?>> { + %subview = memref.subview %input[%col, %row] [1, 1] [1, 1] : memref<4x8xf32> to memref<1x1xf32, strided<[8, 1], offset: ?>> + return %subview : memref<1x1xf32, strided<[8, 1], offset: ?>> +} +// CHECK: func @load_scalar_from_memref_subview + +// ----- + +func.func @store_scalar_from_memref_static_dim(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index, %value: f32) { + memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>> + return +} +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)> +// CHECK: func @store_scalar_from_memref_static_dim +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32) +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] +// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[12], offset: 100>> + +// ----- + +func.func @store_scalar_from_memref_dynamic_dim(%input: memref>, %row: index, %col: index, %value: f32) { + memref.store %value, %input[%col, %row] : memref> + return +} +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s2 * s3)> +// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: func @store_scalar_from_memref_dynamic_dim +// CHECK-SAME: (%[[ARG0:.*]]: memref>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]] +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] +// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1] +// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] + +// ----- + +func.func @load_vector_from_memref(%input: memref<4x8xf32>) -> vector<8xf32> { + %c3 = arith.constant 3 : index + %c6 = arith.constant 6 : index + %value = vector.load %input[%c3, %c6] : memref<4x8xf32>, vector<8xf32> + return %value : vector<8xf32> +} +// CHECK: func @load_vector_from_memref +// CHECK: %[[C30:.*]] = arith.constant 30 +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1] +// CHECK-NEXT: vector.load %[[REINT]][%[[C30]]] + +// ----- + +func.func @load_vector_from_memref_odd(%input: memref<3x7xi2>) -> vector<3xi2> { + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %value = vector.load %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2> + return %value : vector<3xi2> +} +// CHECK: func @load_vector_from_memref_odd +// CHECK: %[[C10:.*]] = arith.constant 10 : index +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast +// CHECK-NEXT: vector.load %[[REINT]][%[[C10]]] + +// ----- + +func.func @load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: index, %col: index) -> vector<3xi2> { + %value = vector.load %input[%col, %row] : memref<3x7xi2>, vector<3xi2> + return %value : vector<3xi2> +} +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)> +// CHECK: func @load_vector_from_memref_dynamic +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]() +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast +// CHECK: vector.load %[[REINT]][%[[IDX]]] : memref<21xi2, strided<[1]>>, vector<3xi2> + +// ----- + +func.func @store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi2>) { + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + vector.store %value, %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2> + return +} +// CHECK: func @store_vector_to_memref_odd +// CHECK: %[[C10:.*]] = arith.constant 10 : index +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast +// CHECK-NEXT: vector.store %arg1, %[[REINT]][%[[C10]]] : memref<21xi2, strided<[1]> + +// ----- + +func.func @store_vector_to_memref_dynamic(%input: memref<3x7xi2>, %value: vector<3xi2>, %row: index, %col: index) { + vector.store %value, %input[%col, %row] : memref<3x7xi2>, vector<3xi2> + return +} +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)> +// CHECK: func @store_vector_to_memref_dynamic +// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1] +// CHECK: vector.store %[[ARG1]], %[[REINT]][%[[IDX]]] + +// ----- + +func.func @mask_store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi2>, %mask: vector<3xi1>) { + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + vector.maskedstore %input[%c1, %c3], %mask, %value : memref<3x7xi2>, vector<3xi1>, vector<3xi2> + return +} +// CHECK: func @mask_store_vector_to_memref_odd +// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: vector<3xi1>) +// CHECK: %[[C10:.*]] = arith.constant 10 : index +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast +// CHECK: vector.maskedstore %[[REINT]][%[[C10]]], %[[ARG2]], %[[ARG1]] + +// ----- + +func.func @mask_store_vector_to_memref_dynamic(%input: memref<3x7xi2>, %value: vector<3xi2>, %row: index, %col: index, %mask: vector<3xi1>) { + vector.maskedstore %input[%col, %row], %mask, %value : memref<3x7xi2>, vector<3xi1>, vector<3xi2> + return +} +// CHECK: #map = affine_map<()[s0, s1] -> (s0 * 7 + s1)> +// CHECK: func @mask_store_vector_to_memref_dynamic +// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: vector<3xi1>) +// CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] +// CHECK: vector.maskedstore %[[REINT]][%[[IDX]]], %[[ARG4]], %[[ARG1]] + +// ----- +func.func @mask_load_vector_from_memref_odd(%input: memref<3x7xi2>, %mask: vector<3xi1>, %passthru: vector<3xi2>) -> vector<3xi2> { + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %result = vector.maskedload %input[%c1, %c3], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2> + return %result : vector<3xi2> +} +// CHECK: func @mask_load_vector_from_memref_odd +// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[MASK:.*]]: vector<3xi1>, %[[PASSTHRU:.*]]: vector<3xi2>) +// CHECK: %[[C10:.*]] = arith.constant 10 : index +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1] +// CHECK: vector.maskedload %[[REINT]][%[[C10]]], %[[MASK]], %[[PASSTHRU]] + +// ----- + +func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: index, %col: index, %mask: vector<3xi1>, %passthru: vector<3xi2>) -> vector<3xi2> { + %result = vector.maskedload %input[%col, %row], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2> + return %result : vector<3xi2> +} +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)> +// CHECK: func @mask_load_vector_from_memref_dynamic +// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<3xi1>, %[[ARG4:.*]]: vector<3xi2>) +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] +// CHECK: vector.maskedload %[[REINT]][%[[IDX]]], %[[ARG3]] + +// ----- + +func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> { + %c0 = arith.constant 0 : i2 + %0 = vector.transfer_read %input[%col, %row], %c0 : memref<4x8xi2>, vector<8xi2> + return %0 : vector<8xi2> +} +// CHECK: func @transfer_read_memref +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) +// CHECK: %[[C0:.*]] = arith.constant 0 : i2 +// CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]] +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] +// CHECK-NEXT: vector.transfer_read %[[REINT]][%[[IDX]]], %[[C0]] + +// ----- + +func.func @transfer_write_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) { + vector.transfer_write %value, %input[%col, %row] : vector<8xi2>, memref<4x8xi2> + return +} +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)> +// CHECK: func @transfer_write_memref +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] +// CHECK: vector.transfer_write %[[ARG1]], %[[REINT]][%[[IDX]]] From 9170789b65289d8fe8f53253a2d989f9b33cd630 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 22 Apr 2025 16:23:39 -0400 Subject: [PATCH 03/18] missing the file --- .../MemRef/Transforms/FlattenMemRefs.cpp | 357 ++++++++++++++++++ 1 file changed, 357 insertions(+) create mode 100644 mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp new file mode 100644 index 0000000000000..6685896624536 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -0,0 +1,357 @@ +//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns for flattening an multi-rank memref-related +// ops into 1-d memref ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + + +namespace mlir { +namespace memref { +#define GEN_PASS_DEF_FLATTENMEMREFSPASS +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" +} // namespace memref +} // namespace mlir + +using namespace mlir; + +static void setInsertionPointToStart(OpBuilder &builder, Value val) { + if (auto *parentOp = val.getDefiningOp()) { + builder.setInsertionPointAfter(parentOp); + } else { + builder.setInsertionPointToStart(val.getParentBlock()); + } +} + +static std::tuple, OpFoldResult, + OpFoldResult> +getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, + ArrayRef subOffsets, + ArrayRef subStrides = std::nullopt) { + auto sourceType = cast(source.getType()); + auto sourceRank = static_cast(sourceType.getRank()); + + memref::ExtractStridedMetadataOp newExtractStridedMetadata; + { + OpBuilder::InsertionGuard g(rewriter); + setInsertionPointToStart(rewriter, source); + newExtractStridedMetadata = + rewriter.create(loc, source); + } + + auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset(); + + auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult { + return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal) + : rewriter.getIndexAttr(dim); + }; + + OpFoldResult origOffset = + getDim(sourceOffset, newExtractStridedMetadata.getOffset()); + ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides(); + OpFoldResult outmostDim = + getDim(sourceType.getShape().front(), + newExtractStridedMetadata.getSizes().front()); + + SmallVector origStrides; + origStrides.reserve(sourceRank); + + SmallVector strides; + strides.reserve(sourceRank); + + AffineExpr s0 = rewriter.getAffineSymbolExpr(0); + AffineExpr s1 = rewriter.getAffineSymbolExpr(1); + for (auto i : llvm::seq(0u, sourceRank)) { + OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]); + + if (!subStrides.empty()) { + strides.push_back(affine::makeComposedFoldedAffineApply( + rewriter, loc, s0 * s1, {subStrides[i], origStride})); + } + + origStrides.emplace_back(origStride); + } + + // Compute linearized index: + auto &&[expr, values] = + computeLinearIndex(rewriter.getIndexAttr(0), origStrides, subOffsets); + OpFoldResult linearizedIndex = + affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values); + + // Compute collapsed size: (the outmost stride * outmost dimension). + SmallVector ops{origStrides.front(), outmostDim}; + OpFoldResult collapsedSize = computeProduct(loc, rewriter, ops); + + return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex, + origStrides, origOffset, collapsedSize}; +} + +static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, + OpFoldResult in) { + if (Attribute offsetAttr = dyn_cast(in)) { + return rewriter.create( + loc, cast(offsetAttr).getInt()); + } + return cast(in); +} + +/// Returns a collapsed memref and the linearized index to access the element +/// at the specified indices. +static std::pair getFlattenMemrefAndOffset(OpBuilder &rewriter, + Location loc, + Value source, + ValueRange indices) { + auto &&[base, index, strides, offset, collapsedShape] = + getFlatOffsetAndStrides(rewriter, loc, source, + getAsOpFoldResult(indices)); + + return std::make_pair( + rewriter.create( + loc, source, + /* offset = */ offset, + /* shapes = */ ArrayRef{collapsedShape}, + /* strides = */ ArrayRef{strides.back()}), + getValueFromOpFoldResult(rewriter, loc, index)); +} + +static bool needFlattenning(Value val) { + auto type = cast(val.getType()); + return type.getRank() > 1; +} + +static bool checkLayout(Value val) { + auto type = cast(val.getType()); + return type.getLayout().isIdentity() || + isa(type.getLayout()); +} + +namespace { +template +static Value getTargetMemref(T op) { + if constexpr (std::is_same_v) { + return op.getMemref(); + } else if constexpr (std::is_same_v) { + return op.getBase(); + } else if constexpr (std::is_same_v) { + return op.getMemref(); + } else if constexpr (std::is_same_v) { + return op.getBase(); + } else if constexpr (std::is_same_v) { + return op.getBase(); + } else if constexpr (std::is_same_v) { + return op.getBase(); + } else if constexpr (std::is_same_v) { + return op.getSource(); + } else if constexpr (std::is_same_v) { + return op.getSource(); + } + return {}; +} + +template +static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, + Value offset) { + if constexpr (std::is_same_v) { + auto newLoad = rewriter.create( + op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset}); + newLoad->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newLoad.getResult()); + } else if constexpr (std::is_same_v) { + auto newLoad = rewriter.create( + op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset}); + newLoad->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newLoad.getResult()); + } else if constexpr (std::is_same_v) { + auto newStore = rewriter.create( + op->getLoc(), op->getOperands().front(), flatMemref, + ValueRange{offset}); + newStore->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newStore); + } else if constexpr (std::is_same_v) { + auto newStore = rewriter.create( + op->getLoc(), op->getOperands().front(), flatMemref, + ValueRange{offset}); + newStore->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newStore); + } else if constexpr (std::is_same_v) { + auto newTransferRead = rewriter.create( + op->getLoc(), op.getType(), flatMemref, ValueRange{offset}, + op.getPadding()); + rewriter.replaceOp(op, newTransferRead.getResult()); + } else if constexpr (std::is_same_v) { + auto newTransferWrite = rewriter.create( + op->getLoc(), op.getVector(), flatMemref, ValueRange{offset}); + rewriter.replaceOp(op, newTransferWrite); + } else if constexpr (std::is_same_v) { + auto newMaskedLoad = rewriter.create( + op->getLoc(), op.getType(), flatMemref, ValueRange{offset}, + op.getMask(), op.getPassThru()); + newMaskedLoad->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newMaskedLoad.getResult()); + } else if constexpr (std::is_same_v) { + auto newMaskedStore = rewriter.create( + op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(), + op.getValueToStore()); + newMaskedStore->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newMaskedStore); + } else { + op.emitOpError("unimplemented: do not know how to replace op."); + } +} + +template +struct MemRefRewritePatternBase : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(T op, + PatternRewriter &rewriter) const override { + Value memref = getTargetMemref(op); + if (!needFlattenning(memref) || !checkLayout(memref)) + return rewriter.notifyMatchFailure(op, + "nothing to do or unsupported layout"); + auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( + rewriter, op->getLoc(), memref, op.getIndices()); + replaceOp(op, rewriter, flatMemref, offset); + return success(); + } +}; + +struct FlattenMemrefLoad : public MemRefRewritePatternBase { + using MemRefRewritePatternBase::MemRefRewritePatternBase; +}; + +struct FlattenVectorLoad : public MemRefRewritePatternBase { + using MemRefRewritePatternBase::MemRefRewritePatternBase; +}; + +struct FlattenMemrefStore : public MemRefRewritePatternBase { + using MemRefRewritePatternBase::MemRefRewritePatternBase; +}; + +struct FlattenVectorStore : public MemRefRewritePatternBase { + using MemRefRewritePatternBase::MemRefRewritePatternBase; +}; + +struct FlattenVectorMaskedLoad + : public MemRefRewritePatternBase { + using MemRefRewritePatternBase< + vector::MaskedLoadOp>::MemRefRewritePatternBase; +}; + +struct FlattenVectorMaskedStore + : public MemRefRewritePatternBase { + using MemRefRewritePatternBase< + vector::MaskedStoreOp>::MemRefRewritePatternBase; +}; + +struct FlattenVectorTransferRead + : public MemRefRewritePatternBase { + using MemRefRewritePatternBase< + vector::TransferReadOp>::MemRefRewritePatternBase; +}; + +struct FlattenVectorTransferWrite + : public MemRefRewritePatternBase { + using MemRefRewritePatternBase< + vector::TransferWriteOp>::MemRefRewritePatternBase; +}; + +struct FlattenSubview : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::SubViewOp op, + PatternRewriter &rewriter) const override { + Value memref = op.getSource(); + if (!needFlattenning(memref)) + return rewriter.notifyMatchFailure(op, "nothing to do"); + + if (!checkLayout(memref)) + return rewriter.notifyMatchFailure(op, "unsupported layout"); + + Location loc = op.getLoc(); + SmallVector subOffsets = op.getMixedOffsets(); + SmallVector subSizes = op.getMixedSizes(); + SmallVector subStrides = op.getMixedStrides(); + auto &&[base, finalOffset, strides, _, __] = + getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides); + + auto srcType = cast(memref.getType()); + auto resultType = cast(op.getType()); + unsigned subRank = static_cast(resultType.getRank()); + + llvm::SmallBitVector droppedDims = op.getDroppedDims(); + + SmallVector finalSizes; + finalSizes.reserve(subRank); + + SmallVector finalStrides; + finalStrides.reserve(subRank); + + for (auto i : llvm::seq(0u, static_cast(srcType.getRank()))) { + if (droppedDims.test(i)) + continue; + + finalSizes.push_back(subSizes[i]); + finalStrides.push_back(strides[i]); + } + + rewriter.replaceOpWithNewOp( + op, resultType, base, finalOffset, finalSizes, finalStrides); + return success(); + } +}; + +struct FlattenMemrefsPass : public mlir::memref::impl::FlattenMemrefsPassBase { + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + + memref::populateFlattenMemrefsPatterns(patterns); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { + patterns.insert( + patterns.getContext()); +} + +std::unique_ptr mlir::memref::createFlattenMemrefsPass() { + return std::make_unique(); +} + From abe4379bbcd525eee99bfccedc6864ca1cab14a5 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 22 Apr 2025 21:00:29 -0400 Subject: [PATCH 04/18] Fix linking issue --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 48 ++++++++++--------- .../MemRef/Transforms/FlattenMemRefs.cpp | 2 +- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index eb23403a68813..e729e73f6ae0c 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -5083,6 +5083,31 @@ SmallVector AffineLinearizeIndexOp::getPaddedBasis() { return ret; } +namespace mlir { +namespace affine { +OpFoldResult computeProduct(Location loc, OpBuilder &builder, + ArrayRef terms) { + int64_t nDynamic = 0; + SmallVector dynamicPart; + AffineExpr result = builder.getAffineConstantExpr(1); + for (OpFoldResult term : terms) { + if (!term) + return term; + std::optional maybeConst = getConstantIntValue(term); + if (maybeConst) { + result = result * builder.getAffineConstantExpr(*maybeConst); + } else { + dynamicPart.push_back(cast(term)); + result = result * builder.getAffineSymbolExpr(nDynamic++); + } + } + if (auto constant = dyn_cast(result)) + return getAsIndexOpFoldResult(builder.getContext(), constant.getValue()); + return builder.create(loc, result, dynamicPart).getResult(); +} +} // namespace affine +} // namespace mlir + namespace { /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1, /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c, @@ -5142,27 +5167,6 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final } }; -OpFoldResult computeProduct(Location loc, OpBuilder &builder, - ArrayRef terms) { - int64_t nDynamic = 0; - SmallVector dynamicPart; - AffineExpr result = builder.getAffineConstantExpr(1); - for (OpFoldResult term : terms) { - if (!term) - return term; - std::optional maybeConst = getConstantIntValue(term); - if (maybeConst) { - result = result * builder.getAffineConstantExpr(*maybeConst); - } else { - dynamicPart.push_back(cast(term)); - result = result * builder.getAffineSymbolExpr(nDynamic++); - } - } - if (auto constant = dyn_cast(result)) - return getAsIndexOpFoldResult(builder.getContext(), constant.getValue()); - return builder.create(loc, result, dynamicPart).getResult(); -} - /// If conseceutive outputs of a delinearize_index are linearized with the same /// bounds, canonicalize away the redundant arithmetic. /// @@ -5309,7 +5313,7 @@ struct CancelLinearizeOfDelinearizePortion final // We use the slice from the linearize's basis above because of the // "bounds inferred from `disjoint`" case above. OpFoldResult newSize = - computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge); + affine::computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge); // Trivial case where we can just skip past the delinearize all together if (m.length == m.delinearize.getNumResults()) { diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 6685896624536..8299dc1716121 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -103,7 +103,7 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, // Compute collapsed size: (the outmost stride * outmost dimension). SmallVector ops{origStrides.front(), outmostDim}; - OpFoldResult collapsedSize = computeProduct(loc, rewriter, ops); + OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops); return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex, origStrides, origOffset, collapsedSize}; From d4da14ad207ca57cf3d76b2d6e22c948d47d3845 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 22 Apr 2025 22:47:17 -0400 Subject: [PATCH 05/18] linting --- mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h | 1 - mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h index 0649bf9c099f9..c2b8cb05be922 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -144,7 +144,6 @@ FailureOr multiBuffer(memref::AllocOp allocOp, /// ``` void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns); - void populateFlattenMemrefsPatterns(RewritePatternSet &patterns); /// Build a new memref::AllocaOp whose dynamic sizes are independent of all diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 8299dc1716121..5ec524967444a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -14,8 +14,8 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -28,7 +28,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" - namespace mlir { namespace memref { #define GEN_PASS_DEF_FLATTENMEMREFSPASS @@ -323,7 +322,8 @@ struct FlattenSubview : public OpRewritePattern { } }; -struct FlattenMemrefsPass : public mlir::memref::impl::FlattenMemrefsPassBase { +struct FlattenMemrefsPass + : public mlir::memref::impl::FlattenMemrefsPassBase { using Base::Base; void getDependentDialects(DialectRegistry ®istry) const override { @@ -354,4 +354,3 @@ void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { std::unique_ptr mlir::memref::createFlattenMemrefsPass() { return std::make_unique(); } - From 6aad0e7cfda39b291247053af0410a127d3cd001 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Wed, 23 Apr 2025 10:10:03 -0400 Subject: [PATCH 06/18] Fix misspelling --- mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 5ec524967444a..dda02be9a9c3a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -136,7 +136,7 @@ static std::pair getFlattenMemrefAndOffset(OpBuilder &rewriter, getValueFromOpFoldResult(rewriter, loc, index)); } -static bool needFlattenning(Value val) { +static bool needFlattening(Value val) { auto type = cast(val.getType()); return type.getRank() > 1; } @@ -227,7 +227,7 @@ struct MemRefRewritePatternBase : public OpRewritePattern { LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { Value memref = getTargetMemref(op); - if (!needFlattenning(memref) || !checkLayout(memref)) + if (!needFlattening(memref) || !checkLayout(memref)) return rewriter.notifyMatchFailure(op, "nothing to do or unsupported layout"); auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( @@ -283,7 +283,7 @@ struct FlattenSubview : public OpRewritePattern { LogicalResult matchAndRewrite(memref::SubViewOp op, PatternRewriter &rewriter) const override { Value memref = op.getSource(); - if (!needFlattenning(memref)) + if (!needFlattening(memref)) return rewriter.notifyMatchFailure(op, "nothing to do"); if (!checkLayout(memref)) From ce849953b7ae10c42bc1d7b02f4664eed87d4bef Mon Sep 17 00:00:00 2001 From: Alan Li Date: Fri, 25 Apr 2025 21:49:41 -0400 Subject: [PATCH 07/18] amend comments --- .../mlir/Dialect/MemRef/Transforms/Passes.td | 2 - .../MemRef/Transforms/FlattenMemRefs.cpp | 203 +++++++----------- 2 files changed, 79 insertions(+), 126 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td index a2a9047bda808..a8d135caa74f0 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -250,8 +250,6 @@ def FlattenMemrefsPass : Pass<"flatten-memref"> { let description = [{ }]; - - let constructor = "mlir::memref::createFlattenMemrefsPass()"; let dependentDialects = [ "affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect" ]; diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index dda02be9a9c3a..8336d9b5715e6 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" namespace mlir { namespace memref { @@ -148,135 +149,90 @@ static bool checkLayout(Value val) { } namespace { -template -static Value getTargetMemref(T op) { - if constexpr (std::is_same_v) { - return op.getMemref(); - } else if constexpr (std::is_same_v) { - return op.getBase(); - } else if constexpr (std::is_same_v) { - return op.getMemref(); - } else if constexpr (std::is_same_v) { - return op.getBase(); - } else if constexpr (std::is_same_v) { - return op.getBase(); - } else if constexpr (std::is_same_v) { - return op.getBase(); - } else if constexpr (std::is_same_v) { - return op.getSource(); - } else if constexpr (std::is_same_v) { - return op.getSource(); - } - return {}; +static Value getTargetMemref(Operation *op) { + return llvm::TypeSwitch(op) + .template Case( + [](auto op) { return op.getMemref(); }) + .template Case( + [](auto op) { return op.getBase(); }) + .template Case( + [](auto op) { return op.getSource(); }) + .Default([](auto) { return Value{}; }); } -template -static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, - Value offset) { - if constexpr (std::is_same_v) { - auto newLoad = rewriter.create( - op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset}); - newLoad->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newLoad.getResult()); - } else if constexpr (std::is_same_v) { - auto newLoad = rewriter.create( - op->getLoc(), op->getResultTypes(), flatMemref, ValueRange{offset}); - newLoad->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newLoad.getResult()); - } else if constexpr (std::is_same_v) { - auto newStore = rewriter.create( - op->getLoc(), op->getOperands().front(), flatMemref, - ValueRange{offset}); - newStore->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newStore); - } else if constexpr (std::is_same_v) { - auto newStore = rewriter.create( - op->getLoc(), op->getOperands().front(), flatMemref, - ValueRange{offset}); - newStore->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newStore); - } else if constexpr (std::is_same_v) { - auto newTransferRead = rewriter.create( - op->getLoc(), op.getType(), flatMemref, ValueRange{offset}, - op.getPadding()); - rewriter.replaceOp(op, newTransferRead.getResult()); - } else if constexpr (std::is_same_v) { - auto newTransferWrite = rewriter.create( - op->getLoc(), op.getVector(), flatMemref, ValueRange{offset}); - rewriter.replaceOp(op, newTransferWrite); - } else if constexpr (std::is_same_v) { - auto newMaskedLoad = rewriter.create( - op->getLoc(), op.getType(), flatMemref, ValueRange{offset}, - op.getMask(), op.getPassThru()); - newMaskedLoad->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newMaskedLoad.getResult()); - } else if constexpr (std::is_same_v) { - auto newMaskedStore = rewriter.create( - op->getLoc(), flatMemref, ValueRange{offset}, op.getMask(), - op.getValueToStore()); - newMaskedStore->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newMaskedStore); - } else { - op.emitOpError("unimplemented: do not know how to replace op."); - } +static void replaceOp(Operation *op, PatternRewriter &rewriter, + Value flatMemref, Value offset) { + auto loc = op->getLoc(); + llvm::TypeSwitch(op) + .Case([&](auto op) { + auto newLoad = rewriter.create( + loc, op->getResultTypes(), flatMemref, ValueRange{offset}); + newLoad->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newLoad.getResult()); + }) + .Case([&](auto op) { + auto newStore = rewriter.create( + loc, op->getOperands().front(), flatMemref, ValueRange{offset}); + newStore->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newStore); + }) + .Case([&](auto op) { + auto newLoad = rewriter.create( + loc, op->getResultTypes(), flatMemref, ValueRange{offset}); + newLoad->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newLoad.getResult()); + }) + .Case([&](auto op) { + auto newStore = rewriter.create( + loc, op->getOperands().front(), flatMemref, ValueRange{offset}); + newStore->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newStore); + }) + .Case([&](auto op) { + auto newMaskedLoad = rewriter.create( + loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), + op.getPassThru()); + newMaskedLoad->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newMaskedLoad.getResult()); + }) + .Case([&](auto op) { + auto newMaskedStore = rewriter.create( + loc, flatMemref, ValueRange{offset}, op.getMask(), + op.getValueToStore()); + newMaskedStore->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newMaskedStore); + }) + .Case([&](auto op) { + auto newTransferRead = rewriter.create( + loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); + rewriter.replaceOp(op, newTransferRead.getResult()); + }) + .Case([&](auto op) { + auto newTransferWrite = rewriter.create( + loc, op.getVector(), flatMemref, ValueRange{offset}); + rewriter.replaceOp(op, newTransferWrite); + }) + .Default([&](auto op) { + op->emitOpError("unimplemented: do not know how to replace op."); + }); } template -struct MemRefRewritePatternBase : public OpRewritePattern { +struct MemRefRewritePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { - Value memref = getTargetMemref(op); + Value memref = getTargetMemref(op); if (!needFlattening(memref) || !checkLayout(memref)) - return rewriter.notifyMatchFailure(op, - "nothing to do or unsupported layout"); + return failure(); auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( rewriter, op->getLoc(), memref, op.getIndices()); - replaceOp(op, rewriter, flatMemref, offset); + replaceOp(op, rewriter, flatMemref, offset); return success(); } }; -struct FlattenMemrefLoad : public MemRefRewritePatternBase { - using MemRefRewritePatternBase::MemRefRewritePatternBase; -}; - -struct FlattenVectorLoad : public MemRefRewritePatternBase { - using MemRefRewritePatternBase::MemRefRewritePatternBase; -}; - -struct FlattenMemrefStore : public MemRefRewritePatternBase { - using MemRefRewritePatternBase::MemRefRewritePatternBase; -}; - -struct FlattenVectorStore : public MemRefRewritePatternBase { - using MemRefRewritePatternBase::MemRefRewritePatternBase; -}; - -struct FlattenVectorMaskedLoad - : public MemRefRewritePatternBase { - using MemRefRewritePatternBase< - vector::MaskedLoadOp>::MemRefRewritePatternBase; -}; - -struct FlattenVectorMaskedStore - : public MemRefRewritePatternBase { - using MemRefRewritePatternBase< - vector::MaskedStoreOp>::MemRefRewritePatternBase; -}; - -struct FlattenVectorTransferRead - : public MemRefRewritePatternBase { - using MemRefRewritePatternBase< - vector::TransferReadOp>::MemRefRewritePatternBase; -}; - -struct FlattenVectorTransferWrite - : public MemRefRewritePatternBase { - using MemRefRewritePatternBase< - vector::TransferWriteOp>::MemRefRewritePatternBase; -}; - struct FlattenSubview : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -284,7 +240,7 @@ struct FlattenSubview : public OpRewritePattern { PatternRewriter &rewriter) const override { Value memref = op.getSource(); if (!needFlattening(memref)) - return rewriter.notifyMatchFailure(op, "nothing to do"); + return rewriter.notifyMatchFailure(op, "already flattened"); if (!checkLayout(memref)) return rewriter.notifyMatchFailure(op, "unsupported layout"); @@ -344,13 +300,12 @@ struct FlattenMemrefsPass } // namespace void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { - patterns.insert( - patterns.getContext()); -} - -std::unique_ptr mlir::memref::createFlattenMemrefsPass() { - return std::make_unique(); + patterns + .insert, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, FlattenSubview>( + patterns.getContext()); } From 7b097f47abcfaec5e12eba1096c0f3913aba509b Mon Sep 17 00:00:00 2001 From: Alan Li Date: Mon, 28 Apr 2025 09:31:23 -0400 Subject: [PATCH 08/18] Not working yet. --- .../MemRef/Transforms/FlattenMemRefs.cpp | 68 ++++++++++++++++++- 1 file changed, 66 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 8336d9b5715e6..65d92cc8e8c52 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -46,6 +46,69 @@ static void setInsertionPointToStart(OpBuilder &builder, Value val) { } } +OpFoldResult computeMemRefSpan(Value memref, OpBuilder &builder) { + Location loc = memref.getLoc(); + MemRefType type = cast(memref.getType()); + ArrayRef shape = type.getShape(); + + // Check for empty memref + if (type.hasStaticShape() && + llvm::any_of(shape, [](int64_t dim) { return dim == 0; })) { + return builder.getIndexAttr(0); + } + + // Get strides of the memref + SmallVector strides; + int64_t offset; + if (failed(type.getStridesAndOffset(strides, offset))) { + // Cannot extract strides, return a dynamic value + return Value(); + } + + // Static case: compute at compile time if possible + if (type.hasStaticShape()) { + int64_t span = 0; + for (unsigned i = 0; i < type.getRank(); ++i) { + span += (shape[i] - 1) * strides[i]; + } + return builder.getIndexAttr(span); + } + + // Dynamic case: emit IR to compute at runtime + Value result = builder.create(loc, 0); + + for (unsigned i = 0; i < type.getRank(); ++i) { + // Get dimension size + Value dimSize; + if (shape[i] == ShapedType::kDynamic) { + dimSize = builder.create(loc, memref, i); + } else { + dimSize = builder.create(loc, shape[i]); + } + + // Compute (dim - 1) + Value one = builder.create(loc, 1); + Value dimMinusOne = builder.create(loc, dimSize, one); + + // Get stride + Value stride; + if (strides[i] == ShapedType::kDynamicStrideOrOffset) { + // For dynamic strides, need to extract from memref descriptor + // This would require runtime support, possibly using extractStride + // As a placeholder, return a dynamic value + return Value(); + } else { + stride = builder.create(loc, strides[i]); + } + + // Add (dim - 1) * stride to result + Value term = builder.create(loc, dimMinusOne, stride); + result = builder.create(loc, result, term); + } + + return result; +} + static std::tuple, OpFoldResult, OpFoldResult> getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, @@ -102,8 +165,9 @@ getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values); // Compute collapsed size: (the outmost stride * outmost dimension). - SmallVector ops{origStrides.front(), outmostDim}; - OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops); + //SmallVector ops{origStrides.front(), outmostDim}; + //OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops); + OpFoldResult collapsedSize = computeMemRefSpan(source, rewriter); return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex, origStrides, origOffset, collapsedSize}; From c9db9bffd282bb0dbc4313eb065e15e92d10f24b Mon Sep 17 00:00:00 2001 From: Alan Li Date: Thu, 1 May 2025 12:33:32 -0400 Subject: [PATCH 09/18] Some updates --- .../MemRef/Transforms/FlattenMemRefs.cpp | 207 +++++------------- mlir/test/Dialect/MemRef/flatten_memref.mlir | 26 ++- 2 files changed, 69 insertions(+), 164 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 65d92cc8e8c52..ba4a00f9c0ed6 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -38,141 +39,6 @@ namespace memref { using namespace mlir; -static void setInsertionPointToStart(OpBuilder &builder, Value val) { - if (auto *parentOp = val.getDefiningOp()) { - builder.setInsertionPointAfter(parentOp); - } else { - builder.setInsertionPointToStart(val.getParentBlock()); - } -} - -OpFoldResult computeMemRefSpan(Value memref, OpBuilder &builder) { - Location loc = memref.getLoc(); - MemRefType type = cast(memref.getType()); - ArrayRef shape = type.getShape(); - - // Check for empty memref - if (type.hasStaticShape() && - llvm::any_of(shape, [](int64_t dim) { return dim == 0; })) { - return builder.getIndexAttr(0); - } - - // Get strides of the memref - SmallVector strides; - int64_t offset; - if (failed(type.getStridesAndOffset(strides, offset))) { - // Cannot extract strides, return a dynamic value - return Value(); - } - - // Static case: compute at compile time if possible - if (type.hasStaticShape()) { - int64_t span = 0; - for (unsigned i = 0; i < type.getRank(); ++i) { - span += (shape[i] - 1) * strides[i]; - } - return builder.getIndexAttr(span); - } - - // Dynamic case: emit IR to compute at runtime - Value result = builder.create(loc, 0); - - for (unsigned i = 0; i < type.getRank(); ++i) { - // Get dimension size - Value dimSize; - if (shape[i] == ShapedType::kDynamic) { - dimSize = builder.create(loc, memref, i); - } else { - dimSize = builder.create(loc, shape[i]); - } - - // Compute (dim - 1) - Value one = builder.create(loc, 1); - Value dimMinusOne = builder.create(loc, dimSize, one); - - // Get stride - Value stride; - if (strides[i] == ShapedType::kDynamicStrideOrOffset) { - // For dynamic strides, need to extract from memref descriptor - // This would require runtime support, possibly using extractStride - // As a placeholder, return a dynamic value - return Value(); - } else { - stride = builder.create(loc, strides[i]); - } - - // Add (dim - 1) * stride to result - Value term = builder.create(loc, dimMinusOne, stride); - result = builder.create(loc, result, term); - } - - return result; -} - -static std::tuple, OpFoldResult, - OpFoldResult> -getFlatOffsetAndStrides(OpBuilder &rewriter, Location loc, Value source, - ArrayRef subOffsets, - ArrayRef subStrides = std::nullopt) { - auto sourceType = cast(source.getType()); - auto sourceRank = static_cast(sourceType.getRank()); - - memref::ExtractStridedMetadataOp newExtractStridedMetadata; - { - OpBuilder::InsertionGuard g(rewriter); - setInsertionPointToStart(rewriter, source); - newExtractStridedMetadata = - rewriter.create(loc, source); - } - - auto &&[sourceStrides, sourceOffset] = sourceType.getStridesAndOffset(); - - auto getDim = [&](int64_t dim, Value dimVal) -> OpFoldResult { - return ShapedType::isDynamic(dim) ? getAsOpFoldResult(dimVal) - : rewriter.getIndexAttr(dim); - }; - - OpFoldResult origOffset = - getDim(sourceOffset, newExtractStridedMetadata.getOffset()); - ValueRange sourceStridesVals = newExtractStridedMetadata.getStrides(); - OpFoldResult outmostDim = - getDim(sourceType.getShape().front(), - newExtractStridedMetadata.getSizes().front()); - - SmallVector origStrides; - origStrides.reserve(sourceRank); - - SmallVector strides; - strides.reserve(sourceRank); - - AffineExpr s0 = rewriter.getAffineSymbolExpr(0); - AffineExpr s1 = rewriter.getAffineSymbolExpr(1); - for (auto i : llvm::seq(0u, sourceRank)) { - OpFoldResult origStride = getDim(sourceStrides[i], sourceStridesVals[i]); - - if (!subStrides.empty()) { - strides.push_back(affine::makeComposedFoldedAffineApply( - rewriter, loc, s0 * s1, {subStrides[i], origStride})); - } - - origStrides.emplace_back(origStride); - } - - // Compute linearized index: - auto &&[expr, values] = - computeLinearIndex(rewriter.getIndexAttr(0), origStrides, subOffsets); - OpFoldResult linearizedIndex = - affine::makeComposedFoldedAffineApply(rewriter, loc, expr, values); - - // Compute collapsed size: (the outmost stride * outmost dimension). - //SmallVector ops{origStrides.front(), outmostDim}; - //OpFoldResult collapsedSize = affine::computeProduct(loc, rewriter, ops); - OpFoldResult collapsedSize = computeMemRefSpan(source, rewriter); - - return {newExtractStridedMetadata.getBaseBuffer(), linearizedIndex, - origStrides, origOffset, collapsedSize}; -} - static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in) { if (Attribute offsetAttr = dyn_cast(in)) { @@ -188,17 +54,36 @@ static std::pair getFlattenMemrefAndOffset(OpBuilder &rewriter, Location loc, Value source, ValueRange indices) { - auto &&[base, index, strides, offset, collapsedShape] = - getFlatOffsetAndStrides(rewriter, loc, source, - getAsOpFoldResult(indices)); + int64_t sourceOffset; + SmallVector sourceStrides; + auto sourceType = cast(source.getType()); + if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) { + assert(false); + } + + memref::ExtractStridedMetadataOp stridedMetadata = + rewriter.create(loc, source); + + auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth(); + OpFoldResult linearizedIndices; + memref::LinearizedMemRefInfo linearizedInfo; + std::tie(linearizedInfo, linearizedIndices) = + memref::getLinearizedMemRefOffsetAndSize( + rewriter, loc, typeBit, typeBit, + stridedMetadata.getConstifiedMixedOffset(), + stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides(), + getAsOpFoldResult(indices)); return std::make_pair( rewriter.create( loc, source, - /* offset = */ offset, - /* shapes = */ ArrayRef{collapsedShape}, - /* strides = */ ArrayRef{strides.back()}), - getValueFromOpFoldResult(rewriter, loc, index)); + /* offset = */ linearizedInfo.linearizedOffset, + /* shapes = */ ArrayRef{linearizedInfo.linearizedSize}, + /* strides = */ + ArrayRef{ + stridedMetadata.getConstifiedMixedStrides().back()}), + getValueFromOpFoldResult(rewriter, loc, linearizedIndices)); } static bool needFlattening(Value val) { @@ -313,8 +198,23 @@ struct FlattenSubview : public OpRewritePattern { SmallVector subOffsets = op.getMixedOffsets(); SmallVector subSizes = op.getMixedSizes(); SmallVector subStrides = op.getMixedStrides(); - auto &&[base, finalOffset, strides, _, __] = - getFlatOffsetAndStrides(rewriter, loc, memref, subOffsets, subStrides); + + // base, finalOffset, strides + memref::ExtractStridedMetadataOp stridedMetadata = + rewriter.create(loc, memref); + + auto sourceType = cast(memref.getType()); + auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth(); + OpFoldResult linearizedIndices; + memref::LinearizedMemRefInfo linearizedInfo; + std::tie(linearizedInfo, linearizedIndices) = + memref::getLinearizedMemRefOffsetAndSize( + rewriter, loc, typeBit, typeBit, + stridedMetadata.getConstifiedMixedOffset(), + stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides(), op.getMixedOffsets()); + auto finalOffset = linearizedInfo.linearizedOffset; + auto strides = stridedMetadata.getConstifiedMixedStrides(); auto srcType = cast(memref.getType()); auto resultType = cast(op.getType()); @@ -337,7 +237,7 @@ struct FlattenSubview : public OpRewritePattern { } rewriter.replaceOpWithNewOp( - op, resultType, base, finalOffset, finalSizes, finalStrides); + op, resultType, memref, finalOffset, finalSizes, finalStrides); return success(); } }; @@ -364,12 +264,13 @@ struct FlattenMemrefsPass } // namespace void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { - patterns - .insert, - MemRefRewritePattern, - MemRefRewritePattern, - MemRefRewritePattern, - MemRefRewritePattern, - MemRefRewritePattern, FlattenSubview>( - patterns.getContext()); + patterns.insert, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, FlattenSubview>( + patterns.getContext()); } diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir index 6c9b09985acf7..f65e12ad6916d 100644 --- a/mlir/test/Dialect/MemRef/flatten_memref.mlir +++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir @@ -6,7 +6,7 @@ func.func @load_scalar_from_memref(%input: memref<4x8xf32, strided<[8, 1], offse %value = memref.load %input[%c1, %c2] : memref<4x8xf32, strided<[8, 1], offset: 100>> return %value : f32 } -// CHECK: func @load_scalar_from_memref +// CHECK-LABEL: func @load_scalar_from_memref // CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [100], sizes: [32], strides: [1] // CHECK-SAME: memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>> @@ -18,6 +18,7 @@ func.func @load_scalar_from_memref_static_dim_2(%input: memref<4x8xf32, strided< %value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>> return %value : f32 } + // CHECK: [[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)> // CHECK: func @load_scalar_from_memref_static_dim_2 // CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) @@ -39,7 +40,7 @@ func.func @load_scalar_from_memref_dynamic_dim(%input: memref>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]] // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] -// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0] +// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[SIZES]]#0, %[[SIZES]]#1] // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1] // CHECK: memref.load %[[REINT]][%[[IDX]]] @@ -49,7 +50,9 @@ func.func @load_scalar_from_memref_subview(%input: memref<4x8xf32>, %row: index, %subview = memref.subview %input[%col, %row] [1, 1] [1, 1] : memref<4x8xf32> to memref<1x1xf32, strided<[8, 1], offset: ?>> return %subview : memref<1x1xf32, strided<[8, 1], offset: ?>> } -// CHECK: func @load_scalar_from_memref_subview +// CHECK-LABEL: func @load_scalar_from_memref_subview +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [1, 1], strides: [8, 1] // ----- @@ -76,7 +79,7 @@ func.func @store_scalar_from_memref_dynamic_dim(%input: memref>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32) // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]] // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] -// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0] +// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[SIZES]]#0, %[[SIZES]]#1] // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1] // CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] @@ -88,7 +91,7 @@ func.func @load_vector_from_memref(%input: memref<4x8xf32>) -> vector<8xf32> { %value = vector.load %input[%c3, %c6] : memref<4x8xf32>, vector<8xf32> return %value : vector<8xf32> } -// CHECK: func @load_vector_from_memref +// CHECK-LABEL: func @load_vector_from_memref // CHECK: %[[C30:.*]] = arith.constant 30 // CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1] // CHECK-NEXT: vector.load %[[REINT]][%[[C30]]] @@ -101,7 +104,7 @@ func.func @load_vector_from_memref_odd(%input: memref<3x7xi2>) -> vector<3xi2> { %value = vector.load %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2> return %value : vector<3xi2> } -// CHECK: func @load_vector_from_memref_odd +// CHECK-LABEL: func @load_vector_from_memref_odd // CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast // CHECK-NEXT: vector.load %[[REINT]][%[[C10]]] @@ -126,10 +129,11 @@ func.func @store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi vector.store %value, %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2> return } -// CHECK: func @store_vector_to_memref_odd +// CHECK-LABEL: func @store_vector_to_memref_odd +// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>) // CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast -// CHECK-NEXT: vector.store %arg1, %[[REINT]][%[[C10]]] : memref<21xi2, strided<[1]> +// CHECK-NEXT: vector.store %[[ARG1]], %[[REINT]][%[[C10]]] : memref<21xi2, strided<[1]> // ----- @@ -152,7 +156,7 @@ func.func @mask_store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vecto vector.maskedstore %input[%c1, %c3], %mask, %value : memref<3x7xi2>, vector<3xi1>, vector<3xi2> return } -// CHECK: func @mask_store_vector_to_memref_odd +// CHECK-LABEL: func @mask_store_vector_to_memref_odd // CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: vector<3xi1>) // CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast @@ -178,7 +182,7 @@ func.func @mask_load_vector_from_memref_odd(%input: memref<3x7xi2>, %mask: vecto %result = vector.maskedload %input[%c1, %c3], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2> return %result : vector<3xi2> } -// CHECK: func @mask_load_vector_from_memref_odd +// CHECK-LABEL: func @mask_load_vector_from_memref_odd // CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[MASK:.*]]: vector<3xi1>, %[[PASSTHRU:.*]]: vector<3xi2>) // CHECK: %[[C10:.*]] = arith.constant 10 : index // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1] @@ -204,7 +208,7 @@ func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %r %0 = vector.transfer_read %input[%col, %row], %c0 : memref<4x8xi2>, vector<8xi2> return %0 : vector<8xi2> } -// CHECK: func @transfer_read_memref +// CHECK-LABEL: func @transfer_read_memref // CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) // CHECK: %[[C0:.*]] = arith.constant 0 : i2 // CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]] From 01d131e936f03844a3f919a1a70cde916273be36 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Mon, 5 May 2025 23:38:08 -0400 Subject: [PATCH 10/18] Remove subview --- .../MemRef/Transforms/FlattenMemRefs.cpp | 62 +------------------ mlir/test/Dialect/MemRef/flatten_memref.mlir | 10 --- 2 files changed, 1 insertion(+), 71 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index ba4a00f9c0ed6..32fe64bb616bc 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -182,66 +182,6 @@ struct MemRefRewritePattern : public OpRewritePattern { } }; -struct FlattenSubview : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(memref::SubViewOp op, - PatternRewriter &rewriter) const override { - Value memref = op.getSource(); - if (!needFlattening(memref)) - return rewriter.notifyMatchFailure(op, "already flattened"); - - if (!checkLayout(memref)) - return rewriter.notifyMatchFailure(op, "unsupported layout"); - - Location loc = op.getLoc(); - SmallVector subOffsets = op.getMixedOffsets(); - SmallVector subSizes = op.getMixedSizes(); - SmallVector subStrides = op.getMixedStrides(); - - // base, finalOffset, strides - memref::ExtractStridedMetadataOp stridedMetadata = - rewriter.create(loc, memref); - - auto sourceType = cast(memref.getType()); - auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth(); - OpFoldResult linearizedIndices; - memref::LinearizedMemRefInfo linearizedInfo; - std::tie(linearizedInfo, linearizedIndices) = - memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, typeBit, typeBit, - stridedMetadata.getConstifiedMixedOffset(), - stridedMetadata.getConstifiedMixedSizes(), - stridedMetadata.getConstifiedMixedStrides(), op.getMixedOffsets()); - auto finalOffset = linearizedInfo.linearizedOffset; - auto strides = stridedMetadata.getConstifiedMixedStrides(); - - auto srcType = cast(memref.getType()); - auto resultType = cast(op.getType()); - unsigned subRank = static_cast(resultType.getRank()); - - llvm::SmallBitVector droppedDims = op.getDroppedDims(); - - SmallVector finalSizes; - finalSizes.reserve(subRank); - - SmallVector finalStrides; - finalStrides.reserve(subRank); - - for (auto i : llvm::seq(0u, static_cast(srcType.getRank()))) { - if (droppedDims.test(i)) - continue; - - finalSizes.push_back(subSizes[i]); - finalStrides.push_back(strides[i]); - } - - rewriter.replaceOpWithNewOp( - op, resultType, memref, finalOffset, finalSizes, finalStrides); - return success(); - } -}; - struct FlattenMemrefsPass : public mlir::memref::impl::FlattenMemrefsPassBase { using Base::Base; @@ -271,6 +211,6 @@ void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { MemRefRewritePattern, MemRefRewritePattern, MemRefRewritePattern, - MemRefRewritePattern, FlattenSubview>( + MemRefRewritePattern>( patterns.getContext()); } diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir index f65e12ad6916d..a182ae58683dd 100644 --- a/mlir/test/Dialect/MemRef/flatten_memref.mlir +++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir @@ -46,16 +46,6 @@ func.func @load_scalar_from_memref_dynamic_dim(%input: memref, %row: index, %col: index) -> memref<1x1xf32, strided<[8, 1], offset: ?>> { - %subview = memref.subview %input[%col, %row] [1, 1] [1, 1] : memref<4x8xf32> to memref<1x1xf32, strided<[8, 1], offset: ?>> - return %subview : memref<1x1xf32, strided<[8, 1], offset: ?>> -} -// CHECK-LABEL: func @load_scalar_from_memref_subview -// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [1, 1], strides: [8, 1] - -// ----- - func.func @store_scalar_from_memref_static_dim(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index, %value: f32) { memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>> return From 1a8aeb48528990e579bda896d0ccc01606304bd7 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 6 May 2025 13:13:17 -0400 Subject: [PATCH 11/18] Adding memref alloc/alloca. --- .../MemRef/Transforms/FlattenMemRefs.cpp | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 32fe64bb616bc..e4d29a9f9e0f4 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -28,8 +28,11 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" +#include + namespace mlir { namespace memref { #define GEN_PASS_DEF_FLATTENMEMREFSPASS @@ -182,6 +185,29 @@ struct MemRefRewritePattern : public OpRewritePattern { } }; +// For any memref op that emits a new memref. +template +struct MemRefSourceRewritePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(T op, + PatternRewriter &rewriter) const override { + if (!needFlattening(op) || !checkLayout(op)) + return failure(); + MemRefType sourceType = cast(op.getType()); + + // Get flattened size, no strides. + auto dimSizes = llvm::to_vector(sourceType.getShape()); + auto flattenedSize = std::accumulate( + dimSizes.begin(), dimSizes.end(), 1, std::multiplies()); + auto flatMemrefType = MemRefType::get( + /*shape=*/{flattenedSize}, sourceType.getElementType(), + /*layout=*/nullptr, sourceType.getMemorySpace()); + rewriter.replaceOpWithNewOp( + op, flatMemrefType); + return success(); + } +}; + struct FlattenMemrefsPass : public mlir::memref::impl::FlattenMemrefsPassBase { using Base::Base; @@ -206,6 +232,8 @@ struct FlattenMemrefsPass void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { patterns.insert, MemRefRewritePattern, + MemRefSourceRewritePattern, + MemRefSourceRewritePattern, MemRefRewritePattern, MemRefRewritePattern, MemRefRewritePattern, From de5ae5d315e1bf96fe45f1403ef77e6cbeb9414d Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 6 May 2025 15:07:23 -0400 Subject: [PATCH 12/18] support alloc/alloca --- .../MemRef/Transforms/FlattenMemRefs.cpp | 120 ++++++++++++------ mlir/test/Dialect/MemRef/flatten_memref.mlir | 21 +++ 2 files changed, 102 insertions(+), 39 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index e4d29a9f9e0f4..29b868ecde46d 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -103,8 +103,8 @@ static bool checkLayout(Value val) { namespace { static Value getTargetMemref(Operation *op) { return llvm::TypeSwitch(op) - .template Case( - [](auto op) { return op.getMemref(); }) + .template Case([](auto op) { return op.getMemref(); }) .template Case( [](auto op) { return op.getBase(); }) @@ -113,54 +113,78 @@ static Value getTargetMemref(Operation *op) { .Default([](auto) { return Value{}; }); } -static void replaceOp(Operation *op, PatternRewriter &rewriter, - Value flatMemref, Value offset) { +template +static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, + Value offset) { auto loc = op->getLoc(); - llvm::TypeSwitch(op) - .Case([&](auto op) { + llvm::TypeSwitch(op.getOperation()) + .template Case([&](auto oper) { + // grab flatMemref's type, and replace op with a new one. Then + // reinterpret it back. + auto flatMemrefType = cast(flatMemref.getType()); + auto loc = oper.getLoc(); + auto newAlloc = rewriter.create( + loc, flatMemrefType, oper.getAlignmentAttr()); + auto originalType = cast(oper.getType()); + + auto rank = originalType.getRank(); + SmallVector sizes, strides; + sizes.resize(rank); + strides.resize(rank); + int64_t staticStride = 1; + for (int i = rank - 1; i >= 0; --i) { + sizes[i] = rewriter.getIndexAttr(originalType.getShape()[i]); + strides[i] = rewriter.getIndexAttr(staticStride); + staticStride *= originalType.getShape()[i]; + } + rewriter.replaceOpWithNewOp( + op, originalType, newAlloc, + /*offset=*/rewriter.getIndexAttr(0), sizes, strides); + }) + .template Case([&](auto op) { auto newLoad = rewriter.create( loc, op->getResultTypes(), flatMemref, ValueRange{offset}); newLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newLoad.getResult()); }) - .Case([&](auto op) { + .template Case([&](auto op) { auto newStore = rewriter.create( loc, op->getOperands().front(), flatMemref, ValueRange{offset}); newStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newStore); }) - .Case([&](auto op) { + .template Case([&](auto op) { auto newLoad = rewriter.create( loc, op->getResultTypes(), flatMemref, ValueRange{offset}); newLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newLoad.getResult()); }) - .Case([&](auto op) { + .template Case([&](auto op) { auto newStore = rewriter.create( loc, op->getOperands().front(), flatMemref, ValueRange{offset}); newStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newStore); }) - .Case([&](auto op) { + .template Case([&](auto op) { auto newMaskedLoad = rewriter.create( loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), op.getPassThru()); newMaskedLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newMaskedLoad.getResult()); }) - .Case([&](auto op) { + .template Case([&](auto op) { auto newMaskedStore = rewriter.create( loc, flatMemref, ValueRange{offset}, op.getMask(), op.getValueToStore()); newMaskedStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newMaskedStore); }) - .Case([&](auto op) { + .template Case([&](auto op) { auto newTransferRead = rewriter.create( loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); rewriter.replaceOp(op, newTransferRead.getResult()); }) - .Case([&](auto op) { + .template Case([&](auto op) { auto newTransferWrite = rewriter.create( loc, op.getVector(), flatMemref, ValueRange{offset}); rewriter.replaceOp(op, newTransferWrite); @@ -170,6 +194,16 @@ static void replaceOp(Operation *op, PatternRewriter &rewriter, }); } +template +static ValueRange getIndices(T op) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ValueRange{}; + } else { + return op.getIndices(); + } +} + template struct MemRefRewritePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -179,34 +213,42 @@ struct MemRefRewritePattern : public OpRewritePattern { if (!needFlattening(memref) || !checkLayout(memref)) return failure(); auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( - rewriter, op->getLoc(), memref, op.getIndices()); - replaceOp(op, rewriter, flatMemref, offset); + rewriter, op->getLoc(), memref, getIndices(op)); + replaceOp(op, rewriter, flatMemref, offset); return success(); } }; -// For any memref op that emits a new memref. -template -struct MemRefSourceRewritePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(T op, - PatternRewriter &rewriter) const override { - if (!needFlattening(op) || !checkLayout(op)) - return failure(); - MemRefType sourceType = cast(op.getType()); - - // Get flattened size, no strides. - auto dimSizes = llvm::to_vector(sourceType.getShape()); - auto flattenedSize = std::accumulate( - dimSizes.begin(), dimSizes.end(), 1, std::multiplies()); - auto flatMemrefType = MemRefType::get( - /*shape=*/{flattenedSize}, sourceType.getElementType(), - /*layout=*/nullptr, sourceType.getMemorySpace()); - rewriter.replaceOpWithNewOp( - op, flatMemrefType); - return success(); - } -}; +// // For any memref op that emits a new memref. +// template +// struct MemRefSourceRewritePattern : public OpRewritePattern { +// using OpRewritePattern::OpRewritePattern; +// LogicalResult matchAndRewrite(T op, +// PatternRewriter &rewriter) const override { +// if (!needFlattening(op) || !checkLayout(op)) +// return failure(); +// MemRefType sourceType = cast(op.getType()); + +// auto mixedSizes = op.getMixedSizes(); + +// // Get flattened size, no strides. +// auto flattenedSize = std::accumulate( +// mixedSizes.begin(), mixedSizes.end(), 1, +// [](int64_t a, OpFoldResult b) { +// return a * getConstantIntValue(b).value_or(1); +// }); + +// auto flatMemrefType = MemRefType::get( +// /*shape=*/{flattenedSize}, sourceType.getElementType(), +// /*layout=*/nullptr, sourceType.getMemorySpace()); +// auto newSource = rewriter.create( +// op.getLoc(), flatMemrefType, op.getDynamicSizes()); +// auto reinterpretCast = rewriter.create( +// op.getLoc(), sourceType, newSource, op.getOffset(), +// op.getMixedSizes(), op.getStrides()); +// return success(); +// } +// }; struct FlattenMemrefsPass : public mlir::memref::impl::FlattenMemrefsPassBase { @@ -232,8 +274,8 @@ struct FlattenMemrefsPass void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { patterns.insert, MemRefRewritePattern, - MemRefSourceRewritePattern, - MemRefSourceRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, MemRefRewritePattern, MemRefRewritePattern, MemRefRewritePattern, diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir index a182ae58683dd..1cfeaabfb1115 100644 --- a/mlir/test/Dialect/MemRef/flatten_memref.mlir +++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir @@ -217,3 +217,24 @@ func.func @transfer_write_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, % // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]] // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] // CHECK: vector.transfer_write %[[ARG1]], %[[REINT]][%[[IDX]]] + +// ----- + +func.func @alloc_4x8_f32() -> memref<4x8xf32> { + // Allocate a memref of size 4x8 with f32 elements. + // The memref is uninitialized by default. + %0 = memref.alloc() : memref<4x8xf32> + + // Return the allocated memref. + return %0 : memref<4x8xf32> +} + +// ----- + +func.func @chained_alloc_load() -> vector<8xf32> { + %c3 = arith.constant 3 : index + %c6 = arith.constant 6 : index + %0 = memref.alloc() : memref<4x8xf32> + %value = vector.load %0[%c3, %c6] : memref<4x8xf32>, vector<8xf32> + return %value : vector<8xf32> +} \ No newline at end of file From 44605e254008e13194fae054f936b94dae278629 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 6 May 2025 19:11:47 -0400 Subject: [PATCH 13/18] Support Alloc/Alloca. --- .../MemRef/Transforms/FlattenMemRefs.cpp | 70 ++++++------------- mlir/test/Dialect/MemRef/flatten_memref.mlir | 28 ++++++-- 2 files changed, 43 insertions(+), 55 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 29b868ecde46d..779cdb10ac0a5 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -119,27 +119,28 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, auto loc = op->getLoc(); llvm::TypeSwitch(op.getOperation()) .template Case([&](auto oper) { - // grab flatMemref's type, and replace op with a new one. Then - // reinterpret it back. - auto flatMemrefType = cast(flatMemref.getType()); - auto loc = oper.getLoc(); auto newAlloc = rewriter.create( - loc, flatMemrefType, oper.getAlignmentAttr()); - auto originalType = cast(oper.getType()); - - auto rank = originalType.getRank(); - SmallVector sizes, strides; - sizes.resize(rank); - strides.resize(rank); - int64_t staticStride = 1; - for (int i = rank - 1; i >= 0; --i) { - sizes[i] = rewriter.getIndexAttr(originalType.getShape()[i]); - strides[i] = rewriter.getIndexAttr(staticStride); - staticStride *= originalType.getShape()[i]; - } + loc, cast(flatMemref.getType()), + oper.getAlignmentAttr()); + memref::ExtractStridedMetadataOp stridedMetadata = + rewriter.create(loc, oper); + rewriter.replaceOpWithNewOp( + op, cast(oper.getType()), newAlloc, + /*offset=*/rewriter.getIndexAttr(0), + stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides()); + }) + .template Case([&](auto oper) { + auto newAlloca = rewriter.create( + loc, cast(flatMemref.getType()), + oper.getAlignmentAttr()); + memref::ExtractStridedMetadataOp stridedMetadata = + rewriter.create(loc, oper); rewriter.replaceOpWithNewOp( - op, originalType, newAlloc, - /*offset=*/rewriter.getIndexAttr(0), sizes, strides); + op, cast(oper.getType()), newAlloca, + /*offset=*/rewriter.getIndexAttr(0), + stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides()); }) .template Case([&](auto op) { auto newLoad = rewriter.create( @@ -219,37 +220,6 @@ struct MemRefRewritePattern : public OpRewritePattern { } }; -// // For any memref op that emits a new memref. -// template -// struct MemRefSourceRewritePattern : public OpRewritePattern { -// using OpRewritePattern::OpRewritePattern; -// LogicalResult matchAndRewrite(T op, -// PatternRewriter &rewriter) const override { -// if (!needFlattening(op) || !checkLayout(op)) -// return failure(); -// MemRefType sourceType = cast(op.getType()); - -// auto mixedSizes = op.getMixedSizes(); - -// // Get flattened size, no strides. -// auto flattenedSize = std::accumulate( -// mixedSizes.begin(), mixedSizes.end(), 1, -// [](int64_t a, OpFoldResult b) { -// return a * getConstantIntValue(b).value_or(1); -// }); - -// auto flatMemrefType = MemRefType::get( -// /*shape=*/{flattenedSize}, sourceType.getElementType(), -// /*layout=*/nullptr, sourceType.getMemorySpace()); -// auto newSource = rewriter.create( -// op.getLoc(), flatMemrefType, op.getDynamicSizes()); -// auto reinterpretCast = rewriter.create( -// op.getLoc(), sourceType, newSource, op.getOffset(), -// op.getMixedSizes(), op.getStrides()); -// return success(); -// } -// }; - struct FlattenMemrefsPass : public mlir::memref::impl::FlattenMemrefsPassBase { using Base::Base; diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir index 1cfeaabfb1115..8af34c13c2a36 100644 --- a/mlir/test/Dialect/MemRef/flatten_memref.mlir +++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir @@ -220,15 +220,27 @@ func.func @transfer_write_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, % // ----- -func.func @alloc_4x8_f32() -> memref<4x8xf32> { - // Allocate a memref of size 4x8 with f32 elements. - // The memref is uninitialized by default. +func.func @alloc() -> memref<4x8xf32> { %0 = memref.alloc() : memref<4x8xf32> + return %0 : memref<4x8xf32> +} + +// CHECK-LABEL: func @alloc +// CHECK-SAME: () -> memref<4x8xf32> +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32, strided<[1]>> +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [8, 1] : memref<32xf32, strided<[1]>> to memref<4x8xf32> + +// ----- - // Return the allocated memref. +func.func @alloca() -> memref<4x8xf32> { + %0 = memref.alloca() : memref<4x8xf32> return %0 : memref<4x8xf32> } +// CHECK-LABEL: func.func @alloca() -> memref<4x8xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<32xf32, strided<[1]>> +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [8, 1] : memref<32xf32, strided<[1]>> to memref<4x8xf32> + // ----- func.func @chained_alloc_load() -> vector<8xf32> { @@ -237,4 +249,10 @@ func.func @chained_alloc_load() -> vector<8xf32> { %0 = memref.alloc() : memref<4x8xf32> %value = vector.load %0[%c3, %c6] : memref<4x8xf32>, vector<8xf32> return %value : vector<8xf32> -} \ No newline at end of file +} + +// CHECK-LABEL: func @chained_alloc_load +// CHECK-SAME: () -> vector<8xf32> +// CHECK-NEXT: %[[C30:.*]] = arith.constant 30 : index +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32, strided<[1]>> +// CHECK-NEXT: vector.load %[[ALLOC]][%[[C30]]] : memref<32xf32, strided<[1]>>, vector<8xf32> From ad09d3cb965656927fba793aa2af625f228c2ba1 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Tue, 6 May 2025 19:46:42 -0400 Subject: [PATCH 14/18] refactor --- .../MemRef/Transforms/FlattenMemRefs.cpp | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 779cdb10ac0a5..e79c996c6f948 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -113,6 +113,18 @@ static Value getTargetMemref(Operation *op) { .Default([](auto) { return Value{}; }); } +template +static void castResult(T oper, T newOper, Location loc, + PatternRewriter &rewriter) { + memref::ExtractStridedMetadataOp stridedMetadata = + rewriter.create(loc, oper); + rewriter.replaceOpWithNewOp( + oper, cast(oper.getType()), newOper, + /*offset=*/rewriter.getIndexAttr(0), + stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides()); +} + template static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, Value offset) { @@ -122,25 +134,13 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, auto newAlloc = rewriter.create( loc, cast(flatMemref.getType()), oper.getAlignmentAttr()); - memref::ExtractStridedMetadataOp stridedMetadata = - rewriter.create(loc, oper); - rewriter.replaceOpWithNewOp( - op, cast(oper.getType()), newAlloc, - /*offset=*/rewriter.getIndexAttr(0), - stridedMetadata.getConstifiedMixedSizes(), - stridedMetadata.getConstifiedMixedStrides()); + castResult(oper, newAlloc, loc, rewriter); }) .template Case([&](auto oper) { auto newAlloca = rewriter.create( loc, cast(flatMemref.getType()), oper.getAlignmentAttr()); - memref::ExtractStridedMetadataOp stridedMetadata = - rewriter.create(loc, oper); - rewriter.replaceOpWithNewOp( - op, cast(oper.getType()), newAlloca, - /*offset=*/rewriter.getIndexAttr(0), - stridedMetadata.getConstifiedMixedSizes(), - stridedMetadata.getConstifiedMixedStrides()); + castResult(oper, newAlloca, loc, rewriter); }) .template Case([&](auto op) { auto newLoad = rewriter.create( From e31e055911ae29a1c9903beac0892c3ad39429b5 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Wed, 7 May 2025 00:05:31 -0400 Subject: [PATCH 15/18] Change the way how linearized sizes are computed. --- .../MemRef/Transforms/FlattenMemRefs.cpp | 72 ++++++++++++++++++- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index e79c996c6f948..d212b59782f3f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -51,6 +51,70 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, return cast(in); } +static bool hasDynamicDim(ArrayRef dims) { + for (auto &&dim : dims) { + auto constant = getConstantIntValue(dim); + if (!constant || *constant < 0) { + return true; + } + } + return false; +} + +static OpFoldResult computeStaticShape(OpBuilder &builder, Location loc, + ArrayRef dims, + ArrayRef strides) { + // max(dims[i] * strides[i]) for i = 0, 1, ..., n-1 + int64_t maxSize = 1; + for (auto &&[dim, stride] : llvm::zip(dims, strides)) { + AffineExpr s0, s1; + bindSymbols(builder.getContext(), s0, s1); + OpFoldResult size = affine::makeComposedFoldedAffineApply( + builder, loc, s0 * s1, ArrayRef{dim, stride}); + auto constant = getConstantIntValue(size); + assert(constant && "expected constant value"); + maxSize = *constant; + } + return builder.getIndexAttr(maxSize); +} + +static OpFoldResult computeDynamicShape(OpBuilder &builder, Location loc, + ArrayRef dims, + ArrayRef strides) { + + SmallVector symbols(2 * dims.size()); + bindSymbolsList(builder.getContext(), MutableArrayRef{symbols}); + SmallVector productExpressions; + SmallVector values; + size_t symbolIndex = 0; + for (auto &&[dim, stride] : llvm::zip(dims, strides)) { + AffineExpr dimExpr = symbols[symbolIndex++]; + AffineExpr strideExpr = symbols[symbolIndex++]; + productExpressions.push_back(dimExpr * strideExpr); + values.push_back(getValueFromOpFoldResult(builder, loc, dim)); + values.push_back(getValueFromOpFoldResult(builder, loc, stride)); + } + + AffineMap maxMap = AffineMap::get(0, symbols.size(), productExpressions, + builder.getContext()); + Value maxValue = + builder.create(loc, maxMap, values).getResult(); + return maxValue; +} + +/// Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the +/// span of the memref. +static OpFoldResult computeSpan(OpBuilder &builder, Location loc, + ArrayRef dims, + ArrayRef strides) { + assert(dims.size() == strides.size() && + "number of dimensions and strides should be equal"); + if (hasDynamicDim(dims) || hasDynamicDim(strides)) { + return computeDynamicShape(builder, loc, dims, strides); + } + return computeStaticShape(builder, loc, dims, strides); +} + /// Returns a collapsed memref and the linearized index to access the element /// at the specified indices. static std::pair getFlattenMemrefAndOffset(OpBuilder &rewriter, @@ -82,10 +146,12 @@ static std::pair getFlattenMemrefAndOffset(OpBuilder &rewriter, rewriter.create( loc, source, /* offset = */ linearizedInfo.linearizedOffset, - /* shapes = */ ArrayRef{linearizedInfo.linearizedSize}, + /* shapes = */ + ArrayRef{computeSpan( + rewriter, loc, stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides())}, /* strides = */ - ArrayRef{ - stridedMetadata.getConstifiedMixedStrides().back()}), + ArrayRef{rewriter.getIndexAttr(1)}), getValueFromOpFoldResult(rewriter, loc, linearizedIndices)); } From 4219f6b7e76e71afef72d2cd54e3305cbc0c64ab Mon Sep 17 00:00:00 2001 From: Alan Li Date: Wed, 7 May 2025 09:19:43 -0400 Subject: [PATCH 16/18] update tests --- .../MemRef/Transforms/FlattenMemRefs.cpp | 6 +-- mlir/test/Dialect/MemRef/flatten_memref.mlir | 38 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index d212b59782f3f..5dc4f9ffb151e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -73,7 +73,7 @@ static OpFoldResult computeStaticShape(OpBuilder &builder, Location loc, builder, loc, s0 * s1, ArrayRef{dim, stride}); auto constant = getConstantIntValue(size); assert(constant && "expected constant value"); - maxSize = *constant; + maxSize = std::max(maxSize, *constant); } return builder.getIndexAttr(maxSize); } @@ -104,7 +104,7 @@ static OpFoldResult computeDynamicShape(OpBuilder &builder, Location loc, /// Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the /// span of the memref. -static OpFoldResult computeSpan(OpBuilder &builder, Location loc, +static OpFoldResult computeSize(OpBuilder &builder, Location loc, ArrayRef dims, ArrayRef strides) { assert(dims.size() == strides.size() && @@ -147,7 +147,7 @@ static std::pair getFlattenMemrefAndOffset(OpBuilder &rewriter, loc, source, /* offset = */ linearizedInfo.linearizedOffset, /* shapes = */ - ArrayRef{computeSpan( + ArrayRef{computeSize( rewriter, loc, stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides())}, /* strides = */ diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir index 8af34c13c2a36..d93eedbc3efd2 100644 --- a/mlir/test/Dialect/MemRef/flatten_memref.mlir +++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir @@ -14,17 +14,17 @@ func.func @load_scalar_from_memref(%input: memref<4x8xf32, strided<[8, 1], offse // ----- -func.func @load_scalar_from_memref_static_dim_2(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index) -> f32 { - %value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>> +func.func @load_scalar_from_memref_static_dim_col_major(%input: memref<4x8xf32, strided<[1, 4], offset: 100>>, %row: index, %col: index) -> f32 { + %value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[1, 4], offset: 100>> return %value : f32 } -// CHECK: [[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)> -// CHECK: func @load_scalar_from_memref_static_dim_2 -// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +// CHECK: [[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 4)> +// CHECK: func @load_scalar_from_memref_static_dim_col_major +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[1, 4], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) // CHECK: %[[IDX:.*]] = affine.apply [[MAP]]()[%[[ARG2]], %[[ARG1]]] -// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [12] -// CHECK-SAME: to memref<32xf32, strided<[12], offset: 100>> +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] +// CHECK-SAME: to memref<32xf32, strided<[1], offset: 100>> // CHECK: memref.load %[[REINT]][%[[IDX]]] // ----- @@ -35,27 +35,27 @@ func.func @load_scalar_from_memref_dynamic_dim(%input: memref (s0 * s1 + s2 * s3)> -// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)> // CHECK: func @load_scalar_from_memref_dynamic_dim // CHECK-SAME: (%[[ARG0:.*]]: memref>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]] // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] -// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[SIZES]]#0, %[[SIZES]]#1] -// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1] +// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[SIZES]]#0, %[[STRIDES]]#0, %[[SIZES]]#1, %[[STRIDES]]#1] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [1] : memref> to memref> // CHECK: memref.load %[[REINT]][%[[IDX]]] // ----- -func.func @store_scalar_from_memref_static_dim(%input: memref<4x8xf32, strided<[8, 12], offset: 100>>, %row: index, %col: index, %value: f32) { - memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[8, 12], offset: 100>> +func.func @store_scalar_from_memref_padded(%input: memref<4x8xf32, strided<[18, 2], offset: 100>>, %row: index, %col: index, %value: f32) { + memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[18, 2], offset: 100>> return } -// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1 * 12)> -// CHECK: func @store_scalar_from_memref_static_dim -// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[8, 12], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32) +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 18 + s1 * 2)> +// CHECK: func @store_scalar_from_memref_padded +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[18, 2], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32) // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]] // CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] -// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[12], offset: 100>> +// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] : memref<72xf32, strided<[1], offset: 100>> // ----- @@ -64,13 +64,13 @@ func.func @store_scalar_from_memref_dynamic_dim(%input: memref (s0 * s1 + s2 * s3)> -// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)> // CHECK: func @store_scalar_from_memref_dynamic_dim // CHECK-SAME: (%[[ARG0:.*]]: memref>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32) // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]] // CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] -// CHECK: %[[SIZE:.*]] = affine.apply #[[MAP1]]()[%[[SIZES]]#0, %[[SIZES]]#1] -// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [%[[STRIDES]]#1] +// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[SIZES]]#0, %[[STRIDES]]#0, %[[SIZES]]#1, %[[STRIDES]]#1] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [1] // CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] // ----- From 0322a199b3bddd64b5514c25893ca63f80e134c6 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Wed, 7 May 2025 09:38:52 -0400 Subject: [PATCH 17/18] More tests --- mlir/test/Dialect/MemRef/flatten_memref.mlir | 49 +++++++++++++------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir index d93eedbc3efd2..486963395a51a 100644 --- a/mlir/test/Dialect/MemRef/flatten_memref.mlir +++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir @@ -7,25 +7,11 @@ func.func @load_scalar_from_memref(%input: memref<4x8xf32, strided<[8, 1], offse return %value : f32 } // CHECK-LABEL: func @load_scalar_from_memref -// CHECK: %[[C10:.*]] = arith.constant 10 : index -// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [100], sizes: [32], strides: [1] +// CHECK-NEXT: %[[C10:.*]] = arith.constant 10 : index +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [100], sizes: [32], strides: [1] // CHECK-SAME: memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>> -// CHECK: memref.load %[[REINT]][%[[C10]]] : memref<32xf32, strided<[1], offset: 100>> - -// ----- - -func.func @load_scalar_from_memref_static_dim_col_major(%input: memref<4x8xf32, strided<[1, 4], offset: 100>>, %row: index, %col: index) -> f32 { - %value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[1, 4], offset: 100>> - return %value : f32 -} +// CHECK-NEXT: memref.load %[[REINT]][%[[C10]]] : memref<32xf32, strided<[1], offset: 100>> -// CHECK: [[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 4)> -// CHECK: func @load_scalar_from_memref_static_dim_col_major -// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[1, 4], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -// CHECK: %[[IDX:.*]] = affine.apply [[MAP]]()[%[[ARG2]], %[[ARG1]]] -// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] -// CHECK-SAME: to memref<32xf32, strided<[1], offset: 100>> -// CHECK: memref.load %[[REINT]][%[[IDX]]] // ----- @@ -46,6 +32,21 @@ func.func @load_scalar_from_memref_dynamic_dim(%input: memref>) -> f32 { + %c7 = arith.constant 7 : index + %c10 = arith.constant 10 : index + %value = memref.load %input[%c7, %c10] : memref<8x12xf32, strided<[24, 2], offset: 100>> + return %value : f32 +} + +// CHECK-LABEL: func @load_scalar_from_memref_static_dim +// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12xf32, strided<[24, 2], offset: 100>>) +// CHECK: %[[C188:.*]] = arith.constant 188 : index +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [192], strides: [1] : memref<8x12xf32, strided<[24, 2], offset: 100>> to memref<192xf32, strided<[1], offset: 100>> +// CHECK: memref.load %[[REINT]][%[[C188]]] : memref<192xf32, strided<[1], offset: 100>> + +// ----- + func.func @store_scalar_from_memref_padded(%input: memref<4x8xf32, strided<[18, 2], offset: 100>>, %row: index, %col: index, %value: f32) { memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[18, 2], offset: 100>> return @@ -256,3 +257,17 @@ func.func @chained_alloc_load() -> vector<8xf32> { // CHECK-NEXT: %[[C30:.*]] = arith.constant 30 : index // CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32, strided<[1]>> // CHECK-NEXT: vector.load %[[ALLOC]][%[[C30]]] : memref<32xf32, strided<[1]>>, vector<8xf32> + +// ----- + +func.func @load_scalar_from_memref_static_dim_col_major(%input: memref<4x8xf32, strided<[1, 4], offset: 100>>, %row: index, %col: index) -> f32 { + %value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[1, 4], offset: 100>> + return %value : f32 +} + +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 4)> +// CHECK: func @load_scalar_from_memref_static_dim_col_major +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[1, 4], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[1, 4], offset: 100>> to memref<32xf32, strided<[1], offset: 100>> +// CHECK: memref.load %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[1], offset: 100>> From fd9c7f9501bb0011482f2e101ebc632a5fd4da73 Mon Sep 17 00:00:00 2001 From: Alan Li Date: Thu, 8 May 2025 12:16:14 -0400 Subject: [PATCH 18/18] simplify folds --- .../MemRef/Transforms/FlattenMemRefs.cpp | 61 ++++--------------- 1 file changed, 11 insertions(+), 50 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 5dc4f9ffb151e..936463f716d9a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -51,68 +51,29 @@ static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, return cast(in); } -static bool hasDynamicDim(ArrayRef dims) { - for (auto &&dim : dims) { - auto constant = getConstantIntValue(dim); - if (!constant || *constant < 0) { - return true; - } - } - return false; -} - -static OpFoldResult computeStaticShape(OpBuilder &builder, Location loc, - ArrayRef dims, - ArrayRef strides) { - // max(dims[i] * strides[i]) for i = 0, 1, ..., n-1 - int64_t maxSize = 1; - for (auto &&[dim, stride] : llvm::zip(dims, strides)) { - AffineExpr s0, s1; - bindSymbols(builder.getContext(), s0, s1); - OpFoldResult size = affine::makeComposedFoldedAffineApply( - builder, loc, s0 * s1, ArrayRef{dim, stride}); - auto constant = getConstantIntValue(size); - assert(constant && "expected constant value"); - maxSize = std::max(maxSize, *constant); - } - return builder.getIndexAttr(maxSize); -} - -static OpFoldResult computeDynamicShape(OpBuilder &builder, Location loc, - ArrayRef dims, - ArrayRef strides) { - +/// Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the +/// span of the memref. +static OpFoldResult computeSize(OpBuilder &builder, Location loc, + ArrayRef dims, + ArrayRef strides) { + assert(dims.size() == strides.size() && + "number of dimensions and strides should be equal"); SmallVector symbols(2 * dims.size()); bindSymbolsList(builder.getContext(), MutableArrayRef{symbols}); SmallVector productExpressions; - SmallVector values; + SmallVector values; size_t symbolIndex = 0; for (auto &&[dim, stride] : llvm::zip(dims, strides)) { AffineExpr dimExpr = symbols[symbolIndex++]; AffineExpr strideExpr = symbols[symbolIndex++]; productExpressions.push_back(dimExpr * strideExpr); - values.push_back(getValueFromOpFoldResult(builder, loc, dim)); - values.push_back(getValueFromOpFoldResult(builder, loc, stride)); + values.push_back(dim); + values.push_back(stride); } AffineMap maxMap = AffineMap::get(0, symbols.size(), productExpressions, builder.getContext()); - Value maxValue = - builder.create(loc, maxMap, values).getResult(); - return maxValue; -} - -/// Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the -/// span of the memref. -static OpFoldResult computeSize(OpBuilder &builder, Location loc, - ArrayRef dims, - ArrayRef strides) { - assert(dims.size() == strides.size() && - "number of dimensions and strides should be equal"); - if (hasDynamicDim(dims) || hasDynamicDim(strides)) { - return computeDynamicShape(builder, loc, dims, strides); - } - return computeStaticShape(builder, loc, dims, strides); + return affine::makeComposedFoldedAffineMax(builder, loc, maxMap, values); } /// Returns a collapsed memref and the linearized index to access the element