Skip to content

[mlir][scf][vector] Add scf.parallel vectorizer #94168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

Hardcode84
Copy link
Contributor

@Hardcode84 Hardcode84 commented Jun 2, 2024

Add scf.parallel vectorizer utilities and a test pass.

Add 2 functions:

  • getLoopVectorizeInfo - collect scf.parallel loop vectorization info for the specific dimension, and target vector size. Returns number of ops which will be potentially vectorized, vectorization factor and if masked mode can be used.
  • vectorizeLoop - unrolls specified scf.parallel dimension factor times and vectorizes ops if possible. Non-vectorizable ops will be replicated.

scf.reduce reductions are supported and will use vector reduction if possible.
Ops with nested regions beside scf.reduce are not supported yet.

Vectorizer has 2 modes:

  • Masked - run loop ceildiv number of iterations and use masked vector ops to handle out-of-bounds access.
  • Non-masked - run loop floordiv number of iterations and add a second loop to handle remaining items.

Upstreaming from numba-mlir project https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/SCFVectorize.cpp

Some initial upstreaming work by @makslevental

@Hardcode84 Hardcode84 marked this pull request as ready for review June 2, 2024 19:50
@Hardcode84 Hardcode84 requested a review from dcaballe June 2, 2024 19:59
@Hardcode84 Hardcode84 force-pushed the scf_vectorize_cont branch from 01bedd0 to cbebacc Compare June 3, 2024 01:39
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jun 3, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir-core

Author: Ivan Butygin (Hardcode84)

Changes

Add scf.parallel vectorizer utilities and a test pass.

Add 2 functions:

  • getLoopVectorizeInfo - collect scf.parallel loop vectorization info for the specific dimension, and target vector size. Returns number of ops which will be potentially vectorized, vectorization factor and if masked mode can be used.
  • vectorizeLoop - unrolls specified scf.parallel dimension factor times and vectorizes ops if possible. Non-vectorizable ops will be replicated.

scf.reduce reductions are supported and will use vector reduction if possible.
Ops with nested regions beside scf.reduce are not supported yet.

Vectorizer has 2 modes:

  • Masked - unroll loop to ceildiv number of iterations and use masked vector ops to handle out-of-bounds access.
  • Non-masked - unroll to floordiv number of iterations and add a second loop to handle remaining items.

Upstreaming from numba-mlir project https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/SCFVectorize.cpp

Some initial upstreaming work by @makslevental


Patch is 51.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94168.diff

7 Files Affected:

  • (added) mlir/include/mlir/Transforms/SCFVectorize.h (+70)
  • (modified) mlir/lib/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Transforms/SCFVectorize.cpp (+648)
  • (added) mlir/test/Transforms/test-scf-vectorize.mlir (+272)
  • (modified) mlir/test/lib/Transforms/CMakeLists.txt (+1)
  • (added) mlir/test/lib/Transforms/TestSCFVectorize.cpp (+110)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+24-22)
diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Transforms/SCFVectorize.h
new file mode 100644
index 0000000000000..d2a5e3085ae37
--- /dev/null
+++ b/mlir/include/mlir/Transforms/SCFVectorize.h
@@ -0,0 +1,70 @@
+//===- SCFVectorize.h - ------------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SCFVECTORIZE_H_
+#define MLIR_TRANSFORMS_SCFVECTORIZE_H_
+
+#include <optional>
+
+namespace mlir {
+class DataLayout;
+struct LogicalResult;
+namespace scf {
+class ParallelOp;
+}
+} // namespace mlir
+
+namespace mlir {
+
+/// Loop vectorization info
+struct SCFVectorizeInfo {
+  /// Loop dimension on which to vectorize.
+  unsigned dim = 0;
+
+  /// Biggest vector width, in elements.
+  unsigned factor = 0;
+
+  /// Number of ops, which will be vectorized.
+  unsigned count = 0;
+
+  /// Can use masked vector ops for our of bounds memory accesses.
+  bool masked = false;
+};
+
+/// Collect vectorization statistics on specified `scf.parallel` dimension.
+/// Return `SCFVectorizeInfo` or `std::nullopt` if loop cannot be vectorized on
+/// specified dimension.
+///
+/// `vectorBitwidth` - maximum vector size, in bits.
+std::optional<SCFVectorizeInfo>
+getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
+                     unsigned vectorBitwidth, const DataLayout *DL = nullptr);
+
+/// Vectorization params
+struct SCFVectorizeParams {
+  /// Loop dimension on which to vectorize.
+  unsigned dim = 0;
+
+  /// Desired vector length, in elements
+  unsigned factor = 0;
+
+  /// Use masked vector ops for memory access outside loop bounds.
+  bool masked = false;
+};
+
+/// Vectorize loop on specified dimension with specified factor.
+///
+/// If `masked` is `true` and loop bound is not divisible by `factor`, instead
+/// of generating second loop to process remainig iterations, extend loop count
+/// and generate masked vector ops to handle out-of bounds memory accesses.
+mlir::LogicalResult vectorizeLoop(mlir::scf::ParallelOp loop,
+                                  const SCFVectorizeParams &params,
+                                  const DataLayout *DL = nullptr);
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46..ed71c73c938ed 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms
   PrintIR.cpp
   RemoveDeadValues.cpp
   SCCP.cpp
