Skip to content

[MLIR] Add a utility pass to linearize memref #136797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,15 @@ def ExpandReallocPass : Pass<"expand-realloc"> {
];
}

def FlattenMemrefsPass : Pass<"flatten-memref"> {
let summary = "Flatten a multiple dimensional memref to 1-dimensional";
let description = [{

}];
let dependentDialects = [
"affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect"
];
}

#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES

2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
/// ```
void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns);

void populateFlattenMemrefsPatterns(RewritePatternSet &patterns);

/// Build a new memref::AllocaOp whose dynamic sizes are independent of all
/// given independencies. If the op is already independent of all
/// independencies, the same AllocaOp result is returned.
Expand Down
48 changes: 26 additions & 22 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5083,6 +5083,31 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
return ret;
}

namespace mlir {
namespace affine {
OpFoldResult computeProduct(Location loc, OpBuilder &builder,
ArrayRef<OpFoldResult> terms) {
int64_t nDynamic = 0;
SmallVector<Value> dynamicPart;
AffineExpr result = builder.getAffineConstantExpr(1);
for (OpFoldResult term : terms) {
if (!term)
return term;
std::optional<int64_t> maybeConst = getConstantIntValue(term);
if (maybeConst) {
result = result * builder.getAffineConstantExpr(*maybeConst);
} else {
dynamicPart.push_back(cast<Value>(term));
result = result * builder.getAffineSymbolExpr(nDynamic++);
}
}
if (auto constant = dyn_cast<AffineConstantExpr>(result))
return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
}
} // namespace affine
} // namespace mlir

namespace {
/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
Expand Down Expand Up @@ -5142,27 +5167,6 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
}
};

OpFoldResult computeProduct(Location loc, OpBuilder &builder,
ArrayRef<OpFoldResult> terms) {
int64_t nDynamic = 0;
SmallVector<Value> dynamicPart;
AffineExpr result = builder.getAffineConstantExpr(1);
for (OpFoldResult term : terms) {
if (!term)
return term;
std::optional<int64_t> maybeConst = getConstantIntValue(term);
if (maybeConst) {
result = result * builder.getAffineConstantExpr(*maybeConst);
} else {
dynamicPart.push_back(cast<Value>(term));
result = result * builder.getAffineSymbolExpr(nDynamic++);
}
}
if (auto constant = dyn_cast<AffineConstantExpr>(result))
return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
}

/// If conseceutive outputs of a delinearize_index are linearized with the same
/// bounds, canonicalize away the redundant arithmetic.
///
Expand Down Expand Up @@ -5309,7 +5313,7 @@ struct CancelLinearizeOfDelinearizePortion final
// We use the slice from the linearize's basis above because of the
// "bounds inferred from `disjoint`" case above.
OpFoldResult newSize =
computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
affine::computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);

// Trivial case where we can just skip past the delinearize all together
if (m.length == m.delinearize.getNumResults()) {
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
EmulateWideInt.cpp
EmulateNarrowType.cpp
ExtractAddressComputations.cpp
FlattenMemRefs.cpp
FoldMemRefAliasOps.cpp
IndependenceTransforms.cpp
MultiBuffer.cpp
Expand All @@ -23,6 +24,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms

LINK_LIBS PUBLIC
MLIRAffineTransforms
MLIRAffineDialect
MLIRAffineUtils
MLIRArithDialect
MLIRArithTransforms
Expand Down
Loading
Loading