diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td index a46f73350bb3c..a8d135caa74f0 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -245,5 +245,15 @@ def ExpandReallocPass : Pass<"expand-realloc"> { ]; } +def FlattenMemrefsPass : Pass<"flatten-memref"> { + let summary = "Flatten a multiple dimensional memref to 1-dimensional"; + let description = [{ + + }]; + let dependentDialects = [ + "affine::AffineDialect", "memref::MemRefDialect", "vector::VectorDialect" + ]; +} + #endif // MLIR_DIALECT_MEMREF_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h index 62a2297c80e78..c2b8cb05be922 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -144,6 +144,8 @@ FailureOr multiBuffer(memref::AllocOp allocOp, /// ``` void populateExtractAddressComputationsPatterns(RewritePatternSet &patterns); +void populateFlattenMemrefsPatterns(RewritePatternSet &patterns); + /// Build a new memref::AllocaOp whose dynamic sizes are independent of all /// given independencies. If the op is already independent of all /// independencies, the same AllocaOp result is returned. diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index eb23403a68813..e729e73f6ae0c 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -5083,6 +5083,31 @@ SmallVector AffineLinearizeIndexOp::getPaddedBasis() { return ret; } +namespace mlir { +namespace affine { +OpFoldResult computeProduct(Location loc, OpBuilder &builder, + ArrayRef terms) { + int64_t nDynamic = 0; + SmallVector dynamicPart; + AffineExpr result = builder.getAffineConstantExpr(1); + for (OpFoldResult term : terms) { + if (!term) + return term; + std::optional maybeConst = getConstantIntValue(term); + if (maybeConst) { + result = result * builder.getAffineConstantExpr(*maybeConst); + } else { + dynamicPart.push_back(cast(term)); + result = result * builder.getAffineSymbolExpr(nDynamic++); + } + } + if (auto constant = dyn_cast(result)) + return getAsIndexOpFoldResult(builder.getContext(), constant.getValue()); + return builder.create(loc, result, dynamicPart).getResult(); +} +} // namespace affine +} // namespace mlir + namespace { /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1, /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c, @@ -5142,27 +5167,6 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final } }; -OpFoldResult computeProduct(Location loc, OpBuilder &builder, - ArrayRef terms) { - int64_t nDynamic = 0; - SmallVector dynamicPart; - AffineExpr result = builder.getAffineConstantExpr(1); - for (OpFoldResult term : terms) { - if (!term) - return term; - std::optional maybeConst = getConstantIntValue(term); - if (maybeConst) { - result = result * builder.getAffineConstantExpr(*maybeConst); - } else { - dynamicPart.push_back(cast(term)); - result = result * builder.getAffineSymbolExpr(nDynamic++); - } - } - if (auto constant = dyn_cast(result)) - return getAsIndexOpFoldResult(builder.getContext(), constant.getValue()); - return builder.create(loc, result, dynamicPart).getResult(); -} - /// If conseceutive outputs of a delinearize_index are linearized with the same /// bounds, canonicalize away the redundant arithmetic. /// @@ -5309,7 +5313,7 @@ struct CancelLinearizeOfDelinearizePortion final // We use the slice from the linearize's basis above because of the // "bounds inferred from `disjoint`" case above. OpFoldResult newSize = - computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge); + affine::computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge); // Trivial case where we can just skip past the delinearize all together if (m.length == m.delinearize.getNumResults()) { diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index ecab97bc2b8e7..48e8bccd369fa 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms EmulateWideInt.cpp EmulateNarrowType.cpp ExtractAddressComputations.cpp + FlattenMemRefs.cpp FoldMemRefAliasOps.cpp IndependenceTransforms.cpp MultiBuffer.cpp @@ -23,6 +24,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms LINK_LIBS PUBLIC MLIRAffineTransforms + MLIRAffineDialect MLIRAffineUtils MLIRArithDialect MLIRArithTransforms diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp new file mode 100644 index 0000000000000..5dc4f9ffb151e --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -0,0 +1,322 @@ +//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns for flattening an multi-rank memref-related +// ops into 1-d memref ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" + +#include + +namespace mlir { +namespace memref { +#define GEN_PASS_DEF_FLATTENMEMREFSPASS +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" +} // namespace memref +} // namespace mlir + +using namespace mlir; + +static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, + OpFoldResult in) { + if (Attribute offsetAttr = dyn_cast(in)) { + return rewriter.create( + loc, cast(offsetAttr).getInt()); + } + return cast(in); +} + +static bool hasDynamicDim(ArrayRef dims) { + for (auto &&dim : dims) { + auto constant = getConstantIntValue(dim); + if (!constant || *constant < 0) { + return true; + } + } + return false; +} + +static OpFoldResult computeStaticShape(OpBuilder &builder, Location loc, + ArrayRef dims, + ArrayRef strides) { + // max(dims[i] * strides[i]) for i = 0, 1, ..., n-1 + int64_t maxSize = 1; + for (auto &&[dim, stride] : llvm::zip(dims, strides)) { + AffineExpr s0, s1; + bindSymbols(builder.getContext(), s0, s1); + OpFoldResult size = affine::makeComposedFoldedAffineApply( + builder, loc, s0 * s1, ArrayRef{dim, stride}); + auto constant = getConstantIntValue(size); + assert(constant && "expected constant value"); + maxSize = std::max(maxSize, *constant); + } + return builder.getIndexAttr(maxSize); +} + +static OpFoldResult computeDynamicShape(OpBuilder &builder, Location loc, + ArrayRef dims, + ArrayRef strides) { + + SmallVector symbols(2 * dims.size()); + bindSymbolsList(builder.getContext(), MutableArrayRef{symbols}); + SmallVector productExpressions; + SmallVector values; + size_t symbolIndex = 0; + for (auto &&[dim, stride] : llvm::zip(dims, strides)) { + AffineExpr dimExpr = symbols[symbolIndex++]; + AffineExpr strideExpr = symbols[symbolIndex++]; + productExpressions.push_back(dimExpr * strideExpr); + values.push_back(getValueFromOpFoldResult(builder, loc, dim)); + values.push_back(getValueFromOpFoldResult(builder, loc, stride)); + } + + AffineMap maxMap = AffineMap::get(0, symbols.size(), productExpressions, + builder.getContext()); + Value maxValue = + builder.create(loc, maxMap, values).getResult(); + return maxValue; +} + +/// Given dimension size [d1, d2, ...] and strides [s1, s2, ...], compute the +/// span of the memref. +static OpFoldResult computeSize(OpBuilder &builder, Location loc, + ArrayRef dims, + ArrayRef strides) { + assert(dims.size() == strides.size() && + "number of dimensions and strides should be equal"); + if (hasDynamicDim(dims) || hasDynamicDim(strides)) { + return computeDynamicShape(builder, loc, dims, strides); + } + return computeStaticShape(builder, loc, dims, strides); +} + +/// Returns a collapsed memref and the linearized index to access the element +/// at the specified indices. +static std::pair getFlattenMemrefAndOffset(OpBuilder &rewriter, + Location loc, + Value source, + ValueRange indices) { + int64_t sourceOffset; + SmallVector sourceStrides; + auto sourceType = cast(source.getType()); + if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) { + assert(false); + } + + memref::ExtractStridedMetadataOp stridedMetadata = + rewriter.create(loc, source); + + auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth(); + OpFoldResult linearizedIndices; + memref::LinearizedMemRefInfo linearizedInfo; + std::tie(linearizedInfo, linearizedIndices) = + memref::getLinearizedMemRefOffsetAndSize( + rewriter, loc, typeBit, typeBit, + stridedMetadata.getConstifiedMixedOffset(), + stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides(), + getAsOpFoldResult(indices)); + + return std::make_pair( + rewriter.create( + loc, source, + /* offset = */ linearizedInfo.linearizedOffset, + /* shapes = */ + ArrayRef{computeSize( + rewriter, loc, stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides())}, + /* strides = */ + ArrayRef{rewriter.getIndexAttr(1)}), + getValueFromOpFoldResult(rewriter, loc, linearizedIndices)); +} + +static bool needFlattening(Value val) { + auto type = cast(val.getType()); + return type.getRank() > 1; +} + +static bool checkLayout(Value val) { + auto type = cast(val.getType()); + return type.getLayout().isIdentity() || + isa(type.getLayout()); +} + +namespace { +static Value getTargetMemref(Operation *op) { + return llvm::TypeSwitch(op) + .template Case([](auto op) { return op.getMemref(); }) + .template Case( + [](auto op) { return op.getBase(); }) + .template Case( + [](auto op) { return op.getSource(); }) + .Default([](auto) { return Value{}; }); +} + +template +static void castResult(T oper, T newOper, Location loc, + PatternRewriter &rewriter) { + memref::ExtractStridedMetadataOp stridedMetadata = + rewriter.create(loc, oper); + rewriter.replaceOpWithNewOp( + oper, cast(oper.getType()), newOper, + /*offset=*/rewriter.getIndexAttr(0), + stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides()); +} + +template +static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, + Value offset) { + auto loc = op->getLoc(); + llvm::TypeSwitch(op.getOperation()) + .template Case([&](auto oper) { + auto newAlloc = rewriter.create( + loc, cast(flatMemref.getType()), + oper.getAlignmentAttr()); + castResult(oper, newAlloc, loc, rewriter); + }) + .template Case([&](auto oper) { + auto newAlloca = rewriter.create( + loc, cast(flatMemref.getType()), + oper.getAlignmentAttr()); + castResult(oper, newAlloca, loc, rewriter); + }) + .template Case([&](auto op) { + auto newLoad = rewriter.create( + loc, op->getResultTypes(), flatMemref, ValueRange{offset}); + newLoad->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newLoad.getResult()); + }) + .template Case([&](auto op) { + auto newStore = rewriter.create( + loc, op->getOperands().front(), flatMemref, ValueRange{offset}); + newStore->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newStore); + }) + .template Case([&](auto op) { + auto newLoad = rewriter.create( + loc, op->getResultTypes(), flatMemref, ValueRange{offset}); + newLoad->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newLoad.getResult()); + }) + .template Case([&](auto op) { + auto newStore = rewriter.create( + loc, op->getOperands().front(), flatMemref, ValueRange{offset}); + newStore->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newStore); + }) + .template Case([&](auto op) { + auto newMaskedLoad = rewriter.create( + loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), + op.getPassThru()); + newMaskedLoad->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newMaskedLoad.getResult()); + }) + .template Case([&](auto op) { + auto newMaskedStore = rewriter.create( + loc, flatMemref, ValueRange{offset}, op.getMask(), + op.getValueToStore()); + newMaskedStore->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newMaskedStore); + }) + .template Case([&](auto op) { + auto newTransferRead = rewriter.create( + loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); + rewriter.replaceOp(op, newTransferRead.getResult()); + }) + .template Case([&](auto op) { + auto newTransferWrite = rewriter.create( + loc, op.getVector(), flatMemref, ValueRange{offset}); + rewriter.replaceOp(op, newTransferWrite); + }) + .Default([&](auto op) { + op->emitOpError("unimplemented: do not know how to replace op."); + }); +} + +template +static ValueRange getIndices(T op) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ValueRange{}; + } else { + return op.getIndices(); + } +} + +template +struct MemRefRewritePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(T op, + PatternRewriter &rewriter) const override { + Value memref = getTargetMemref(op); + if (!needFlattening(memref) || !checkLayout(memref)) + return failure(); + auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( + rewriter, op->getLoc(), memref, getIndices(op)); + replaceOp(op, rewriter, flatMemref, offset); + return success(); + } +}; + +struct FlattenMemrefsPass + : public mlir::memref::impl::FlattenMemrefsPassBase { + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + + memref::populateFlattenMemrefsPatterns(patterns); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) { + patterns.insert, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern, + MemRefRewritePattern>( + patterns.getContext()); +} diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir new file mode 100644 index 0000000000000..486963395a51a --- /dev/null +++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir @@ -0,0 +1,273 @@ +// RUN: mlir-opt --flatten-memref %s --split-input-file --verify-diagnostics | FileCheck %s + +func.func @load_scalar_from_memref(%input: memref<4x8xf32, strided<[8, 1], offset: 100>>) -> f32 { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %value = memref.load %input[%c1, %c2] : memref<4x8xf32, strided<[8, 1], offset: 100>> + return %value : f32 +} +// CHECK-LABEL: func @load_scalar_from_memref +// CHECK-NEXT: %[[C10:.*]] = arith.constant 10 : index +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [100], sizes: [32], strides: [1] +// CHECK-SAME: memref<4x8xf32, strided<[8, 1], offset: 100>> to memref<32xf32, strided<[1], offset: 100>> +// CHECK-NEXT: memref.load %[[REINT]][%[[C10]]] : memref<32xf32, strided<[1], offset: 100>> + + +// ----- + +func.func @load_scalar_from_memref_dynamic_dim(%input: memref>, %row: index, %col: index) -> f32 { + %value = memref.load %input[%col, %row] : memref> + return %value : f32 +} + +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s2 * s3)> +// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)> +// CHECK: func @load_scalar_from_memref_dynamic_dim +// CHECK-SAME: (%[[ARG0:.*]]: memref>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]] +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] +// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[SIZES]]#0, %[[STRIDES]]#0, %[[SIZES]]#1, %[[STRIDES]]#1] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [1] : memref> to memref> +// CHECK: memref.load %[[REINT]][%[[IDX]]] + +// ----- + +func.func @load_scalar_from_memref_static_dim(%input: memref<8x12xf32, strided<[24, 2], offset: 100>>) -> f32 { + %c7 = arith.constant 7 : index + %c10 = arith.constant 10 : index + %value = memref.load %input[%c7, %c10] : memref<8x12xf32, strided<[24, 2], offset: 100>> + return %value : f32 +} + +// CHECK-LABEL: func @load_scalar_from_memref_static_dim +// CHECK-SAME: (%[[ARG0:.*]]: memref<8x12xf32, strided<[24, 2], offset: 100>>) +// CHECK: %[[C188:.*]] = arith.constant 188 : index +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [192], strides: [1] : memref<8x12xf32, strided<[24, 2], offset: 100>> to memref<192xf32, strided<[1], offset: 100>> +// CHECK: memref.load %[[REINT]][%[[C188]]] : memref<192xf32, strided<[1], offset: 100>> + +// ----- + +func.func @store_scalar_from_memref_padded(%input: memref<4x8xf32, strided<[18, 2], offset: 100>>, %row: index, %col: index, %value: f32) { + memref.store %value, %input[%col, %row] : memref<4x8xf32, strided<[18, 2], offset: 100>> + return +} +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 18 + s1 * 2)> +// CHECK: func @store_scalar_from_memref_padded +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[18, 2], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32) +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] +// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] : memref<72xf32, strided<[1], offset: 100>> + +// ----- + +func.func @store_scalar_from_memref_dynamic_dim(%input: memref>, %row: index, %col: index, %value: f32) { + memref.store %value, %input[%col, %row] : memref> + return +} +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s2 * s3)> +// CHECK: #[[MAP1:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1, s2 * s3)> +// CHECK: func @store_scalar_from_memref_dynamic_dim +// CHECK-SAME: (%[[ARG0:.*]]: memref>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: f32) +// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]] +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[STRIDES]]#0, %[[ARG1]], %[[STRIDES]]#1] +// CHECK: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[SIZES]]#0, %[[STRIDES]]#0, %[[SIZES]]#1, %[[STRIDES]]#1] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[OFFSET]]], sizes: [%[[SIZE]]], strides: [1] +// CHECK: memref.store %[[ARG3]], %[[REINT]][%[[IDX]]] + +// ----- + +func.func @load_vector_from_memref(%input: memref<4x8xf32>) -> vector<8xf32> { + %c3 = arith.constant 3 : index + %c6 = arith.constant 6 : index + %value = vector.load %input[%c3, %c6] : memref<4x8xf32>, vector<8xf32> + return %value : vector<8xf32> +} +// CHECK-LABEL: func @load_vector_from_memref +// CHECK: %[[C30:.*]] = arith.constant 30 +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %arg0 to offset: [0], sizes: [32], strides: [1] +// CHECK-NEXT: vector.load %[[REINT]][%[[C30]]] + +// ----- + +func.func @load_vector_from_memref_odd(%input: memref<3x7xi2>) -> vector<3xi2> { + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %value = vector.load %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2> + return %value : vector<3xi2> +} +// CHECK-LABEL: func @load_vector_from_memref_odd +// CHECK: %[[C10:.*]] = arith.constant 10 : index +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast +// CHECK-NEXT: vector.load %[[REINT]][%[[C10]]] + +// ----- + +func.func @load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: index, %col: index) -> vector<3xi2> { + %value = vector.load %input[%col, %row] : memref<3x7xi2>, vector<3xi2> + return %value : vector<3xi2> +} +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)> +// CHECK: func @load_vector_from_memref_dynamic +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]() +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast +// CHECK: vector.load %[[REINT]][%[[IDX]]] : memref<21xi2, strided<[1]>>, vector<3xi2> + +// ----- + +func.func @store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi2>) { + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + vector.store %value, %input[%c1, %c3] : memref<3x7xi2>, vector<3xi2> + return +} +// CHECK-LABEL: func @store_vector_to_memref_odd +// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>) +// CHECK: %[[C10:.*]] = arith.constant 10 : index +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast +// CHECK-NEXT: vector.store %[[ARG1]], %[[REINT]][%[[C10]]] : memref<21xi2, strided<[1]> + +// ----- + +func.func @store_vector_to_memref_dynamic(%input: memref<3x7xi2>, %value: vector<3xi2>, %row: index, %col: index) { + vector.store %value, %input[%col, %row] : memref<3x7xi2>, vector<3xi2> + return +} +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)> +// CHECK: func @store_vector_to_memref_dynamic +// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1] +// CHECK: vector.store %[[ARG1]], %[[REINT]][%[[IDX]]] + +// ----- + +func.func @mask_store_vector_to_memref_odd(%input: memref<3x7xi2>, %value: vector<3xi2>, %mask: vector<3xi1>) { + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + vector.maskedstore %input[%c1, %c3], %mask, %value : memref<3x7xi2>, vector<3xi1>, vector<3xi2> + return +} +// CHECK-LABEL: func @mask_store_vector_to_memref_odd +// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: vector<3xi1>) +// CHECK: %[[C10:.*]] = arith.constant 10 : index +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast +// CHECK: vector.maskedstore %[[REINT]][%[[C10]]], %[[ARG2]], %[[ARG1]] + +// ----- + +func.func @mask_store_vector_to_memref_dynamic(%input: memref<3x7xi2>, %value: vector<3xi2>, %row: index, %col: index, %mask: vector<3xi1>) { + vector.maskedstore %input[%col, %row], %mask, %value : memref<3x7xi2>, vector<3xi1>, vector<3xi2> + return +} +// CHECK: #map = affine_map<()[s0, s1] -> (s0 * 7 + s1)> +// CHECK: func @mask_store_vector_to_memref_dynamic +// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: vector<3xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: vector<3xi1>) +// CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] +// CHECK: vector.maskedstore %[[REINT]][%[[IDX]]], %[[ARG4]], %[[ARG1]] + +// ----- +func.func @mask_load_vector_from_memref_odd(%input: memref<3x7xi2>, %mask: vector<3xi1>, %passthru: vector<3xi2>) -> vector<3xi2> { + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %result = vector.maskedload %input[%c1, %c3], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2> + return %result : vector<3xi2> +} +// CHECK-LABEL: func @mask_load_vector_from_memref_odd +// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[MASK:.*]]: vector<3xi1>, %[[PASSTHRU:.*]]: vector<3xi2>) +// CHECK: %[[C10:.*]] = arith.constant 10 : index +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [21], strides: [1] +// CHECK: vector.maskedload %[[REINT]][%[[C10]]], %[[MASK]], %[[PASSTHRU]] + +// ----- + +func.func @mask_load_vector_from_memref_dynamic(%input: memref<3x7xi2>, %row: index, %col: index, %mask: vector<3xi1>, %passthru: vector<3xi2>) -> vector<3xi2> { + %result = vector.maskedload %input[%col, %row], %mask, %passthru : memref<3x7xi2>, vector<3xi1>, vector<3xi2> into vector<3xi2> + return %result : vector<3xi2> +} +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 7 + s1)> +// CHECK: func @mask_load_vector_from_memref_dynamic +// CHECK-SAME: (%[[ARG0:.*]]: memref<3x7xi2>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<3xi1>, %[[ARG4:.*]]: vector<3xi2>) +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] +// CHECK: vector.maskedload %[[REINT]][%[[IDX]]], %[[ARG3]] + +// ----- + +func.func @transfer_read_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) -> vector<8xi2> { + %c0 = arith.constant 0 : i2 + %0 = vector.transfer_read %input[%col, %row], %c0 : memref<4x8xi2>, vector<8xi2> + return %0 : vector<8xi2> +} +// CHECK-LABEL: func @transfer_read_memref +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) +// CHECK: %[[C0:.*]] = arith.constant 0 : i2 +// CHECK: %[[IDX:.*]] = affine.apply #map()[%[[ARG3]], %[[ARG2]]] +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] +// CHECK-NEXT: vector.transfer_read %[[REINT]][%[[IDX]]], %[[C0]] + +// ----- + +func.func @transfer_write_memref(%input: memref<4x8xi2>, %value: vector<8xi2>, %row: index, %col: index) { + vector.transfer_write %value, %input[%col, %row] : vector<8xi2>, memref<4x8xi2> + return +} +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)> +// CHECK: func @transfer_write_memref +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xi2>, %[[ARG1:.*]]: vector<8xi2>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG3]], %[[ARG2]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] +// CHECK: vector.transfer_write %[[ARG1]], %[[REINT]][%[[IDX]]] + +// ----- + +func.func @alloc() -> memref<4x8xf32> { + %0 = memref.alloc() : memref<4x8xf32> + return %0 : memref<4x8xf32> +} + +// CHECK-LABEL: func @alloc +// CHECK-SAME: () -> memref<4x8xf32> +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32, strided<[1]>> +// CHECK-NEXT: %[[REINT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [8, 1] : memref<32xf32, strided<[1]>> to memref<4x8xf32> + +// ----- + +func.func @alloca() -> memref<4x8xf32> { + %0 = memref.alloca() : memref<4x8xf32> + return %0 : memref<4x8xf32> +} + +// CHECK-LABEL: func.func @alloca() -> memref<4x8xf32> +// CHECK: %[[ALLOC:.*]] = memref.alloca() : memref<32xf32, strided<[1]>> +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [8, 1] : memref<32xf32, strided<[1]>> to memref<4x8xf32> + +// ----- + +func.func @chained_alloc_load() -> vector<8xf32> { + %c3 = arith.constant 3 : index + %c6 = arith.constant 6 : index + %0 = memref.alloc() : memref<4x8xf32> + %value = vector.load %0[%c3, %c6] : memref<4x8xf32>, vector<8xf32> + return %value : vector<8xf32> +} + +// CHECK-LABEL: func @chained_alloc_load +// CHECK-SAME: () -> vector<8xf32> +// CHECK-NEXT: %[[C30:.*]] = arith.constant 30 : index +// CHECK-NEXT: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32, strided<[1]>> +// CHECK-NEXT: vector.load %[[ALLOC]][%[[C30]]] : memref<32xf32, strided<[1]>>, vector<8xf32> + +// ----- + +func.func @load_scalar_from_memref_static_dim_col_major(%input: memref<4x8xf32, strided<[1, 4], offset: 100>>, %row: index, %col: index) -> f32 { + %value = memref.load %input[%col, %row] : memref<4x8xf32, strided<[1, 4], offset: 100>> + return %value : f32 +} + +// CHECK: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s0 + s1 * 4)> +// CHECK: func @load_scalar_from_memref_static_dim_col_major +// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8xf32, strided<[1, 4], offset: 100>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) +// CHECK: %[[IDX:.*]] = affine.apply #[[MAP]]()[%[[ARG2]], %[[ARG1]]] +// CHECK: %[[REINT:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [100], sizes: [32], strides: [1] : memref<4x8xf32, strided<[1, 4], offset: 100>> to memref<32xf32, strided<[1], offset: 100>> +// CHECK: memref.load %[[REINT]][%[[IDX]]] : memref<32xf32, strided<[1], offset: 100>>