+  SCFVectorize.cpp
   SROA.cpp
   StripDebugInfo.cpp
   SymbolDCE.cpp
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
new file mode 100644
index 0000000000000..29e184e584a56
--- /dev/null
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -0,0 +1,648 @@
+//===- SCFVectorize.cpp - SCF vectorization utilities ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/SCFVectorize.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // getCombinerOpKind
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/IRMapping.h"
+
+using namespace mlir;
+
+static bool isSupportedVecElem(Type type) { return type.isIntOrIndexOrFloat(); }
+
+/// Return type bitwidth for vectorization purposes or 0 if type cannot be
+/// vectorized.
+static unsigned getTypeBitWidth(Type type, const DataLayout *DL) {
+  if (!isSupportedVecElem(type))
+    return 0;
+
+  if (DL)
+    return DL->getTypeSizeInBits(type);
+
+  if (type.isIntOrFloat())
+    return type.getIntOrFloatBitWidth();
+
+  return 0;
+}
+
+static unsigned getArgsTypeWidth(Operation &op, const DataLayout *DL) {
+  unsigned ret = 0;
+  for (auto arg : op.getOperands())
+    ret = std::max(ret, getTypeBitWidth(arg.getType(), DL));
+
+  for (auto res : op.getResults())
+    ret = std::max(ret, getTypeBitWidth(res.getType(), DL));
+
+  return ret;
+}
+
+static bool isSupportedVectorOp(Operation &op) {
+  return op.hasTrait<OpTrait::Vectorizable>();
+}
+
+/// Check if one `ValueRange` is permutation of another, i.e. contains same
+/// values, potentially in different order.
+static bool isRangePermutation(ValueRange val1, ValueRange val2) {
+  if (val1.size() != val2.size())
+    return false;
+
+  for (auto v1 : val1) {
+    auto it = llvm::find(val2, v1);
+    if (it == val2.end())
+      return false;
+  }
+  return true;
+}
+
+template <typename Op>
+static std::optional<unsigned>
+cavTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
+                               const DataLayout *DL) {
+  auto loopIndexVars = loop.getInductionVars();
+  assert(dim < loopIndexVars.size());
+  auto memref = memOp.getMemRef();
+  auto type = cast<MemRefType>(memref.getType());
+  auto width = getTypeBitWidth(type.getElementType(), DL);
+  if (width == 0)
+    return std::nullopt;
+
+  if (!type.getLayout().isIdentity())
+    return std::nullopt;
+
+  if (!isRangePermutation(memOp.getIndices(), loopIndexVars))
+    return std::nullopt;
+
+  if (memOp.getIndices().back() != loopIndexVars[dim])
+    return std::nullopt;
+
+  DominanceInfo dom;
+  if (!dom.properlyDominates(memref, loop))
+    return std::nullopt;
+
+  return width;
+}
+
+/// Check if memref load/store can be converted into vectorized load/store
+///
+/// Returns memref element bitwidth or `std::nullopt` if access cannot be
+/// vectorized.
+static std::optional<unsigned>
+cavTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
+                           const DataLayout *DL) {
+  assert(dim < loop.getInductionVars().size());
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL);
+
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL);
+
+  return std::nullopt;
+}
+
+template <typename Op>
+static std::optional<unsigned> canGatherScatterImpl(scf::ParallelOp loop, Op op,
+                                                    const DataLayout *DL) {
+  auto memref = op.getMemRef();
+  auto memrefType = cast<MemRefType>(memref.getType());
+  auto width = getTypeBitWidth(memrefType.getElementType(), DL);
+  if (width == 0)
+    return std::nullopt;
+
+  DominanceInfo dom;
+  return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 &&
+         memrefType.getLayout().isIdentity();
+}
+
+// Check if memref access can be converted into gather/scatter.
+///
+/// Returns memref element bitwidth or `std::nullopt` if access cannot be
+/// vectorized.
+static std::optional<unsigned>
+canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) {
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return canGatherScatterImpl(loop, storeOp, DL);
+
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return canGatherScatterImpl(loop, loadOp, DL);
+
+  return std::nullopt;
+}
+
+static std::optional<unsigned> cenVectorizeMemrefOp(scf::ParallelOp loop,
+                                                    unsigned dim, Operation &op,
+                                                    const DataLayout *DL) {
+  if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op, DL))
+    return w;
+
+  return canGatherScatter(loop, op, DL);
+}
+
+/// Returns `vector.reduce` kind for specified `scf.parallel` reduce op ot
+/// `std::nullopt` if reduction cannot be handled by `vector.reduce`.
+static std::optional<vector::CombiningKind> getReductionKind(Block &body) {
+  if (!llvm::hasSingleElement(body.without_terminator()))
+    return std::nullopt;
+
+  // TODO: Move getCombinerOpKind to vector dialect.
+  return linalg::getCombinerOpKind(&body.front());
+}
+
+std::optional<SCFVectorizeInfo>
+mlir::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
+                           unsigned vectorBitwidth, const DataLayout *DL) {
+  assert(dim < loop.getStep().size());
+  assert(vectorBitwidth > 0);
+  unsigned factor = vectorBitwidth / 8;
+  if (factor <= 1)
+    return std::nullopt;
+
+  /// Only step==1 is supported for now.
+  if (!isConstantIntValue(loop.getStep()[dim], 1))
+    return std::nullopt;
+
+  unsigned count = 0;
+  bool masked = true;
+
+  /// Check if `scf.reduce` can be handled by `vector.reduce`.
+  /// If not we still can vectorize the loop but we cannot use masked
+  /// vectorize.
+  auto reduce = cast<scf::ReduceOp>(loop.getBody()->getTerminator());
+  for (Region &reg : reduce.getReductions()) {
+    if (!getReductionKind(reg.front()))
+      masked = false;
+
+    continue;
+  }
+
+  for (Operation &op : loop.getBody()->without_terminator()) {
+    /// Ops with nested regions are not supported yet.
+    if (op.getNumRegions() > 0)
+      return std::nullopt;
+
+    /// Check mem ops.
+    if (auto w = cenVectorizeMemrefOp(loop, dim, op, DL)) {
+      auto newFactor = vectorBitwidth / *w;
+      if (newFactor > 1) {
+        factor = std::min(factor, newFactor);
+        ++count;
+      }
+      continue;
+    }
+
+    /// If met the op which cannot be vectorized, we can replicate it and still
+    /// potentially vectorize other ops, but we cannot use masked vectorize.
+    if (!isSupportedVectorOp(op)) {
+      masked = false;
+      continue;
+    }
+
+    auto width = getArgsTypeWidth(op, DL);
+    if (width == 0)
+      return std::nullopt;
+
+    auto newFactor = vectorBitwidth / width;
+    if (newFactor <= 1)
+      continue;
+
+    factor = std::min(factor, newFactor);
+
+    ++count;
+  }
+
+  /// No ops to vectorize.
+  if (count == 0)
+    return std::nullopt;
+
+  return SCFVectorizeInfo{dim, factor, count, masked};
+}
+
+/// Get fastmath flags if ops support them or default (none).
+static arith::FastMathFlags getFMF(Operation &op) {
+  if (auto fmf = dyn_cast<arith::ArithFastMathInterface>(op))
+    return fmf.getFastMathFlagsAttr().getValue();
+
+  return arith::FastMathFlags::none;
+}
+
+LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop,
+                                  const SCFVectorizeParams &params,
+                                  const DataLayout *DL) {
+  auto dim = params.dim;
+  auto factor = params.factor;
+  auto masked = params.masked;
+  assert(dim < loop.getStep().size());
+  assert(factor > 1);
+  assert(isConstantIntValue(loop.getStep()[dim], 1));
+
+  OpBuilder builder(loop);
+  auto lower = llvm::to_vector(loop.getLowerBound());
+  auto upper = llvm::to_vector(loop.getUpperBound());
+  auto step = llvm::to_vector(loop.getStep());
+
+  auto loc = loop.getLoc();
+
+  auto origIndexVar = loop.getInductionVars()[dim];
+
+  Value factorVal = builder.create<arith::ConstantIndexOp>(loc, factor);
+
+  auto origLower = lower[dim];
+  auto origUpper = upper[dim];
+  Value count = builder.createOrFold<arith::SubIOp>(loc, origUpper, origLower);
+  Value newCount;
+
+  // Compute new loop count, ceildiv if masked, floordiv otherwise.
+  if (masked) {
+    newCount = builder.createOrFold<arith::CeilDivSIOp>(loc, count, factorVal);
+  } else {
+    newCount = builder.createOrFold<arith::DivSIOp>(loc, count, factorVal);
+  }
+
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  lower[dim] = zero;
+  upper[dim] = newCount;
+
+  // Vectorized loop.
+  auto newLoop = builder.create<scf::ParallelOp>(loc, lower, upper, step,
+                                                 loop.getInitVals());
+  auto newIndexVar = newLoop.getInductionVars()[dim];
+
+  auto toVectorType = [&](Type elemType) -> VectorType {
+    int64_t f = factor;
+    return VectorType::get(f, elemType);
+  };
+
+  IRMapping mapping;
+  IRMapping scalarMapping;
+
+  auto createPosionVec = [&](VectorType vecType) -> Value {
+    return builder.create<ub::PoisonOp>(loc, vecType, nullptr);
+  };
+
+  Value indexVarMult;
+  auto getrIndexVarMult = [&]() -> Value {
+    if (indexVarMult)
+      return indexVarMult;
+
+    indexVarMult =
+        builder.createOrFold<arith::MulIOp>(loc, newIndexVar, factorVal);
+    return indexVarMult;
+  };
+
+  // Get vector value in new loop for provided `orig` value in source loop.
+  auto getVecVal = [&](Value orig) -> Value {
+    // Use cached value if present.
+    if (auto mapped = mapping.lookupOrNull(orig))
+      return mapped;
+
+    // Vectorized loop index, loop index is divided by factor, so for factorN
+    // vectorized index will looks like `splat(idx) + (0, 1, ..., N - 1)`
+    if (orig == origIndexVar) {
+      auto vecType = toVectorType(builder.getIndexType());
+      llvm::SmallVector<Attribute> elems(factor);
+      for (auto i : llvm::seq(0u, factor))
+        elems[i] = builder.getIndexAttr(i);
+      auto attr = DenseElementsAttr::get(vecType, elems);
+      Value vec = builder.create<arith::ConstantOp>(loc, vecType, attr);
+
+      Value idx = getrIndexVarMult();
+      idx = builder.createOrFold<arith::AddIOp>(loc, idx, origLower);
+      idx = builder.create<vector::SplatOp>(loc, idx, vecType);
+      vec = builder.createOrFold<arith::AddIOp>(loc, idx, vec);
+      mapping.map(orig, vec);
+      return vec;
+    }
+    auto type = orig.getType();
+    assert(isSupportedVecElem(type));
+
+    Value val = orig;
+    auto origIndexVars = loop.getInductionVars();
+    auto it = llvm::find(origIndexVars, orig);
+
+    // If loop index, but not on vectorized dimension, just take new loop index
+    // and splat it.
+    if (it != origIndexVars.end())
+      val = newLoop.getInductionVars()[it - origIndexVars.begin()];
+
+    // Values which are defined inside loop body are preemptively added to the
+    // mapper and not handled here. Values defined outside body are just
+    // splatted.
+
+    auto vecType = toVectorType(type);
+    Value vec = builder.create<vector::SplatOp>(loc, val, vecType);
+    mapping.map(orig, vec);
+    return vec;
+  };
+
+  llvm::DenseMap<Value, llvm::SmallVector<Value>> unpackedVals;
+
+  // Get unpacked values for provided `orig` value in source loop.
+  // Values are returned as `ValueRange` and not as vector value.
+  auto getUnpackedVals = [&](Value val) -> ValueRange {
+    // Use cached values if present.
+    auto it = unpackedVals.find(val);
+    if (it != unpackedVals.end())
+      return it->second;
+
+    // Values which are defined inside loop body are preemptively added to the
+    // cache and not handled here.
+
+    auto &ret = unpackedVals[val];
+    assert(ret.empty());
+    if (!isSupportedVecElem(val.getType())) {
+      // Non vectorizable value, it must be a value defined outside the loop,
+      // just replicate it.
+      ret.resize(factor, val);
+      return ret;
+    }
+
+    // Get vector value and extract elements from it.
+    auto vecVal = getVecVal(val);
+    ret.resize(factor);
+    for (auto i : llvm::seq(0u, factor)) {
+      Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
+      ret[i] = builder.create<vector::ExtractElementOp>(loc, vecVal, idx);
+    }
+    return ret;
+  };
+
+  // Add unpacked values to the cache.
+  auto setUnpackedVals = [&](Value origVal, ValueRange newVals) {
+    assert(newVals.size() == factor);
+    assert(unpackedVals.count(origVal) == 0);
+    unpackedVals[origVal].append(newVals.begin(), newVals.end());
+
+    auto type = origVal.getType();
+    if (!isSupportedVecElem(type))
+      return;
+
+    // If type is vectorizabale construct a vector add it to vector cache as
+    // well.
+    auto vecType = toVectorType(type);
+
+    Value vec = createPosionVec(vecType);
+    for (auto i : llvm::seq(0u, factor)) {
+      Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
+      vec = builder.create<vector::InsertElementOp>(loc, newVals[i], vec, idx);
+    }
+    mapping.map(origVal, vec);
+  };
+
+  Value mask;
+
+  // Contruct mask value and cache it. If not a masked mode mask is always all
+  // 1s.
+  auto getMask = [&]() -> Value {
+    if (mask)
+      return mask;
+
+    OpFoldResult maskSize;
+    if (masked) {
+      Value size = getrIndexVarMult();
+      maskSize = builder.createOrFold<arith::SubIOp>(loc, count, size);
+    } else {
+      maskSize = builder.getIndexAttr(factor);
+    }
+    auto vecType = toVectorType(builder.getI1Type());
+    mask = builder.create<vector::CreateMaskOp>(loc, vecType, maskSize);
+
+    return mask;
+  };
+
+  auto canTriviallyVectorizeMemOp = [&](auto op) -> bool {
+    return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op, DL);
+  };
+
+  auto canGatherScatter = [&](auto op) {
+    return !!::canGatherScatterImpl(loop, op, DL);
+  };
+
+  // Get idices for vectorized memref load/store.
+  auto getMemrefVecIndices = [&](ValueRange indices) {
+    scalarMapping.clear();
+    scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars());
+
+    llvm::SmallVector<Value> ret(indices.size());
+    for (auto &&[i, val] : llvm::enumerate(indices)) {
+      if (val == origIndexVar) {
+        Value idx = getrIndexVarMult();
+        idx = builder.createOrFold<arith::AddIOp>(loc, idx, origLower);
+        ret[i] = idx;
+        continue;
+      }
+      ret[i] = scalarMapping.lookup(val);
+    }
+
+    return ret;
+  };
+
+  // Create vectorized memref load for specified non-vectorized load.
+  auto genLoad = [&](auto loadOp) {
+    auto indices = getMemrefVecIndices(loadOp.getIndices());
+    auto resType = toVectorType(loadOp.getResult().getType());
+    auto memref = loadOp.getMemRef();
+    Value vecLoad;
+    if (masked) {
+      auto mask = getMask();
+      auto init = createPosionVec(resType);
+      vecLoad = builder.create<vector::MaskedLoadOp>(loc, resType, memref,
+                                                     indices, mask, init);
+    } else {
+      vecLoad = builder.create<vector::LoadOp>(loc, resType, memref, indices);
+    }
+    mapping.map(loadOp.getResult(), vecLoad);
+  };
+
+  // Create vectorized memref store for specified non-vectorized store.
+  auto genStore = [&](auto storeOp) {
+    auto indices = getMemrefVecIndices(storeOp.getIndices());
+    auto value = getVecVal(storeOp.getValueToStore());
+    auto memref = storeOp.getMemRef();
+    if (masked) {
+      auto mask = getMask();
+      builder.create<vector::MaskedStoreOp>(loc, memref, indices, mask, value);
+    } else {
+      builder.create<vector::StoreOp>(loc, value, memref, indices);
+    }
+  };
+
+  llvm::SmallVector<Value> duplicatedArgs;
+  llvm::SmallVector<Value> duplicatedResults;
+
+  builder.setInsertionPointToStart(newLoop.getBody());
+  for (Operation &op : loop.getBody()->without_terminator()) {
+    loc = op.getLoc();
+    if (isSupportedVectorOp(op)) {
+      // If op can be vectorized, clone it with vectorized inputs and  update
+      // resuls to vectorized types.
+      for (auto arg : op.getOperands())
+        getVecVal(arg); // init mapper for op args
+
+      auto newOp = builder.clone(op, mapping);
+      for (auto res : newOp->getResults())
+        res.setType(toVectorType(res.getType()));
+
+      continue;
+    }
+
+    // Vectorize memref load/store ops, vector load/store are preffered over
+    // gather/scatter.
+    if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
+      if (canTriviallyVectorizeMemOp(loadOp)) {
+        genLoad(loadOp);
+        continue;
+      }
+      if (canGatherScatter(loadOp)) {
+        auto resType = toVectorType(loadOp.getResult().getType());
+        auto memref = loadOp.getMemRef();
+        auto mask = getMask();
+        auto indexVec = getVecVal(loadOp.getIndices()[0]);
+        auto init = createPosionVec(resType);
+
+        auto gather = builder.create<vector::GatherOp>(
+            loc, resType, memref, zero, indexVec, mask, init);
+        mapping.map(loadOp.getResult(), gather.getResult());
+        continue;
+      }
+    }
+
+    if (auto stor...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2024

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Add scf.parallel vectorizer utilities and a test pass.

Add 2 functions:

  • getLoopVectorizeInfo - collect scf.parallel loop vectorization info for the specific dimension, and target vector size. Returns number of ops which will be potentially vectorized, vectorization factor and if masked mode can be used.
  • vectorizeLoop - unrolls specified scf.parallel dimension factor times and vectorizes ops if possible. Non-vectorizable ops will be replicated.

scf.reduce reductions are supported and will use vector reduction if possible.
Ops with nested regions beside scf.reduce are not supported yet.

Vectorizer has 2 modes:

  • Masked - unroll loop to ceildiv number of iterations and use masked vector ops to handle out-of-bounds access.
  • Non-masked - unroll to floordiv number of iterations and add a second loop to handle remaining items.

Upstreaming from numba-mlir project https://github.com/numba/numba-mlir/blob/main/mlir/lib/Transforms/SCFVectorize.cpp

Some initial upstreaming work by @makslevental


Patch is 51.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94168.diff

7 Files Affected:

  • (added) mlir/include/mlir/Transforms/SCFVectorize.h (+70)
  • (modified) mlir/lib/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Transforms/SCFVectorize.cpp (+648)
  • (added) mlir/test/Transforms/test-scf-vectorize.mlir (+272)
  • (modified) mlir/test/lib/Transforms/CMakeLists.txt (+1)
  • (added) mlir/test/lib/Transforms/TestSCFVectorize.cpp (+110)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+24-22)
diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Transforms/SCFVectorize.h
new file mode 100644
index 0000000000000..d2a5e3085ae37
--- /dev/null
+++ b/mlir/include/mlir/Transforms/SCFVectorize.h
@@ -0,0 +1,70 @@
+//===- SCFVectorize.h - ------------------------------------------*- C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_SCFVECTORIZE_H_
+#define MLIR_TRANSFORMS_SCFVECTORIZE_H_
+
+#include <optional>
+
+namespace mlir {
+class DataLayout;
+struct LogicalResult;
+namespace scf {
+class ParallelOp;
+}
+} // namespace mlir
+
+namespace mlir {
+
+/// Loop vectorization info
+struct SCFVectorizeInfo {
+  /// Loop dimension on which to vectorize.
+  unsigned dim = 0;
+
+  /// Biggest vector width, in elements.
+  unsigned factor = 0;
+
+  /// Number of ops, which will be vectorized.
+  unsigned count = 0;
+
+  /// Can use masked vector ops for our of bounds memory accesses.
+  bool masked = false;
+};
+
+/// Collect vectorization statistics on specified `scf.parallel` dimension.
+/// Return `SCFVectorizeInfo` or `std::nullopt` if loop cannot be vectorized on
+/// specified dimension.
+///
+/// `vectorBitwidth` - maximum vector size, in bits.
+std::optional<SCFVectorizeInfo>
+getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim,
+                     unsigned vectorBitwidth, const DataLayout *DL = nullptr);
+
+/// Vectorization params
+struct SCFVectorizeParams {
+  /// Loop dimension on which to vectorize.
+  unsigned dim = 0;
+
+  /// Desired vector length, in elements
+  unsigned factor = 0;
+
+  /// Use masked vector ops for memory access outside loop bounds.
+  bool masked = false;
+};
+
+/// Vectorize loop on specified dimension with specified factor.
+///
+/// If `masked` is `true` and loop bound is not divisible by `factor`, instead
+/// of generating second loop to process remainig iterations, extend loop count
+/// and generate masked vector ops to handle out-of bounds memory accesses.
+mlir::LogicalResult vectorizeLoop(mlir::scf::ParallelOp loop,
+                                  const SCFVectorizeParams &params,
+                                  const DataLayout *DL = nullptr);
+} // namespace mlir
+
+#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 90c0298fb5e46..ed71c73c938ed 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms
   PrintIR.cpp
   RemoveDeadValues.cpp
   SCCP.cpp
+  SCFVectorize.cpp
   SROA.cpp
   StripDebugInfo.cpp
   SymbolDCE.cpp
diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp
new file mode 100644
index 0000000000000..29e184e584a56
--- /dev/null
+++ b/mlir/lib/Transforms/SCFVectorize.cpp
@@ -0,0 +1,648 @@
+//===- SCFVectorize.cpp - SCF vectorization utilities ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/SCFVectorize.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // getCombinerOpKind
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/IRMapping.h"
+
+using namespace mlir;
+
+static bool isSupportedVecElem(Type type) { return type.isIntOrIndexOrFloat(); }
+
+/// Return type bitwidth for vectorization purposes or 0 if type cannot be
+/// vectorized.
+static unsigned getTypeBitWidth(Type type, const DataLayout *DL) {
+  if (!isSupportedVecElem(type))
+    return 0;
+
+  if (DL)
+    return DL->getTypeSizeInBits(type);
+
+  if (type.isIntOrFloat())
+    return type.getIntOrFloatBitWidth();
+
+  return 0;
+}
+
+static unsigned getArgsTypeWidth(Operation &op, const DataLayout *DL) {
+  unsigned ret = 0;
+  for (auto arg : op.getOperands())
+    ret = std::max(ret, getTypeBitWidth(arg.getType(), DL));
+
+  for (auto res : op.getResults())
+    ret = std::max(ret, getTypeBitWidth(res.getType(), DL));
+
+  return ret;
+}
+
+static bool isSupportedVectorOp(Operation &op) {
+  return op.hasTrait<OpTrait::Vectorizable>();
+}
+
+/// Check if one `ValueRange` is permutation of another, i.e. contains same
+/// values, potentially in different order.
+static bool isRangePermutation(ValueRange val1, ValueRange val2) {
+  if (val1.size() != val2.size())
+    return false;
+
+  for (auto v1 : val1) {
+    auto it = llvm::find(val2, v1);
+    if (it == val2.end())
+      return false;
+  }
+  return true;
+}
+
+template <typename Op>
+static std::optional<unsigned>
+cavTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp,
+                               const DataLayout *DL) {
+  auto loopIndexVars = loop.getInductionVars();
+  assert(dim < loopIndexVars.size());
+  auto memref = memOp.getMemRef();
+  auto type = cast<MemRefType>(memref.getType());
+  auto width = getTypeBitWidth(type.getElementType(), DL);
+  if (width == 0)
+    return std::nullopt;
+
+  if (!type.getLayout().isIdentity())
+    return std::nullopt;
+
+  if (!isRangePermutation(memOp.getIndices(), loopIndexVars))
+    return std::nullopt;
+
+  if (memOp.getIndices().back() != loopIndexVars[dim])
+    return std::nullopt;
+
+  DominanceInfo dom;
+  if (!dom.properlyDominates(memref, loop))
+    return std::nullopt;
+
+  return width;
+}
+
+/// Check if memref load/store can be converted into vectorized load/store
+///
+/// Returns memref element bitwidth or `std::nullopt` if access cannot be
+/// vectorized.
+static std::optional<unsigned>
+cavTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op,
+                           const DataLayout *DL) {
+  assert(dim < loop.getInductionVars().size());
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL);
+
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL);
+
+  return std::nullopt;
+}
+
+template <typename Op>
+static std::optional<unsigned> canGatherScatterImpl(scf::ParallelOp loop, Op op,
+                                                    const DataLayout *DL) {
+  auto memref = op.getMemRef();
+  auto memrefType = cast<MemRefType>(memref.getType());
+  auto width = getTypeBitWidth(memrefType.getElementType(), DL);
+  if (width == 0)
+    return std::nullopt;
+
+  DominanceInfo dom;
+  return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 &&
+         memrefType.getLayout().isIdentity();
+}
+
+// Check if memref access can be converted into gather/scatter.
+///
+/// Returns memref element bitwidth or `std::nullopt` if access cannot be
+/// vectorized.
+static std::optional<unsigned>
+canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) {
+  if (auto storeOp = dyn_cast<memref::StoreOp>(op))
+    return canGatherScatterImpl(loop, storeOp, DL);
+
+  if (auto loadOp = dyn_cast<memref::LoadOp>(op))
+    return canGatherScatterImpl(loop, loadOp, DL);
+
+  return std::nullopt;
+}
+
+static std::optional<unsigned> cenVectorizeMemrefOp(scf::ParallelOp loop,
+                                                    unsigned dim, Operation &op,
+                                                    const DataLayout *DL) {
+  if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op, DL))
+    return w;
+
+  return canGatherScatter(loop, op, DL);
+}
+
+/// Returns `vector.reduce` kind for specified `scf.parallel` reduce op ot
+/// `std::nullopt` if reduction cannot be handled by `vector.reduce`.
+static std::optional<vector::CombiningKind> getReductionKind(Block &body) {
+  if (!llvm::hasSingleElement(body.without_terminator()))
+    return std::nullopt;
+
+  // TODO: Move getCombinerOpKind to vector dialect.
+  return linalg::getCombinerOpKind(&body.front());
+}
+
+std::optional<SCFVectorizeInfo>
+mlir::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim,
+                           unsigned vectorBitwidth, const DataLayout *DL) {
+  assert(dim < loop.getStep().size());
+  assert(vectorBitwidth > 0);
+  unsigned factor = vectorBitwidth / 8;
+  if (factor <= 1)
+    return std::nullopt;
+
+  /// Only step==1 is supported for now.
+  if (!isConstantIntValue(loop.getStep()[dim], 1))
+    return std::nullopt;
+
+  unsigned count = 0;
+  bool masked = true;
+
+  /// Check if `scf.reduce` can be handled by `vector.reduce`.
+  /// If not we still can vectorize the loop but we cannot use masked
+  /// vectorize.
+  auto reduce = cast<scf::ReduceOp>(loop.getBody()->getTerminator());
+  for (Region &reg : reduce.getReductions()) {
+    if (!getReductionKind(reg.front()))
+      masked = false;
+
+    continue;
+  }
+
+  for (Operation &op : loop.getBody()->without_terminator()) {
+    /// Ops with nested regions are not supported yet.
+    if (op.getNumRegions() > 0)
+      return std::nullopt;
+
+    /// Check mem ops.
+    if (auto w = cenVectorizeMemrefOp(loop, dim, op, DL)) {
+      auto newFactor = vectorBitwidth / *w;
+      if (newFactor > 1) {
+        factor = std::min(factor, newFactor);
+        ++count;
+      }
+      continue;
+    }
+
+    /// If met the op which cannot be vectorized, we can replicate it and still
+    /// potentially vectorize other ops, but we cannot use masked vectorize.
+    if (!isSupportedVectorOp(op)) {
+      masked = false;
+      continue;
+    }
+
+    auto width = getArgsTypeWidth(op, DL);
+    if (width == 0)
+      return std::nullopt;
+
+    auto newFactor = vectorBitwidth / width;
+    if (newFactor <= 1)
+      continue;
+
+    factor = std::min(factor, newFactor);
+
+    ++count;
+  }
+
+  /// No ops to vectorize.
+  if (count == 0)
+    return std::nullopt;
+
+  return SCFVectorizeInfo{dim, factor, count, masked};
+}
+
+/// Get fastmath flags if ops support them or default (none).
+static arith::FastMathFlags getFMF(Operation &op) {
+  if (auto fmf = dyn_cast<arith::ArithFastMathInterface>(op))
+    return fmf.getFastMathFlagsAttr().getValue();
+
+  return arith::FastMathFlags::none;
+}
+
+LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop,
+                                  const SCFVectorizeParams &params,
+                                  const DataLayout *DL) {
+  auto dim = params.dim;
+  auto factor = params.factor;
+  auto masked = params.masked;
+  assert(dim < loop.getStep().size());
+  assert(factor > 1);
+  assert(isConstantIntValue(loop.getStep()[dim], 1));
+
+  OpBuilder builder(loop);
+  auto lower = llvm::to_vector(loop.getLowerBound());
+  auto upper = llvm::to_vector(loop.getUpperBound());
+  auto step = llvm::to_vector(loop.getStep());
+
+  auto loc = loop.getLoc();
+
+  auto origIndexVar = loop.getInductionVars()[dim];
+
+  Value factorVal = builder.create<arith::ConstantIndexOp>(loc, factor);
+
+  auto origLower = lower[dim];
+  auto origUpper = upper[dim];
+  Value count = builder.createOrFold<arith::SubIOp>(loc, origUpper, origLower);
+  Value newCount;
+
+  // Compute new loop count, ceildiv if masked, floordiv otherwise.
+  if (masked) {
+    newCount = builder.createOrFold<arith::CeilDivSIOp>(loc, count, factorVal);
+  } else {
+    newCount = builder.createOrFold<arith::DivSIOp>(loc, count, factorVal);
+  }
+
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  lower[dim] = zero;
+  upper[dim] = newCount;
+
+  // Vectorized loop.
+  auto newLoop = builder.create<scf::ParallelOp>(loc, lower, upper, step,
+                                                 loop.getInitVals());
+  auto newIndexVar = newLoop.getInductionVars()[dim];
+
+  auto toVectorType = [&](Type elemType) -> VectorType {
+    int64_t f = factor;
+    return VectorType::get(f, elemType);
+  };
+
+  IRMapping mapping;
+  IRMapping scalarMapping;
+
+  auto createPosionVec = [&](VectorType vecType) -> Value {
+    return builder.create<ub::PoisonOp>(loc, vecType, nullptr);
+  };
+
+  Value indexVarMult;
+  auto getrIndexVarMult = [&]() -> Value {
+    if (indexVarMult)
+      return indexVarMult;
+
+    indexVarMult =
+        builder.createOrFold<arith::MulIOp>(loc, newIndexVar, factorVal);
+    return indexVarMult;
+  };
+
+  // Get vector value in new loop for provided `orig` value in source loop.
+  auto getVecVal = [&](Value orig) -> Value {
+    // Use cached value if present.
+    if (auto mapped = mapping.lookupOrNull(orig))
+      return mapped;
+
+    // Vectorized loop index, loop index is divided by factor, so for factorN
+    // vectorized index will looks like `splat(idx) + (0, 1, ..., N - 1)`
+    if (orig == origIndexVar) {
+      auto vecType = toVectorType(builder.getIndexType());
+      llvm::SmallVector<Attribute> elems(factor);
+      for (auto i : llvm::seq(0u, factor))
+        elems[i] = builder.getIndexAttr(i);
+      auto attr = DenseElementsAttr::get(vecType, elems);
+      Value vec = builder.create<arith::ConstantOp>(loc, vecType, attr);
+
+      Value idx = getrIndexVarMult();
+      idx = builder.createOrFold<arith::AddIOp>(loc, idx, origLower);
+      idx = builder.create<vector::SplatOp>(loc, idx, vecType);
+      vec = builder.createOrFold<arith::AddIOp>(loc, idx, vec);
+      mapping.map(orig, vec);
+      return vec;
+    }
+    auto type = orig.getType();
+    assert(isSupportedVecElem(type));
+
+    Value val = orig;
+    auto origIndexVars = loop.getInductionVars();
+    auto it = llvm::find(origIndexVars, orig);
+
+    // If loop index, but not on vectorized dimension, just take new loop index
+    // and splat it.
+    if (it != origIndexVars.end())
+      val = newLoop.getInductionVars()[it - origIndexVars.begin()];
+
+    // Values which are defined inside loop body are preemptively added to the
+    // mapper and not handled here. Values defined outside body are just
+    // splatted.
+
+    auto vecType = toVectorType(type);
+    Value vec = builder.create<vector::SplatOp>(loc, val, vecType);
+    mapping.map(orig, vec);
+    return vec;
+  };
+
+  llvm::DenseMap<Value, llvm::SmallVector<Value>> unpackedVals;
+
+  // Get unpacked values for provided `orig` value in source loop.
+  // Values are returned as `ValueRange` and not as vector value.
+  auto getUnpackedVals = [&](Value val) -> ValueRange {
+    // Use cached values if present.
+    auto it = unpackedVals.find(val);
+    if (it != unpackedVals.end())
+      return it->second;
+
+    // Values which are defined inside loop body are preemptively added to the
+    // cache and not handled here.
+
+    auto &ret = unpackedVals[val];
+    assert(ret.empty());
+    if (!isSupportedVecElem(val.getType())) {
+      // Non vectorizable value, it must be a value defined outside the loop,
+      // just replicate it.
+      ret.resize(factor, val);
+      return ret;
+    }
+
+    // Get vector value and extract elements from it.
+    auto vecVal = getVecVal(val);
+    ret.resize(factor);
+    for (auto i : llvm::seq(0u, factor)) {
+      Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
+      ret[i] = builder.create<vector::ExtractElementOp>(loc, vecVal, idx);
+    }
+    return ret;
+  };
+
+  // Add unpacked values to the cache.
+  auto setUnpackedVals = [&](Value origVal, ValueRange newVals) {
+    assert(newVals.size() == factor);
+    assert(unpackedVals.count(origVal) == 0);
+    unpackedVals[origVal].append(newVals.begin(), newVals.end());
+
+    auto type = origVal.getType();
+    if (!isSupportedVecElem(type))
+      return;
+
+    // If type is vectorizabale construct a vector add it to vector cache as
+    // well.
+    auto vecType = toVectorType(type);
+
+    Value vec = createPosionVec(vecType);
+    for (auto i : llvm::seq(0u, factor)) {
+      Value idx = builder.create<arith::ConstantIndexOp>(loc, i);
+      vec = builder.create<vector::InsertElementOp>(loc, newVals[i], vec, idx);
+    }
+    mapping.map(origVal, vec);
+  };
+
+  Value mask;
+
+  // Contruct mask value and cache it. If not a masked mode mask is always all
+  // 1s.
+  auto getMask = [&]() -> Value {
+    if (mask)
+      return mask;
+
+    OpFoldResult maskSize;
+    if (masked) {
+      Value size = getrIndexVarMult();
+      maskSize = builder.createOrFold<arith::SubIOp>(loc, count, size);
+    } else {
+      maskSize = builder.getIndexAttr(factor);
+    }
+    auto vecType = toVectorType(builder.getI1Type());
+    mask = builder.create<vector::CreateMaskOp>(loc, vecType, maskSize);
+
+    return mask;
+  };
+
+  auto canTriviallyVectorizeMemOp = [&](auto op) -> bool {
+    return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op, DL);
+  };
+
+  auto canGatherScatter = [&](auto op) {
+    return !!::canGatherScatterImpl(loop, op, DL);
+  };
+
+  // Get idices for vectorized memref load/store.
+  auto getMemrefVecIndices = [&](ValueRange indices) {
+    scalarMapping.clear();
+    scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars());
+
+    llvm::SmallVector<Value> ret(indices.size());
+    for (auto &&[i, val] : llvm::enumerate(indices)) {
+      if (val == origIndexVar) {
+        Value idx = getrIndexVarMult();
+        idx = builder.createOrFold<arith::AddIOp>(loc, idx, origLower);
+        ret[i] = idx;
+        continue;
+      }
+      ret[i] = scalarMapping.lookup(val);
+    }
+
+    return ret;
+  };
+
+  // Create vectorized memref load for specified non-vectorized load.
+  auto genLoad = [&](auto loadOp) {
+    auto indices = getMemrefVecIndices(loadOp.getIndices());
+    auto resType = toVectorType(loadOp.getResult().getType());
+    auto memref = loadOp.getMemRef();
+    Value vecLoad;
+    if (masked) {
+      auto mask = getMask();
+      auto init = createPosionVec(resType);
+      vecLoad = builder.create<vector::MaskedLoadOp>(loc, resType, memref,
+                                                     indices, mask, init);
+    } else {
+      vecLoad = builder.create<vector::LoadOp>(loc, resType, memref, indices);
+    }
+    mapping.map(loadOp.getResult(), vecLoad);
+  };
+
+  // Create vectorized memref store for specified non-vectorized store.
+  auto genStore = [&](auto storeOp) {
+    auto indices = getMemrefVecIndices(storeOp.getIndices());
+    auto value = getVecVal(storeOp.getValueToStore());
+    auto memref = storeOp.getMemRef();
+    if (masked) {
+      auto mask = getMask();
+      builder.create<vector::MaskedStoreOp>(loc, memref, indices, mask, value);
+    } else {
+      builder.create<vector::StoreOp>(loc, value, memref, indices);
+    }
+  };
+
+  llvm::SmallVector<Value> duplicatedArgs;
+  llvm::SmallVector<Value> duplicatedResults;
+
+  builder.setInsertionPointToStart(newLoop.getBody());
+  for (Operation &op : loop.getBody()->without_terminator()) {
+    loc = op.getLoc();
+    if (isSupportedVectorOp(op)) {
+      // If op can be vectorized, clone it with vectorized inputs and  update
+      // resuls to vectorized types.
+      for (auto arg : op.getOperands())
+        getVecVal(arg); // init mapper for op args
+
+      auto newOp = builder.clone(op, mapping);
+      for (auto res : newOp->getResults())
+        res.setType(toVectorType(res.getType()));
+
+      continue;
+    }
+
+    // Vectorize memref load/store ops, vector load/store are preffered over
+    // gather/scatter.
+    if (auto loadOp = dyn_cast<memref::LoadOp>(op)) {
+      if (canTriviallyVectorizeMemOp(loadOp)) {
+        genLoad(loadOp);
+        continue;
+      }
+      if (canGatherScatter(loadOp)) {
+        auto resType = toVectorType(loadOp.getResult().getType());
+        auto memref = loadOp.getMemRef();
+        auto mask = getMask();
+        auto indexVec = getVecVal(loadOp.getIndices()[0]);
+        auto init = createPosionVec(resType);
+
+        auto gather = builder.create<vector::GatherOp>(
+            loc, resType, memref, zero, indexVec, mask, init);
+        mapping.map(loadOp.getResult(), gather.getResult());
+        continue;
+      }
+    }
+
+    if (auto stor...
[truncated]

@@ -0,0 +1,648 @@
//===- SCFVectorize.cpp - SCF vectorization utilities ---------------------===//
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go to mlir/lib/Dialect/SCF/Transforms

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

};

auto canTriviallyVectorizeMemOp = [&](auto op) -> bool {
return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op, DL);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd explicit use hasValue here. !! looks like a typo.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cavTriviallyVectorizeMemOpImpl -> cav?? Is that a typo? And if yes, it means it's not being tested at all.

Copy link
Contributor Author

@Hardcode84 Hardcode84 Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to has_value() and fixed the cav typo.

return arith::FastMathFlags::none;
}

LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have something similar for scf.for? If so, can some of the implementation be reused?

Copy link
Contributor Author

@Hardcode84 Hardcode84 Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have SuperVectorize operating on affine loops, but it's too entangled with the affine dialect. Nothing for scf.for directly, I believe. Also, doing it on scf.for will be more involved as it will require dependency analysis while scf.parallel semantics is alredy guarantees loop iterations are independent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthias-springer recently I tried to factor out all of the reusable parts of affine in order to implement some kind of naive vectorization for scf.for in cases where it's obvious there are no loop-carried dependencies. It basically boiled down to needing extend the LoopLikeInterface to include things like getTripCount and also creating a new interface for memref|affine::store and memref|affine:load. I started to do it but then stopped because I figured people wouldn't go for it - I presumed the reaction would be "why when we have structured codegen". I'm still keen on doing it because I think it would be useful so if you'll support it/vouch for it then I can write up an RFC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, instead of trying to unify scf.for and scf.parallel vectorization under one pass we can can instead introduce scf.for -> scf.parallel uplifting pass and encapsulate dependencies analysis there. Such pass will also be useful outside the vectorization.

Comment on lines 73 to 75
auto memref = memOp.getMemRef();
auto type = cast<MemRefType>(memref.getType());
auto width = getTypeBitWidth(type.getElementType(), DL);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please spell out auto except in cases where the spelling is very verbose (this is a general convention/rule in mlir)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done (although IMO it's annoying and doesn't bring anything useful in most cases)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes I agree with you and sometimes I disagree with you (often clangd doesn't figure out the type fast enough or ever and I am stuck guessing). Either way it's pretty established convention 🤷

Comment on lines 279 to 291
auto toVectorType = [&](Type elemType) -> VectorType {
int64_t f = factor;
return VectorType::get(f, elemType);
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: i think this is short enough that it can be "inlined" at call sites.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's shorter as you don't need to pass factor val

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough

@makslevental
Copy link
Contributor

Just some driveby nits - will come back for a more thorough review of the algo/impl.

@MaheshRavishankar
Copy link
Contributor

This seems like a major feature add. Is there an RFC that describes what is being done here, how it is structured and how it can be extended.
Ill review in a meantime, but this is fairly big, so need a few days.

@Hardcode84
Copy link
Contributor Author

This seems like a major feature add. Is there an RFC that describes what is being done here, how it is structured and how it can be extended.

No RFC, but I can expand comments in code if needed.

@MaheshRavishankar
Copy link
Contributor

This seems like a major feature add. Is there an RFC that describes what is being done here, how it is structured and how it can be extended.

No RFC, but I can expand comments in code if needed.

This being a "new big feature" maybe an RFC is worth it. Comments in code would definitely help.

@Hardcode84 Hardcode84 force-pushed the scf_vectorize_cont branch from ab4302f to a4ea5d9 Compare June 8, 2024 22:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:scf mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants