From db95496d83bebfc1db2cbc1ac6c1d04d706b6499 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 5 Dec 2023 16:44:04 -0600 Subject: [PATCH 01/10] [mlir][scf] upstream numba's scf vectorizer --- mlir/include/mlir/Transforms/SCFVectorize.h | 49 ++ mlir/lib/Transforms/CMakeLists.txt | 1 + mlir/lib/Transforms/SCFVectorize.cpp | 661 ++++++++++++++++++++ 3 files changed, 711 insertions(+) create mode 100644 mlir/include/mlir/Transforms/SCFVectorize.h create mode 100644 mlir/lib/Transforms/SCFVectorize.cpp diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Transforms/SCFVectorize.h new file mode 100644 index 0000000000000..d754b38d5bc23 --- /dev/null +++ b/mlir/include/mlir/Transforms/SCFVectorize.h @@ -0,0 +1,49 @@ +//===- SCFVectorize.h - ------------------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_SCFVECTORIZE_H_ +#define MLIR_TRANSFORMS_SCFVECTORIZE_H_ + +#include +#include + +namespace mlir { +class OpBuilder; +class Pass; +struct LogicalResult; +namespace scf { +class ParallelOp; +} +} // namespace mlir + +namespace mlir { +struct SCFVectorizeInfo { + unsigned dim = 0; + unsigned factor = 0; + unsigned count = 0; + bool masked = false; +}; + +std::optional getLoopVectorizeInfo(mlir::scf::ParallelOp loop, + unsigned dim, + unsigned vectorBitWidth); + +struct SCFVectorizeParams { + unsigned dim = 0; + unsigned factor = 0; + bool masked = false; +}; + +mlir::LogicalResult vectorizeLoop(mlir::OpBuilder &builder, + mlir::scf::ParallelOp loop, + const SCFVectorizeParams ¶ms); + +std::unique_ptr createSCFVectorizePass(); +} // namespace mlir + +#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_ \ No newline at end of file diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 90c0298fb5e46..ed71c73c938ed 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms PrintIR.cpp RemoveDeadValues.cpp SCCP.cpp + SCFVectorize.cpp SROA.cpp StripDebugInfo.cpp SymbolDCE.cpp diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp new file mode 100644 index 0000000000000..d7545ee30e29a --- /dev/null +++ b/mlir/lib/Transforms/SCFVectorize.cpp @@ -0,0 +1,661 @@ +//===- ControlFlowSink.cpp - Code to perform control-flow sinking ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/SCFVectorize.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/Pass.h" + +static unsigned getTypeBitWidth(mlir::Type type) { + if (mlir::isa(type)) + return 64; // TODO: unhardcode + + if (type.isIntOrFloat()) + return type.getIntOrFloatBitWidth(); + + return 0; +} + +static unsigned getArgsTypeWidth(mlir::Operation &op) { + unsigned ret = 0; + for (auto arg : op.getOperands()) + ret = std::max(ret, getTypeBitWidth(arg.getType())); + + for (auto res : op.getResults()) + ret = std::max(ret, getTypeBitWidth(res.getType())); + + return ret; +} + +static bool isSupportedVectorOp(mlir::Operation &op) { + return op.hasTrait(); +} + +static bool isSupportedVecElem(mlir::Type type) { + return type.isIntOrIndexOrFloat(); +} + +static bool isRangePermutation(mlir::ValueRange val1, mlir::ValueRange val2) { + if (val1.size() != val2.size()) + return false; + + for (auto v1 : val1) { + auto it = llvm::find(val2, v1); + if (it == val2.end()) + return false; + } + return true; +} + +template +static std::optional +cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim, + Op memOp) { + auto loopIndexVars = loop.getInductionVars(); + assert(dim < loopIndexVars.size()); + auto memref = memOp.getMemRef(); + auto type = mlir::cast(memref.getType()); + auto width = getTypeBitWidth(type.getElementType()); + if (width == 0) + return std::nullopt; + + if (!type.getLayout().isIdentity()) + return std::nullopt; + + if (!isRangePermutation(memOp.getIndices(), loopIndexVars)) + return std::nullopt; + + if (memOp.getIndices().back() != loopIndexVars[dim]) + return std::nullopt; + + mlir::DominanceInfo dom; + if (!dom.properlyDominates(memref, loop)) + return std::nullopt; + + return width; +} + +static std::optional +cavTriviallyVectorizeMemOp(mlir::scf::ParallelOp loop, unsigned dim, + mlir::Operation &op) { + assert(dim < loop.getInductionVars().size()); + if (auto storeOp = mlir::dyn_cast(op)) + return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp); + + if (auto loadOp = mlir::dyn_cast(op)) + return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp); + + return std::nullopt; +} + +template +static bool isOp(mlir::Operation &op) { + return mlir::isa(op); +} + +static std::optional +getReductionKind(mlir::scf::ReduceOp op) { + mlir::Block &body = op.getReductionOperator().front(); + if (!llvm::hasSingleElement(body.without_terminator())) + return std::nullopt; + + mlir::Operation &redOp = body.front(); + + using fptr_t = bool (*)(mlir::Operation &); + using CC = mlir::vector::CombiningKind; + const std::pair handlers[] = { + // clang-format off + {&isOp, CC::ADD}, + {&isOp, CC::ADD}, + {&isOp, CC::MUL}, + {&isOp, CC::MUL}, + // clang-format on + }; + + for (auto &&[handler, cc] : handlers) { + if (handler(redOp)) + return cc; + } + + return std::nullopt; +} + +std::optional +mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim, + unsigned vectorBitwidth) { + assert(dim < loop.getStep().size()); + assert(vectorBitwidth > 0); + unsigned factor = vectorBitwidth / 8; + if (factor <= 1) + return std::nullopt; + + if (!mlir::isConstantIntValue(loop.getStep()[dim], 1)) + return std::nullopt; + + unsigned count = 0; + bool masked = true; + + for (mlir::Operation &op : loop.getBody()->without_terminator()) { + if (auto reduce = mlir::dyn_cast(op)) { + if (!getReductionKind(reduce)) + masked = false; + + continue; + } + + if (op.getNumRegions() > 0) + return std::nullopt; + + if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op)) { + auto newFactor = vectorBitwidth / *w; + if (newFactor > 1) { + factor = std::min(factor, newFactor); + ++count; + } + continue; + } + + if (!isSupportedVectorOp(op)) { + masked = false; + continue; + } + + auto width = getArgsTypeWidth(op); + if (width == 0) + return std::nullopt; + + auto newFactor = vectorBitwidth / width; + if (newFactor <= 1) + continue; + + factor = std::min(factor, newFactor); + + ++count; + } + + if (count == 0) + return std::nullopt; + + return SCFVectorizeInfo{dim, factor, count, masked}; +} + +static mlir::arith::FastMathFlags getFMF(mlir::Operation &op) { + if (auto fmf = mlir::dyn_cast(op)) + return fmf.getFastMathFlagsAttr().getValue(); + + return mlir::arith::FastMathFlags::none; +} + +mlir::LogicalResult +mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, + const mlir::SCFVectorizeParams ¶ms) { + auto dim = params.dim; + auto factor = params.factor; + auto masked = params.masked; + assert(dim < loop.getStep().size()); + assert(factor > 1); + assert(mlir::isConstantIntValue(loop.getStep()[dim], 1)); + + mlir::OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(loop); + + auto lower = llvm::to_vector(loop.getLowerBound()); + auto upper = llvm::to_vector(loop.getUpperBound()); + auto step = llvm::to_vector(loop.getStep()); + + auto loc = loop.getLoc(); + + auto origIndexVar = loop.getInductionVars()[dim]; + + mlir::Value factorVal = + builder.create(loc, factor); + + auto origLower = lower[dim]; + auto origUpper = upper[dim]; + mlir::Value count = + builder.create(loc, origUpper, origLower); + mlir::Value newCount; + if (masked) { + mlir::Value incCount = + builder.create(loc, count, factorVal); + mlir::Value one = builder.create(loc, 1); + incCount = builder.create(loc, incCount, one); + newCount = builder.create(loc, incCount, factorVal); + } else { + newCount = builder.create(loc, count, factorVal); + } + + mlir::Value zero = builder.create(loc, 0); + lower[dim] = zero; + upper[dim] = newCount; + + auto newLoop = builder.create(loc, lower, upper, step, + loop.getInitVals()); + auto newIndexVar = newLoop.getInductionVars()[dim]; + + auto toVectorType = [&](mlir::Type elemType) -> mlir::VectorType { + int64_t f = factor; + return mlir::VectorType::get(f, elemType); + }; + + mlir::IRMapping mapping; + mlir::IRMapping scalarMapping; + + auto createPosionVec = [&](mlir::VectorType vecType) -> mlir::Value { + return builder.create(loc, vecType, nullptr); + }; + + auto getVecVal = [&](mlir::Value orig) -> mlir::Value { + if (auto mapped = mapping.lookupOrNull(orig)) + return mapped; + + if (orig == origIndexVar) { + auto vecType = toVectorType(builder.getIndexType()); + llvm::SmallVector elems(factor); + for (auto i : llvm::seq(0u, factor)) + elems[i] = builder.getIndexAttr(i); + auto attr = mlir::DenseElementsAttr::get(vecType, elems); + mlir::Value vec = + builder.create(loc, vecType, attr); + + mlir::Value idx = + builder.create(loc, newIndexVar, factorVal); + idx = builder.create(loc, idx, origLower); + idx = builder.create(loc, idx, vecType); + vec = builder.create(loc, idx, vec); + mapping.map(orig, vec); + return vec; + } + auto type = orig.getType(); + assert(isSupportedVecElem(type)); + + mlir::Value val = orig; + auto origIndexVars = loop.getInductionVars(); + auto it = llvm::find(origIndexVars, orig); + if (it != origIndexVars.end()) + val = newLoop.getInductionVars()[it - origIndexVars.begin()]; + + auto vecType = toVectorType(type); + mlir::Value vec = builder.create(loc, val, vecType); + mapping.map(orig, vec); + return vec; + }; + + llvm::DenseMap> unpackedVals; + auto getUnpackedVals = [&](mlir::Value val) -> mlir::ValueRange { + auto it = unpackedVals.find(val); + if (it != unpackedVals.end()) + return it->second; + + auto &ret = unpackedVals[val]; + assert(ret.empty()); + if (!isSupportedVecElem(val.getType())) { + ret.resize(factor, val); + return ret; + } + + auto vecVal = getVecVal(val); + ret.resize(factor); + for (auto i : llvm::seq(0u, factor)) { + mlir::Value idx = builder.create(loc, i); + ret[i] = builder.create(loc, vecVal, idx); + } + return ret; + }; + + auto setUnpackedVals = [&](mlir::Value origVal, mlir::ValueRange newVals) { + assert(newVals.size() == factor); + assert(unpackedVals.count(origVal) == 0); + unpackedVals[origVal].append(newVals.begin(), newVals.end()); + + auto type = origVal.getType(); + if (!isSupportedVecElem(type)) + return; + + auto vecType = toVectorType(type); + + mlir::Value vec = createPosionVec(vecType); + for (auto i : llvm::seq(0u, factor)) { + mlir::Value idx = builder.create(loc, i); + vec = builder.create(loc, newVals[i], vec, + idx); + } + mapping.map(origVal, vec); + }; + + mlir::Value mask; + auto getMask = [&]() -> mlir::Value { + if (mask) + return mask; + + mlir::OpFoldResult maskSize; + if (masked) { + mlir::Value size = + builder.create(loc, factorVal, newIndexVar); + maskSize = + builder.create(loc, count, size).getResult(); + } else { + maskSize = builder.getIndexAttr(factor); + } + auto vecType = toVectorType(builder.getI1Type()); + mask = builder.create(loc, vecType, maskSize); + + return mask; + }; + + mlir::DominanceInfo dom; + + auto canTriviallyVectorizeMemOp = [&](auto op) -> bool { + return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op); + }; + + auto getMemrefVecIndices = [&](mlir::ValueRange indices) { + scalarMapping.clear(); + scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars()); + + llvm::SmallVector ret(indices.size()); + for (auto &&[i, val] : llvm::enumerate(indices)) { + if (val == origIndexVar) { + mlir::Value idx = + builder.create(loc, newIndexVar, factorVal); + idx = builder.create(loc, idx, origLower); + ret[i] = idx; + continue; + } + ret[i] = scalarMapping.lookup(val); + } + + return ret; + }; + + auto canGatherScatter = [&](auto op) { + auto memref = op.getMemRef(); + auto memrefType = mlir::cast(memref.getType()); + if (!isSupportedVecElem(memrefType.getElementType())) + return false; + + return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 && + memrefType.getLayout().isIdentity(); + }; + + auto genLoad = [&](auto loadOp) { + auto indices = getMemrefVecIndices(loadOp.getIndices()); + auto resType = toVectorType(loadOp.getResult().getType()); + auto memref = loadOp.getMemRef(); + mlir::Value vecLoad; + if (masked) { + auto mask = getMask(); + auto init = createPosionVec(resType); + vecLoad = builder.create(loc, resType, memref, + indices, mask, init); + } else { + vecLoad = + builder.create(loc, resType, memref, indices); + } + mapping.map(loadOp.getResult(), vecLoad); + }; + + auto genStore = [&](auto storeOp) { + auto indices = getMemrefVecIndices(storeOp.getIndices()); + auto value = getVecVal(storeOp.getValueToStore()); + auto memref = storeOp.getMemRef(); + if (masked) { + auto mask = getMask(); + builder.create(loc, memref, indices, mask, + value); + } else { + builder.create(loc, value, memref, indices); + } + }; + + llvm::SmallVector duplicatedArgs; + llvm::SmallVector duplicatedResults; + + builder.setInsertionPointToStart(newLoop.getBody()); + for (mlir::Operation &op : loop.getBody()->without_terminator()) { + loc = op.getLoc(); + if (isSupportedVectorOp(op)) { + for (auto arg : op.getOperands()) + getVecVal(arg); // init mapper for op args + + auto newOp = builder.clone(op, mapping); + for (auto res : newOp->getResults()) + res.setType(toVectorType(res.getType())); + + continue; + } + + if (auto reduceOp = mlir::dyn_cast(op)) { + scalarMapping.clear(); + auto &reduceBody = reduceOp.getReductionOperator().front(); + assert(reduceBody.getNumArguments() == 2); + + mlir::Value reduceVal; + if (auto redKind = getReductionKind(reduceOp)) { + mlir::Value redArg = getVecVal(reduceOp.getOperand()); + if (redArg) { + auto neutral = mlir::arith::getNeutralElement(&reduceBody.front()); + assert(neutral); + mlir::Value neutralVal = + builder.create(loc, *neutral); + mlir::Value neutralVec = builder.create( + loc, neutralVal, redArg.getType()); + auto mask = getMask(); + redArg = builder.create(loc, mask, redArg, + neutralVec); + } + + auto fmf = getFMF(reduceBody.front()); + reduceVal = builder.create(loc, *redKind, + redArg, fmf); + } else { + if (masked) + return op.emitError("Cannot vectorize op in masked mode"); + + auto reduceTerm = + mlir::cast(reduceBody.getTerminator()); + auto lhs = reduceBody.getArgument(0); + auto rhs = reduceBody.getArgument(1); + auto unpacked = getUnpackedVals(reduceOp.getOperand()); + assert(unpacked.size() == factor); + reduceVal = unpacked.front(); + for (auto i : llvm::seq(1u, factor)) { + mlir::Value val = unpacked[i]; + scalarMapping.map(lhs, reduceVal); + scalarMapping.map(rhs, val); + for (auto &redOp : reduceBody.without_terminator()) + builder.clone(redOp, scalarMapping); + + reduceVal = scalarMapping.lookupOrDefault(reduceTerm.getResult()); + } + } + scalarMapping.clear(); + scalarMapping.map(reduceOp.getOperand(), reduceVal); + builder.clone(op, scalarMapping); + continue; + } + + if (auto loadOp = mlir::dyn_cast(op)) { + if (canTriviallyVectorizeMemOp(loadOp)) { + genLoad(loadOp); + continue; + } + if (canGatherScatter(loadOp)) { + auto resType = toVectorType(loadOp.getResult().getType()); + auto memref = loadOp.getMemRef(); + auto mask = getMask(); + auto indexVec = getVecVal(loadOp.getIndices()[0]); + auto init = createPosionVec(resType); + + auto gather = builder.create( + loc, resType, memref, zero, indexVec, mask, init); + mapping.map(loadOp.getResult(), gather.getResult()); + continue; + } + } + + if (auto storeOp = mlir::dyn_cast(op)) { + if (canTriviallyVectorizeMemOp(storeOp)) { + genStore(storeOp); + continue; + } + if (canGatherScatter(storeOp)) { + auto memref = storeOp.getMemRef(); + auto value = getVecVal(storeOp.getValueToStore()); + auto mask = getMask(); + auto indexVec = getVecVal(storeOp.getIndices()[0]); + + builder.create(loc, memref, zero, indexVec, + mask, value); + } + } + + // Fallback: Failed to vectorize op, just duplicate it `factor` times + if (masked) + return op.emitError("Cannot vectorize op in masked mode"); + + scalarMapping.clear(); + + auto numArgs = op.getNumOperands(); + auto numResults = op.getNumResults(); + duplicatedArgs.resize(numArgs * factor); + duplicatedResults.resize(numResults * factor); + + for (auto &&[i, arg] : llvm::enumerate(op.getOperands())) { + auto unpacked = getUnpackedVals(arg); + assert(unpacked.size() == factor); + for (auto j : llvm::seq(0u, factor)) + duplicatedArgs[j * numArgs + i] = unpacked[j]; + } + + for (auto i : llvm::seq(0u, factor)) { + auto args = mlir::ValueRange(duplicatedArgs) + .drop_front(numArgs * i) + .take_front(numArgs); + scalarMapping.map(op.getOperands(), args); + auto results = builder.clone(op, scalarMapping)->getResults(); + + for (auto j : llvm::seq(0u, numResults)) + duplicatedResults[j * factor + i] = results[j]; + } + + for (auto i : llvm::seq(0u, numResults)) { + auto results = mlir::ValueRange(duplicatedResults) + .drop_front(factor * i) + .take_front(factor); + setUnpackedVals(op.getResult(i), results); + } + } + + if (masked) { + loop->replaceAllUsesWith(newLoop.getResults()); + loop->erase(); + } else { + builder.setInsertionPoint(loop); + mlir::Value newLower = + builder.create(loc, newCount, factorVal); + newLower = builder.create(loc, origLower, newLower); + + auto lowerCopy = llvm::to_vector(loop.getLowerBound()); + lowerCopy[dim] = newLower; + loop.getLowerBoundMutable().assign(lowerCopy); + loop.getInitValsMutable().assign(newLoop.getResults()); + } + + return mlir::success(); +} + +llvm::StringRef getVectorLengthName() { return "numba.vector_length"; } + +static std::optional getVectorLength(mlir::Operation *op) { + auto func = op->getParentOfType(); + if (!func) + return std::nullopt; + + auto attr = func->getAttrOfType(getVectorLengthName()); + if (!attr) + return std::nullopt; + + auto val = attr.getInt(); + if (val <= 0 || val > std::numeric_limits::max()) + return std::nullopt; + + return static_cast(val); +} + +namespace { +struct SCFVectorizePass + : public mlir::PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFVectorizePass) + + void getDependentDialects(mlir::DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + llvm::SmallVector< + std::pair> + toVectorize; + + auto getBenefit = [](const mlir::SCFVectorizeInfo &info) { + return info.factor * info.count * (int(info.masked) + 1); + }; + + getOperation()->walk([&](mlir::scf::ParallelOp loop) { + auto len = getVectorLength(loop); + if (!len) + return; + + std::optional best; + for (auto dim : llvm::seq(0u, loop.getNumLoops())) { + auto info = mlir::getLoopVectorizeInfo(loop, dim, *len); + if (!info) + continue; + + if (!best) { + best = *info; + continue; + } + + if (getBenefit(*info) > getBenefit(*best)) + best = *info; + } + + if (!best) + return; + + toVectorize.emplace_back( + loop, + mlir::SCFVectorizeParams{best->dim, best->factor, best->masked}); + }); + + if (toVectorize.empty()) + return markAllAnalysesPreserved(); + + mlir::OpBuilder builder(&getContext()); + for (auto &&[loop, params] : toVectorize) { + builder.setInsertionPoint(loop); + if (mlir::failed(mlir::vectorizeLoop(builder, loop, params))) + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr mlir::createSCFVectorizePass() { + return std::make_unique(); +} From c241fe2eb24f4c2dc21de412a68a371566d9bb17 Mon Sep 17 00:00:00 2001 From: max Date: Thu, 25 Apr 2024 10:57:05 -0500 Subject: [PATCH 02/10] get new stuff --- mlir/include/mlir/Transforms/SCFVectorize.h | 27 ++- mlir/lib/Transforms/SCFVectorize.cpp | 210 ++++++++++++++------ 2 files changed, 172 insertions(+), 65 deletions(-) diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Transforms/SCFVectorize.h index d754b38d5bc23..93a7864b976ec 100644 --- a/mlir/include/mlir/Transforms/SCFVectorize.h +++ b/mlir/include/mlir/Transforms/SCFVectorize.h @@ -22,23 +22,48 @@ class ParallelOp; } // namespace mlir namespace mlir { + +/// Loop vectorization info struct SCFVectorizeInfo { + /// Loop dimension on which to vectorize. unsigned dim = 0; + + /// Biggest vector width, in elements. unsigned factor = 0; + + /// Number of ops, which will be vectorized. unsigned count = 0; + + /// Can use masked vector ops for our of bounds memory accesses. bool masked = false; }; +/// Collect vectorization statistics on specified `scf.parallel` dimension. +/// Return `SCFVectorizeInfo` or `std::nullopt` if loop cannot be vectorized on +/// specified dimension. +/// +/// `vectorBitwidth` - maximum vector size, in bits. std::optional getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim, - unsigned vectorBitWidth); + unsigned vectorBitwidth); +/// Vectorization params struct SCFVectorizeParams { + /// Loop dimension on which to vectorize. unsigned dim = 0; + + /// Desired vector length, in elements unsigned factor = 0; + + /// Use masked vector ops for memory access outside loop bounds. bool masked = false; }; +/// Vectorize loop on specified dimension with specified factor. +/// +/// If `masked` is `true` and loop bound is not divisible by `factor`, instead +/// of generating second loop to process remainig iterations, extend loop count +/// and generate masked vector ops to handle out-of bounds memory accesses. mlir::LogicalResult vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, const SCFVectorizeParams ¶ms); diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp index d7545ee30e29a..13a9eca9cd2d3 100644 --- a/mlir/lib/Transforms/SCFVectorize.cpp +++ b/mlir/lib/Transforms/SCFVectorize.cpp @@ -16,7 +16,17 @@ #include "mlir/IR/IRMapping.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" - +#include +#include +#include +#include +#include +#include +#include +#include + +/// Return type bitwidth for vectorization purposes or 0 if type cannot be +/// vectorized. static unsigned getTypeBitWidth(mlir::Type type) { if (mlir::isa(type)) return 64; // TODO: unhardcode @@ -46,6 +56,8 @@ static bool isSupportedVecElem(mlir::Type type) { return type.isIntOrIndexOrFloat(); } +/// Check if one `ValueRange` is permutation of another, i.e. contains same +/// values, potentially in different order. static bool isRangePermutation(mlir::ValueRange val1, mlir::ValueRange val2) { if (val1.size() != val2.size()) return false; @@ -86,6 +98,10 @@ cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim, return width; } +/// Check if memref load/store can be converted into vectorized load/store +/// +/// Returns memref element bitwidth or `std::nullopt` if access cannot be +/// vectorized. static std::optional cavTriviallyVectorizeMemOp(mlir::scf::ParallelOp loop, unsigned dim, mlir::Operation &op) { @@ -104,9 +120,10 @@ static bool isOp(mlir::Operation &op) { return mlir::isa(op); } +/// Returns `vector.reduce` kind for specified `scf.parallel` reduce op ot +/// `std::nullopt` if reduction cannot be handled by `vector.reduce`. static std::optional -getReductionKind(mlir::scf::ReduceOp op) { - mlir::Block &body = op.getReductionOperator().front(); +getReductionKind(mlir::Block &body) { if (!llvm::hasSingleElement(body.without_terminator())) return std::nullopt; @@ -140,23 +157,31 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim, if (factor <= 1) return std::nullopt; + /// Only step==1 is supported for now. if (!mlir::isConstantIntValue(loop.getStep()[dim], 1)) return std::nullopt; unsigned count = 0; bool masked = true; - for (mlir::Operation &op : loop.getBody()->without_terminator()) { - if (auto reduce = mlir::dyn_cast(op)) { - if (!getReductionKind(reduce)) - masked = false; + /// Check if `scf.reduce` can be handled by `vector.reduce`. + /// If not we still can vectorize the loop but we cannot use masked + /// vectorize. + auto reduce = + mlir::cast(loop.getBody()->getTerminator()); + for (mlir::Region ® : reduce.getReductions()) { + if (!getReductionKind(reg.front())) + masked = false; - continue; - } + continue; + } + for (mlir::Operation &op : loop.getBody()->without_terminator()) { + /// Ops with nested regions are not supported yet. if (op.getNumRegions() > 0) return std::nullopt; + /// Check mem ops. if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op)) { auto newFactor = vectorBitwidth / *w; if (newFactor > 1) { @@ -166,6 +191,8 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim, continue; } + /// If met the op which cannot be vectorized, we can replicate it and still + /// potentially vectorize other ops, but we cannot use masked vectorize. if (!isSupportedVectorOp(op)) { masked = false; continue; @@ -184,12 +211,14 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim, ++count; } + /// No ops to vectorize. if (count == 0) return std::nullopt; return SCFVectorizeInfo{dim, factor, count, masked}; } +/// Get fastmath flags if ops support them or default (none). static mlir::arith::FastMathFlags getFMF(mlir::Operation &op) { if (auto fmf = mlir::dyn_cast(op)) return fmf.getFastMathFlagsAttr().getValue(); @@ -226,6 +255,8 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, mlir::Value count = builder.create(loc, origUpper, origLower); mlir::Value newCount; + + // Compute new loop count, ceildiv if masked, floordiv otherwise. if (masked) { mlir::Value incCount = builder.create(loc, count, factorVal); @@ -240,6 +271,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, lower[dim] = zero; upper[dim] = newCount; + // Vectorized loop. auto newLoop = builder.create(loc, lower, upper, step, loop.getInitVals()); auto newIndexVar = newLoop.getInductionVars()[dim]; @@ -256,10 +288,14 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, return builder.create(loc, vecType, nullptr); }; + // Get vector value in new loop for provided `orig` value in source loop. auto getVecVal = [&](mlir::Value orig) -> mlir::Value { + // Use cached value if present. if (auto mapped = mapping.lookupOrNull(orig)) return mapped; + // Vectorized loop index, loop index is divided by factor, so for factorN + // vectorized index will looks like `splat(idx) + (0, 1, ..., N - 1)` if (orig == origIndexVar) { auto vecType = toVectorType(builder.getIndexType()); llvm::SmallVector elems(factor); @@ -283,9 +319,16 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, mlir::Value val = orig; auto origIndexVars = loop.getInductionVars(); auto it = llvm::find(origIndexVars, orig); + + // If loop index, but not on vectorized dimension, just take new loop index + // and splat it. if (it != origIndexVars.end()) val = newLoop.getInductionVars()[it - origIndexVars.begin()]; + // Values which are defined inside loop body are preemptively added to the + // mapper and not handled here. Values defined outside body are just + // splatted. + auto vecType = toVectorType(type); mlir::Value vec = builder.create(loc, val, vecType); mapping.map(orig, vec); @@ -293,18 +336,28 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, }; llvm::DenseMap> unpackedVals; + + // Get unpacked values for provided `orig` value in source loop. + // Values are returned as `ValueRange` and not as vector value. auto getUnpackedVals = [&](mlir::Value val) -> mlir::ValueRange { + // Use cached values if present. auto it = unpackedVals.find(val); if (it != unpackedVals.end()) return it->second; + // Values which are defined inside loop body are preemptively added to the + // cache and not handled here. + auto &ret = unpackedVals[val]; assert(ret.empty()); if (!isSupportedVecElem(val.getType())) { + // Non vectorizable value, it must be a value defined outside the loop, + // just replicate it. ret.resize(factor, val); return ret; } + // Get vector value and extract elements from it. auto vecVal = getVecVal(val); ret.resize(factor); for (auto i : llvm::seq(0u, factor)) { @@ -314,6 +367,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, return ret; }; + // Add unpacked values to the cache. auto setUnpackedVals = [&](mlir::Value origVal, mlir::ValueRange newVals) { assert(newVals.size() == factor); assert(unpackedVals.count(origVal) == 0); @@ -323,6 +377,8 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, if (!isSupportedVecElem(type)) return; + // If type is vectorizabale construct a vector add it to vector cache as + // well. auto vecType = toVectorType(type); mlir::Value vec = createPosionVec(vecType); @@ -335,6 +391,9 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, }; mlir::Value mask; + + // Contruct mask value and cache it. If not a masked mode mask is always all + // 1s. auto getMask = [&]() -> mlir::Value { if (mask) return mask; @@ -360,6 +419,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op); }; + // Get idices for vectorized memref load/store. auto getMemrefVecIndices = [&](mlir::ValueRange indices) { scalarMapping.clear(); scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars()); @@ -379,6 +439,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, return ret; }; + // Check if memref access can be converted into gather/scatter. auto canGatherScatter = [&](auto op) { auto memref = op.getMemRef(); auto memrefType = mlir::cast(memref.getType()); @@ -389,6 +450,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, memrefType.getLayout().isIdentity(); }; + // Create vectorized memref load for specified non-vectorized load. auto genLoad = [&](auto loadOp) { auto indices = getMemrefVecIndices(loadOp.getIndices()); auto resType = toVectorType(loadOp.getResult().getType()); @@ -406,6 +468,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, mapping.map(loadOp.getResult(), vecLoad); }; + // Create vectorized memref store for specified non-vectorized store. auto genStore = [&](auto storeOp) { auto indices = getMemrefVecIndices(storeOp.getIndices()); auto value = getVecVal(storeOp.getValueToStore()); @@ -426,6 +489,8 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, for (mlir::Operation &op : loop.getBody()->without_terminator()) { loc = op.getLoc(); if (isSupportedVectorOp(op)) { + // If op can be vectorized, clone it with vectorized inputs and update + // resuls to vectorized types. for (auto arg : op.getOperands()) getVecVal(arg); // init mapper for op args @@ -436,56 +501,8 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, continue; } - if (auto reduceOp = mlir::dyn_cast(op)) { - scalarMapping.clear(); - auto &reduceBody = reduceOp.getReductionOperator().front(); - assert(reduceBody.getNumArguments() == 2); - - mlir::Value reduceVal; - if (auto redKind = getReductionKind(reduceOp)) { - mlir::Value redArg = getVecVal(reduceOp.getOperand()); - if (redArg) { - auto neutral = mlir::arith::getNeutralElement(&reduceBody.front()); - assert(neutral); - mlir::Value neutralVal = - builder.create(loc, *neutral); - mlir::Value neutralVec = builder.create( - loc, neutralVal, redArg.getType()); - auto mask = getMask(); - redArg = builder.create(loc, mask, redArg, - neutralVec); - } - - auto fmf = getFMF(reduceBody.front()); - reduceVal = builder.create(loc, *redKind, - redArg, fmf); - } else { - if (masked) - return op.emitError("Cannot vectorize op in masked mode"); - - auto reduceTerm = - mlir::cast(reduceBody.getTerminator()); - auto lhs = reduceBody.getArgument(0); - auto rhs = reduceBody.getArgument(1); - auto unpacked = getUnpackedVals(reduceOp.getOperand()); - assert(unpacked.size() == factor); - reduceVal = unpacked.front(); - for (auto i : llvm::seq(1u, factor)) { - mlir::Value val = unpacked[i]; - scalarMapping.map(lhs, reduceVal); - scalarMapping.map(rhs, val); - for (auto &redOp : reduceBody.without_terminator()) - builder.clone(redOp, scalarMapping); - - reduceVal = scalarMapping.lookupOrDefault(reduceTerm.getResult()); - } - } - scalarMapping.clear(); - scalarMapping.map(reduceOp.getOperand(), reduceVal); - builder.clone(op, scalarMapping); - continue; - } - + // Vectorize memref load/store ops, vector load/store are preffered over + // gather/scatter. if (auto loadOp = mlir::dyn_cast(op)) { if (canTriviallyVectorizeMemOp(loadOp)) { genLoad(loadOp); @@ -558,6 +575,70 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, } } + // Vectorize `scf.reduce` op. + auto reduceOp = + mlir::cast(loop.getBody()->getTerminator()); + llvm::SmallVector reduceVals; + reduceVals.reserve(reduceOp.getNumOperands()); + + for (auto &&[body, arg] : + llvm::zip(reduceOp.getReductions(), reduceOp.getOperands())) { + scalarMapping.clear(); + mlir::Block &reduceBody = body.front(); + assert(reduceBody.getNumArguments() == 2); + + mlir::Value reduceVal; + if (auto redKind = getReductionKind(reduceBody)) { + // Generate `vector.reduce` if possible. + mlir::Value redArg = getVecVal(arg); + if (redArg) { + auto neutral = mlir::arith::getNeutralElement(&reduceBody.front()); + assert(neutral); + mlir::Value neutralVal = + builder.create(loc, *neutral); + mlir::Value neutralVec = builder.create( + loc, neutralVal, redArg.getType()); + auto mask = getMask(); + redArg = builder.create(loc, mask, redArg, + neutralVec); + } + + auto fmf = getFMF(reduceBody.front()); + reduceVal = + builder.create(loc, *redKind, redArg, fmf); + } else { + if (masked) + return reduceOp.emitError("Cannot vectorize reduce op in masked mode"); + + // If `vector.reduce` cannot be used, unpack values and reduce them + // individually. + + auto reduceTerm = + mlir::cast(reduceBody.getTerminator()); + auto lhs = reduceBody.getArgument(0); + auto rhs = reduceBody.getArgument(1); + auto unpacked = getUnpackedVals(arg); + assert(unpacked.size() == factor); + reduceVal = unpacked.front(); + for (auto i : llvm::seq(1u, factor)) { + mlir::Value val = unpacked[i]; + scalarMapping.map(lhs, reduceVal); + scalarMapping.map(rhs, val); + for (auto &redOp : reduceBody.without_terminator()) + builder.clone(redOp, scalarMapping); + + reduceVal = scalarMapping.lookupOrDefault(reduceTerm.getResult()); + } + } + reduceVals.emplace_back(reduceVal); + } + + // Clone `scf.reduce` op to reduce across loop iterations. + if (!reduceVals.empty()) + builder.clone(*reduceOp)->setOperands(reduceVals); + + // If in masked mode remove old loop, otherwise update loop bounds to + // repurpose it for handling remaining values. if (masked) { loop->replaceAllUsesWith(newLoop.getResults()); loop->erase(); @@ -576,14 +657,12 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, return mlir::success(); } -llvm::StringRef getVectorLengthName() { return "numba.vector_length"; } - static std::optional getVectorLength(mlir::Operation *op) { auto func = op->getParentOfType(); if (!func) return std::nullopt; - auto attr = func->getAttrOfType(getVectorLengthName()); + auto attr = func->getAttrOfType("mlir.vector_length"); if (!attr) return std::nullopt; @@ -599,7 +678,8 @@ struct SCFVectorizePass : public mlir::PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFVectorizePass) - void getDependentDialects(mlir::DialectRegistry ®istry) const override { + virtual void + getDependentDialects(mlir::DialectRegistry ®istry) const override { registry.insert(); registry.insert(); registry.insert(); @@ -611,6 +691,8 @@ struct SCFVectorizePass std::pair> toVectorize; + // Simple heuristic: total number of elements processed by vector ops, but + // prefer masked mode over non-masked. auto getBenefit = [](const mlir::SCFVectorizeInfo &info) { return info.factor * info.count * (int(info.masked) + 1); }; @@ -658,4 +740,4 @@ struct SCFVectorizePass std::unique_ptr mlir::createSCFVectorizePass() { return std::make_unique(); -} +} \ No newline at end of file From c9a4bfe563013cfc64f0c44643c2c8e97b48757f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 2 Jun 2024 17:07:51 +0200 Subject: [PATCH 03/10] working on pass --- mlir/include/mlir/Transforms/SCFVectorize.h | 20 +- mlir/lib/Transforms/SCFVectorize.cpp | 449 +++++++----------- mlir/test/Transforms/test-scf-vectorize.mlir | 272 +++++++++++ mlir/test/lib/Transforms/CMakeLists.txt | 1 + mlir/test/lib/Transforms/TestSCFVectorize.cpp | 110 +++++ mlir/tools/mlir-opt/mlir-opt.cpp | 2 + 6 files changed, 570 insertions(+), 284 deletions(-) create mode 100644 mlir/test/Transforms/test-scf-vectorize.mlir create mode 100644 mlir/test/lib/Transforms/TestSCFVectorize.cpp diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Transforms/SCFVectorize.h index 93a7864b976ec..d2a5e3085ae37 100644 --- a/mlir/include/mlir/Transforms/SCFVectorize.h +++ b/mlir/include/mlir/Transforms/SCFVectorize.h @@ -9,12 +9,10 @@ #ifndef MLIR_TRANSFORMS_SCFVECTORIZE_H_ #define MLIR_TRANSFORMS_SCFVECTORIZE_H_ -#include #include namespace mlir { -class OpBuilder; -class Pass; +class DataLayout; struct LogicalResult; namespace scf { class ParallelOp; @@ -43,9 +41,9 @@ struct SCFVectorizeInfo { /// specified dimension. /// /// `vectorBitwidth` - maximum vector size, in bits. -std::optional getLoopVectorizeInfo(mlir::scf::ParallelOp loop, - unsigned dim, - unsigned vectorBitwidth); +std::optional +getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim, + unsigned vectorBitwidth, const DataLayout *DL = nullptr); /// Vectorization params struct SCFVectorizeParams { @@ -64,11 +62,9 @@ struct SCFVectorizeParams { /// If `masked` is `true` and loop bound is not divisible by `factor`, instead /// of generating second loop to process remainig iterations, extend loop count /// and generate masked vector ops to handle out-of bounds memory accesses. -mlir::LogicalResult vectorizeLoop(mlir::OpBuilder &builder, - mlir::scf::ParallelOp loop, - const SCFVectorizeParams ¶ms); - -std::unique_ptr createSCFVectorizePass(); +mlir::LogicalResult vectorizeLoop(mlir::scf::ParallelOp loop, + const SCFVectorizeParams ¶ms, + const DataLayout *DL = nullptr); } // namespace mlir -#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_ \ No newline at end of file +#endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_ diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp index 13a9eca9cd2d3..29e184e584a56 100644 --- a/mlir/lib/Transforms/SCFVectorize.cpp +++ b/mlir/lib/Transforms/SCFVectorize.cpp @@ -1,4 +1,4 @@ -//===- ControlFlowSink.cpp - Code to perform control-flow sinking ---------===// +//===- SCFVectorize.cpp - SCF vectorization utilities ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -9,27 +9,25 @@ #include "mlir/Transforms/SCFVectorize.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // getCombinerOpKind #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/IRMapping.h" -#include "mlir/Interfaces/FunctionInterfaces.h" -#include "mlir/Pass/Pass.h" -#include -#include -#include -#include -#include -#include -#include -#include + +using namespace mlir; + +static bool isSupportedVecElem(Type type) { return type.isIntOrIndexOrFloat(); } /// Return type bitwidth for vectorization purposes or 0 if type cannot be /// vectorized. -static unsigned getTypeBitWidth(mlir::Type type) { - if (mlir::isa(type)) - return 64; // TODO: unhardcode +static unsigned getTypeBitWidth(Type type, const DataLayout *DL) { + if (!isSupportedVecElem(type)) + return 0; + + if (DL) + return DL->getTypeSizeInBits(type); if (type.isIntOrFloat()) return type.getIntOrFloatBitWidth(); @@ -37,28 +35,24 @@ static unsigned getTypeBitWidth(mlir::Type type) { return 0; } -static unsigned getArgsTypeWidth(mlir::Operation &op) { +static unsigned getArgsTypeWidth(Operation &op, const DataLayout *DL) { unsigned ret = 0; for (auto arg : op.getOperands()) - ret = std::max(ret, getTypeBitWidth(arg.getType())); + ret = std::max(ret, getTypeBitWidth(arg.getType(), DL)); for (auto res : op.getResults()) - ret = std::max(ret, getTypeBitWidth(res.getType())); + ret = std::max(ret, getTypeBitWidth(res.getType(), DL)); return ret; } -static bool isSupportedVectorOp(mlir::Operation &op) { - return op.hasTrait(); -} - -static bool isSupportedVecElem(mlir::Type type) { - return type.isIntOrIndexOrFloat(); +static bool isSupportedVectorOp(Operation &op) { + return op.hasTrait(); } /// Check if one `ValueRange` is permutation of another, i.e. contains same /// values, potentially in different order. -static bool isRangePermutation(mlir::ValueRange val1, mlir::ValueRange val2) { +static bool isRangePermutation(ValueRange val1, ValueRange val2) { if (val1.size() != val2.size()) return false; @@ -72,13 +66,13 @@ static bool isRangePermutation(mlir::ValueRange val1, mlir::ValueRange val2) { template static std::optional -cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim, - Op memOp) { +cavTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp, + const DataLayout *DL) { auto loopIndexVars = loop.getInductionVars(); assert(dim < loopIndexVars.size()); auto memref = memOp.getMemRef(); - auto type = mlir::cast(memref.getType()); - auto width = getTypeBitWidth(type.getElementType()); + auto type = cast(memref.getType()); + auto width = getTypeBitWidth(type.getElementType(), DL); if (width == 0) return std::nullopt; @@ -91,7 +85,7 @@ cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim, if (memOp.getIndices().back() != loopIndexVars[dim]) return std::nullopt; - mlir::DominanceInfo dom; + DominanceInfo dom; if (!dom.properlyDominates(memref, loop)) return std::nullopt; @@ -103,54 +97,69 @@ cavTriviallyVectorizeMemOpImpl(mlir::scf::ParallelOp loop, unsigned dim, /// Returns memref element bitwidth or `std::nullopt` if access cannot be /// vectorized. static std::optional -cavTriviallyVectorizeMemOp(mlir::scf::ParallelOp loop, unsigned dim, - mlir::Operation &op) { +cavTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op, + const DataLayout *DL) { assert(dim < loop.getInductionVars().size()); - if (auto storeOp = mlir::dyn_cast(op)) - return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp); + if (auto storeOp = dyn_cast(op)) + return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL); + + if (auto loadOp = dyn_cast(op)) + return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL); + + return std::nullopt; +} + +template +static std::optional canGatherScatterImpl(scf::ParallelOp loop, Op op, + const DataLayout *DL) { + auto memref = op.getMemRef(); + auto memrefType = cast(memref.getType()); + auto width = getTypeBitWidth(memrefType.getElementType(), DL); + if (width == 0) + return std::nullopt; + + DominanceInfo dom; + return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 && + memrefType.getLayout().isIdentity(); +} - if (auto loadOp = mlir::dyn_cast(op)) - return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp); +// Check if memref access can be converted into gather/scatter. +/// +/// Returns memref element bitwidth or `std::nullopt` if access cannot be +/// vectorized. +static std::optional +canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) { + if (auto storeOp = dyn_cast(op)) + return canGatherScatterImpl(loop, storeOp, DL); + + if (auto loadOp = dyn_cast(op)) + return canGatherScatterImpl(loop, loadOp, DL); return std::nullopt; } -template -static bool isOp(mlir::Operation &op) { - return mlir::isa(op); +static std::optional cenVectorizeMemrefOp(scf::ParallelOp loop, + unsigned dim, Operation &op, + const DataLayout *DL) { + if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op, DL)) + return w; + + return canGatherScatter(loop, op, DL); } /// Returns `vector.reduce` kind for specified `scf.parallel` reduce op ot /// `std::nullopt` if reduction cannot be handled by `vector.reduce`. -static std::optional -getReductionKind(mlir::Block &body) { +static std::optional getReductionKind(Block &body) { if (!llvm::hasSingleElement(body.without_terminator())) return std::nullopt; - mlir::Operation &redOp = body.front(); - - using fptr_t = bool (*)(mlir::Operation &); - using CC = mlir::vector::CombiningKind; - const std::pair handlers[] = { - // clang-format off - {&isOp, CC::ADD}, - {&isOp, CC::ADD}, - {&isOp, CC::MUL}, - {&isOp, CC::MUL}, - // clang-format on - }; - - for (auto &&[handler, cc] : handlers) { - if (handler(redOp)) - return cc; - } - - return std::nullopt; + // TODO: Move getCombinerOpKind to vector dialect. + return linalg::getCombinerOpKind(&body.front()); } -std::optional -mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim, - unsigned vectorBitwidth) { +std::optional +mlir::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim, + unsigned vectorBitwidth, const DataLayout *DL) { assert(dim < loop.getStep().size()); assert(vectorBitwidth > 0); unsigned factor = vectorBitwidth / 8; @@ -158,7 +167,7 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim, return std::nullopt; /// Only step==1 is supported for now. - if (!mlir::isConstantIntValue(loop.getStep()[dim], 1)) + if (!isConstantIntValue(loop.getStep()[dim], 1)) return std::nullopt; unsigned count = 0; @@ -167,22 +176,21 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim, /// Check if `scf.reduce` can be handled by `vector.reduce`. /// If not we still can vectorize the loop but we cannot use masked /// vectorize. - auto reduce = - mlir::cast(loop.getBody()->getTerminator()); - for (mlir::Region ® : reduce.getReductions()) { + auto reduce = cast(loop.getBody()->getTerminator()); + for (Region ® : reduce.getReductions()) { if (!getReductionKind(reg.front())) masked = false; continue; } - for (mlir::Operation &op : loop.getBody()->without_terminator()) { + for (Operation &op : loop.getBody()->without_terminator()) { /// Ops with nested regions are not supported yet. if (op.getNumRegions() > 0) return std::nullopt; /// Check mem ops. - if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op)) { + if (auto w = cenVectorizeMemrefOp(loop, dim, op, DL)) { auto newFactor = vectorBitwidth / *w; if (newFactor > 1) { factor = std::min(factor, newFactor); @@ -198,7 +206,7 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim, continue; } - auto width = getArgsTypeWidth(op); + auto width = getArgsTypeWidth(op, DL); if (width == 0) return std::nullopt; @@ -219,26 +227,24 @@ mlir::getLoopVectorizeInfo(mlir::scf::ParallelOp loop, unsigned dim, } /// Get fastmath flags if ops support them or default (none). -static mlir::arith::FastMathFlags getFMF(mlir::Operation &op) { - if (auto fmf = mlir::dyn_cast(op)) +static arith::FastMathFlags getFMF(Operation &op) { + if (auto fmf = dyn_cast(op)) return fmf.getFastMathFlagsAttr().getValue(); - return mlir::arith::FastMathFlags::none; + return arith::FastMathFlags::none; } -mlir::LogicalResult -mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, - const mlir::SCFVectorizeParams ¶ms) { +LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop, + const SCFVectorizeParams ¶ms, + const DataLayout *DL) { auto dim = params.dim; auto factor = params.factor; auto masked = params.masked; assert(dim < loop.getStep().size()); assert(factor > 1); - assert(mlir::isConstantIntValue(loop.getStep()[dim], 1)); - - mlir::OpBuilder::InsertionGuard g(builder); - builder.setInsertionPoint(loop); + assert(isConstantIntValue(loop.getStep()[dim], 1)); + OpBuilder builder(loop); auto lower = llvm::to_vector(loop.getLowerBound()); auto upper = llvm::to_vector(loop.getUpperBound()); auto step = llvm::to_vector(loop.getStep()); @@ -247,49 +253,53 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, auto origIndexVar = loop.getInductionVars()[dim]; - mlir::Value factorVal = - builder.create(loc, factor); + Value factorVal = builder.create(loc, factor); auto origLower = lower[dim]; auto origUpper = upper[dim]; - mlir::Value count = - builder.create(loc, origUpper, origLower); - mlir::Value newCount; + Value count = builder.createOrFold(loc, origUpper, origLower); + Value newCount; // Compute new loop count, ceildiv if masked, floordiv otherwise. if (masked) { - mlir::Value incCount = - builder.create(loc, count, factorVal); - mlir::Value one = builder.create(loc, 1); - incCount = builder.create(loc, incCount, one); - newCount = builder.create(loc, incCount, factorVal); + newCount = builder.createOrFold(loc, count, factorVal); } else { - newCount = builder.create(loc, count, factorVal); + newCount = builder.createOrFold(loc, count, factorVal); } - mlir::Value zero = builder.create(loc, 0); + Value zero = builder.create(loc, 0); lower[dim] = zero; upper[dim] = newCount; // Vectorized loop. - auto newLoop = builder.create(loc, lower, upper, step, - loop.getInitVals()); + auto newLoop = builder.create(loc, lower, upper, step, + loop.getInitVals()); auto newIndexVar = newLoop.getInductionVars()[dim]; - auto toVectorType = [&](mlir::Type elemType) -> mlir::VectorType { + auto toVectorType = [&](Type elemType) -> VectorType { int64_t f = factor; - return mlir::VectorType::get(f, elemType); + return VectorType::get(f, elemType); + }; + + IRMapping mapping; + IRMapping scalarMapping; + + auto createPosionVec = [&](VectorType vecType) -> Value { + return builder.create(loc, vecType, nullptr); }; - mlir::IRMapping mapping; - mlir::IRMapping scalarMapping; + Value indexVarMult; + auto getrIndexVarMult = [&]() -> Value { + if (indexVarMult) + return indexVarMult; - auto createPosionVec = [&](mlir::VectorType vecType) -> mlir::Value { - return builder.create(loc, vecType, nullptr); + indexVarMult = + builder.createOrFold(loc, newIndexVar, factorVal); + return indexVarMult; }; // Get vector value in new loop for provided `orig` value in source loop. - auto getVecVal = [&](mlir::Value orig) -> mlir::Value { + auto getVecVal = [&](Value orig) -> Value { // Use cached value if present. if (auto mapped = mapping.lookupOrNull(orig)) return mapped; @@ -298,25 +308,23 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, // vectorized index will looks like `splat(idx) + (0, 1, ..., N - 1)` if (orig == origIndexVar) { auto vecType = toVectorType(builder.getIndexType()); - llvm::SmallVector elems(factor); + llvm::SmallVector elems(factor); for (auto i : llvm::seq(0u, factor)) elems[i] = builder.getIndexAttr(i); - auto attr = mlir::DenseElementsAttr::get(vecType, elems); - mlir::Value vec = - builder.create(loc, vecType, attr); - - mlir::Value idx = - builder.create(loc, newIndexVar, factorVal); - idx = builder.create(loc, idx, origLower); - idx = builder.create(loc, idx, vecType); - vec = builder.create(loc, idx, vec); + auto attr = DenseElementsAttr::get(vecType, elems); + Value vec = builder.create(loc, vecType, attr); + + Value idx = getrIndexVarMult(); + idx = builder.createOrFold(loc, idx, origLower); + idx = builder.create(loc, idx, vecType); + vec = builder.createOrFold(loc, idx, vec); mapping.map(orig, vec); return vec; } auto type = orig.getType(); assert(isSupportedVecElem(type)); - mlir::Value val = orig; + Value val = orig; auto origIndexVars = loop.getInductionVars(); auto it = llvm::find(origIndexVars, orig); @@ -330,16 +338,16 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, // splatted. auto vecType = toVectorType(type); - mlir::Value vec = builder.create(loc, val, vecType); + Value vec = builder.create(loc, val, vecType); mapping.map(orig, vec); return vec; }; - llvm::DenseMap> unpackedVals; + llvm::DenseMap> unpackedVals; // Get unpacked values for provided `orig` value in source loop. // Values are returned as `ValueRange` and not as vector value. - auto getUnpackedVals = [&](mlir::Value val) -> mlir::ValueRange { + auto getUnpackedVals = [&](Value val) -> ValueRange { // Use cached values if present. auto it = unpackedVals.find(val); if (it != unpackedVals.end()) @@ -361,14 +369,14 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, auto vecVal = getVecVal(val); ret.resize(factor); for (auto i : llvm::seq(0u, factor)) { - mlir::Value idx = builder.create(loc, i); - ret[i] = builder.create(loc, vecVal, idx); + Value idx = builder.create(loc, i); + ret[i] = builder.create(loc, vecVal, idx); } return ret; }; // Add unpacked values to the cache. - auto setUnpackedVals = [&](mlir::Value origVal, mlir::ValueRange newVals) { + auto setUnpackedVals = [&](Value origVal, ValueRange newVals) { assert(newVals.size() == factor); assert(unpackedVals.count(origVal) == 0); unpackedVals[origVal].append(newVals.begin(), newVals.end()); @@ -381,55 +389,53 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, // well. auto vecType = toVectorType(type); - mlir::Value vec = createPosionVec(vecType); + Value vec = createPosionVec(vecType); for (auto i : llvm::seq(0u, factor)) { - mlir::Value idx = builder.create(loc, i); - vec = builder.create(loc, newVals[i], vec, - idx); + Value idx = builder.create(loc, i); + vec = builder.create(loc, newVals[i], vec, idx); } mapping.map(origVal, vec); }; - mlir::Value mask; + Value mask; // Contruct mask value and cache it. If not a masked mode mask is always all // 1s. - auto getMask = [&]() -> mlir::Value { + auto getMask = [&]() -> Value { if (mask) return mask; - mlir::OpFoldResult maskSize; + OpFoldResult maskSize; if (masked) { - mlir::Value size = - builder.create(loc, factorVal, newIndexVar); - maskSize = - builder.create(loc, count, size).getResult(); + Value size = getrIndexVarMult(); + maskSize = builder.createOrFold(loc, count, size); } else { maskSize = builder.getIndexAttr(factor); } auto vecType = toVectorType(builder.getI1Type()); - mask = builder.create(loc, vecType, maskSize); + mask = builder.create(loc, vecType, maskSize); return mask; }; - mlir::DominanceInfo dom; - auto canTriviallyVectorizeMemOp = [&](auto op) -> bool { - return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op); + return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op, DL); + }; + + auto canGatherScatter = [&](auto op) { + return !!::canGatherScatterImpl(loop, op, DL); }; // Get idices for vectorized memref load/store. - auto getMemrefVecIndices = [&](mlir::ValueRange indices) { + auto getMemrefVecIndices = [&](ValueRange indices) { scalarMapping.clear(); scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars()); - llvm::SmallVector ret(indices.size()); + llvm::SmallVector ret(indices.size()); for (auto &&[i, val] : llvm::enumerate(indices)) { if (val == origIndexVar) { - mlir::Value idx = - builder.create(loc, newIndexVar, factorVal); - idx = builder.create(loc, idx, origLower); + Value idx = getrIndexVarMult(); + idx = builder.createOrFold(loc, idx, origLower); ret[i] = idx; continue; } @@ -439,31 +445,19 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, return ret; }; - // Check if memref access can be converted into gather/scatter. - auto canGatherScatter = [&](auto op) { - auto memref = op.getMemRef(); - auto memrefType = mlir::cast(memref.getType()); - if (!isSupportedVecElem(memrefType.getElementType())) - return false; - - return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 && - memrefType.getLayout().isIdentity(); - }; - // Create vectorized memref load for specified non-vectorized load. auto genLoad = [&](auto loadOp) { auto indices = getMemrefVecIndices(loadOp.getIndices()); auto resType = toVectorType(loadOp.getResult().getType()); auto memref = loadOp.getMemRef(); - mlir::Value vecLoad; + Value vecLoad; if (masked) { auto mask = getMask(); auto init = createPosionVec(resType); - vecLoad = builder.create(loc, resType, memref, - indices, mask, init); + vecLoad = builder.create(loc, resType, memref, + indices, mask, init); } else { - vecLoad = - builder.create(loc, resType, memref, indices); + vecLoad = builder.create(loc, resType, memref, indices); } mapping.map(loadOp.getResult(), vecLoad); }; @@ -475,18 +469,17 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, auto memref = storeOp.getMemRef(); if (masked) { auto mask = getMask(); - builder.create(loc, memref, indices, mask, - value); + builder.create(loc, memref, indices, mask, value); } else { - builder.create(loc, value, memref, indices); + builder.create(loc, value, memref, indices); } }; - llvm::SmallVector duplicatedArgs; - llvm::SmallVector duplicatedResults; + llvm::SmallVector duplicatedArgs; + llvm::SmallVector duplicatedResults; builder.setInsertionPointToStart(newLoop.getBody()); - for (mlir::Operation &op : loop.getBody()->without_terminator()) { + for (Operation &op : loop.getBody()->without_terminator()) { loc = op.getLoc(); if (isSupportedVectorOp(op)) { // If op can be vectorized, clone it with vectorized inputs and update @@ -503,7 +496,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, // Vectorize memref load/store ops, vector load/store are preffered over // gather/scatter. - if (auto loadOp = mlir::dyn_cast(op)) { + if (auto loadOp = dyn_cast(op)) { if (canTriviallyVectorizeMemOp(loadOp)) { genLoad(loadOp); continue; @@ -515,14 +508,14 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, auto indexVec = getVecVal(loadOp.getIndices()[0]); auto init = createPosionVec(resType); - auto gather = builder.create( + auto gather = builder.create( loc, resType, memref, zero, indexVec, mask, init); mapping.map(loadOp.getResult(), gather.getResult()); continue; } } - if (auto storeOp = mlir::dyn_cast(op)) { + if (auto storeOp = dyn_cast(op)) { if (canTriviallyVectorizeMemOp(storeOp)) { genStore(storeOp); continue; @@ -533,8 +526,9 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, auto mask = getMask(); auto indexVec = getVecVal(storeOp.getIndices()[0]); - builder.create(loc, memref, zero, indexVec, - mask, value); + builder.create(loc, memref, zero, indexVec, mask, + value); + continue; } } @@ -557,7 +551,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, } for (auto i : llvm::seq(0u, factor)) { - auto args = mlir::ValueRange(duplicatedArgs) + auto args = ValueRange(duplicatedArgs) .drop_front(numArgs * i) .take_front(numArgs); scalarMapping.map(op.getOperands(), args); @@ -568,7 +562,7 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, } for (auto i : llvm::seq(0u, numResults)) { - auto results = mlir::ValueRange(duplicatedResults) + auto results = ValueRange(duplicatedResults) .drop_front(factor * i) .take_front(factor); setUnpackedVals(op.getResult(i), results); @@ -576,36 +570,33 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, } // Vectorize `scf.reduce` op. - auto reduceOp = - mlir::cast(loop.getBody()->getTerminator()); - llvm::SmallVector reduceVals; + auto reduceOp = cast(loop.getBody()->getTerminator()); + llvm::SmallVector reduceVals; reduceVals.reserve(reduceOp.getNumOperands()); for (auto &&[body, arg] : llvm::zip(reduceOp.getReductions(), reduceOp.getOperands())) { scalarMapping.clear(); - mlir::Block &reduceBody = body.front(); + Block &reduceBody = body.front(); assert(reduceBody.getNumArguments() == 2); - mlir::Value reduceVal; + Value reduceVal; if (auto redKind = getReductionKind(reduceBody)) { // Generate `vector.reduce` if possible. - mlir::Value redArg = getVecVal(arg); + Value redArg = getVecVal(arg); if (redArg) { - auto neutral = mlir::arith::getNeutralElement(&reduceBody.front()); + auto neutral = arith::getNeutralElement(&reduceBody.front()); assert(neutral); - mlir::Value neutralVal = - builder.create(loc, *neutral); - mlir::Value neutralVec = builder.create( - loc, neutralVal, redArg.getType()); + Value neutralVal = builder.create(loc, *neutral); + Value neutralVec = + builder.create(loc, neutralVal, redArg.getType()); auto mask = getMask(); - redArg = builder.create(loc, mask, redArg, - neutralVec); + redArg = builder.create(loc, mask, redArg, neutralVec); } auto fmf = getFMF(reduceBody.front()); reduceVal = - builder.create(loc, *redKind, redArg, fmf); + builder.create(loc, *redKind, redArg, fmf); } else { if (masked) return reduceOp.emitError("Cannot vectorize reduce op in masked mode"); @@ -613,15 +604,14 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, // If `vector.reduce` cannot be used, unpack values and reduce them // individually. - auto reduceTerm = - mlir::cast(reduceBody.getTerminator()); + auto reduceTerm = cast(reduceBody.getTerminator()); auto lhs = reduceBody.getArgument(0); auto rhs = reduceBody.getArgument(1); auto unpacked = getUnpackedVals(arg); assert(unpacked.size() == factor); reduceVal = unpacked.front(); for (auto i : llvm::seq(1u, factor)) { - mlir::Value val = unpacked[i]; + Value val = unpacked[i]; scalarMapping.map(lhs, reduceVal); scalarMapping.map(rhs, val); for (auto &redOp : reduceBody.without_terminator()) @@ -644,9 +634,9 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, loop->erase(); } else { builder.setInsertionPoint(loop); - mlir::Value newLower = - builder.create(loc, newCount, factorVal); - newLower = builder.create(loc, origLower, newLower); + Value newLower = + builder.createOrFold(loc, newCount, factorVal); + newLower = builder.createOrFold(loc, origLower, newLower); auto lowerCopy = llvm::to_vector(loop.getLowerBound()); lowerCopy[dim] = newLower; @@ -654,90 +644,5 @@ mlir::vectorizeLoop(mlir::OpBuilder &builder, mlir::scf::ParallelOp loop, loop.getInitValsMutable().assign(newLoop.getResults()); } - return mlir::success(); -} - -static std::optional getVectorLength(mlir::Operation *op) { - auto func = op->getParentOfType(); - if (!func) - return std::nullopt; - - auto attr = func->getAttrOfType("mlir.vector_length"); - if (!attr) - return std::nullopt; - - auto val = attr.getInt(); - if (val <= 0 || val > std::numeric_limits::max()) - return std::nullopt; - - return static_cast(val); + return success(); } - -namespace { -struct SCFVectorizePass - : public mlir::PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SCFVectorizePass) - - virtual void - getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - } - - void runOnOperation() override { - llvm::SmallVector< - std::pair> - toVectorize; - - // Simple heuristic: total number of elements processed by vector ops, but - // prefer masked mode over non-masked. - auto getBenefit = [](const mlir::SCFVectorizeInfo &info) { - return info.factor * info.count * (int(info.masked) + 1); - }; - - getOperation()->walk([&](mlir::scf::ParallelOp loop) { - auto len = getVectorLength(loop); - if (!len) - return; - - std::optional best; - for (auto dim : llvm::seq(0u, loop.getNumLoops())) { - auto info = mlir::getLoopVectorizeInfo(loop, dim, *len); - if (!info) - continue; - - if (!best) { - best = *info; - continue; - } - - if (getBenefit(*info) > getBenefit(*best)) - best = *info; - } - - if (!best) - return; - - toVectorize.emplace_back( - loop, - mlir::SCFVectorizeParams{best->dim, best->factor, best->masked}); - }); - - if (toVectorize.empty()) - return markAllAnalysesPreserved(); - - mlir::OpBuilder builder(&getContext()); - for (auto &&[loop, params] : toVectorize) { - builder.setInsertionPoint(loop); - if (mlir::failed(mlir::vectorizeLoop(builder, loop, params))) - return signalPassFailure(); - } - } -}; -} // namespace - -std::unique_ptr mlir::createSCFVectorizePass() { - return std::make_unique(); -} \ No newline at end of file diff --git a/mlir/test/Transforms/test-scf-vectorize.mlir b/mlir/test/Transforms/test-scf-vectorize.mlir new file mode 100644 index 0000000000000..f4a817b44aa39 --- /dev/null +++ b/mlir/test/Transforms/test-scf-vectorize.mlir @@ -0,0 +1,272 @@ +// RUN: mlir-opt %s --test-scf-vectorize=vector-bitwidth=128 -split-input-file | FileCheck %s + +// CHECK-LABEL: @test +// CHECK-SAME: (%[[A:.*]]: memref, %[[B:.*]]: memref, %[[C:.*]]: memref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[COUNT:.*]] = arith.ceildivsi %[[DIM]], %[[C4]] : index +// CHECK: scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) { +// CHECK: %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index +// CHECK: %[[M:.*]] = arith.subi %[[DIM]], %[[MULT]] : index +// CHECK: %[[MASK:.*]] = vector.create_mask %[[M]] : vector<4xi1> +// CHECK: %[[P:.*]] = ub.poison : vector<4xi32> +// CHECK: %[[A_VAL:.*]] = vector.maskedload %[[A]][%[[MULT]]], %[[MASK]], %[[P]] : memref, vector<4xi1>, vector<4xi32> into vector<4xi32> +// CHECK: %[[P:.*]] = ub.poison : vector<4xi32> +// CHECK: %[[B_VAL:.*]] = vector.maskedload %[[B]][%[[MULT]]], %[[MASK]], %[[P]] : memref, vector<4xi1>, vector<4xi32> into vector<4xi32> +// CHECK: %[[RES:.*]] = arith.addi %[[A_VAL]], %[[B_VAL]] : vector<4xi32> +// CHECK: vector.maskedstore %[[C]][%1], %[[MASK]], %[[RES]] : memref, vector<4xi1>, vector<4xi32> +// CHECK: scf.reduce +func.func @test(%A: memref, %B: memref, %C: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %count = memref.dim %A, %c0 : memref + scf.parallel (%i) = (%c0) to (%count) step (%c1) { + %1 = memref.load %A[%i] : memref + %2 = memref.load %B[%i] : memref + %3 = arith.addi %1, %2 : i32 + memref.store %3, %C[%i] : memref + } + return +} + +// ----- + +// CHECK-LABEL: @test +// CHECK-SAME: (%[[A:.*]]: memref, %[[B:.*]]: memref, %[[C:.*]]: memref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[COUNT:.*]] = arith.ceildivsi %[[DIM]], %[[C4]] : index +// CHECK: scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) { +// CHECK: %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index +// CHECK: %[[M:.*]] = arith.subi %[[DIM]], %[[MULT]] : index +// CHECK: %[[MASK:.*]] = vector.create_mask %[[M]] : vector<4xi1> +// CHECK: %[[P:.*]] = ub.poison : vector<4xindex> +// CHECK: %[[A_VAL:.*]] = vector.maskedload %[[A]][%[[MULT]]], %[[MASK]], %[[P]] : memref, vector<4xi1>, vector<4xindex> into vector<4xindex> +// CHECK: %[[P:.*]] = ub.poison : vector<4xindex> +// CHECK: %[[B_VAL:.*]] = vector.maskedload %[[B]][%[[MULT]]], %[[MASK]], %[[P]] : memref, vector<4xi1>, vector<4xindex> into vector<4xindex> +// CHECK: %[[RES:.*]] = arith.addi %[[A_VAL]], %[[B_VAL]] : vector<4xindex> +// CHECK: vector.maskedstore %[[C]][%1], %[[MASK]], %[[RES]] : memref, vector<4xi1>, vector<4xindex> +// CHECK: scf.reduce + +module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry> } { +func.func @test(%A: memref, %B: memref, %C: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %count = memref.dim %A, %c0 : memref + scf.parallel (%i) = (%c0) to (%count) step (%c1) { + %1 = memref.load %A[%i] : memref + %2 = memref.load %B[%i] : memref + %3 = arith.addi %1, %2 : index + memref.store %3, %C[%i] : memref + } + return +} +} + +// ----- + +func.func private @non_vectorizable(i32) -> (i32) + +// CHECK-LABEL: @test +// CHECK-SAME: (%[[A:.*]]: memref, %[[B:.*]]: memref, %[[C:.*]]: memref) { +// CHECK: %[[C00:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = memref.dim %[[A]], %[[C00]] : memref +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[COUNT:.*]] = arith.divsi %[[DIM]], %[[C4]] : index +// CHECK: scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) { +// CHECK: %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index +// CHECK: %[[A_VAL:.*]] = vector.load %[[A]][%[[MULT]]] : memref, vector<4xi32> +// CHECK: %[[B_VAL:.*]] = vector.load %[[B]][%[[MULT]]] : memref, vector<4xi32> +// CHECK: %[[R1:.*]] = arith.addi %[[A_VAL]], %[[B_VAL]] : vector<4xi32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[E0:.*]] = vector.extractelement %[[R1]][%[[C0]] : index] : vector<4xi32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[E1:.*]] = vector.extractelement %[[R1]][%[[C1]] : index] : vector<4xi32> +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[E2:.*]] = vector.extractelement %[[R1]][%[[C2]] : index] : vector<4xi32> +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[E3:.*]] = vector.extractelement %[[R1]][%[[C3]] : index] : vector<4xi32> +// CHECK: %[[R2:.*]] = func.call @non_vectorizable(%[[E0]]) : (i32) -> i32 +// CHECK: %[[R3:.*]] = func.call @non_vectorizable(%[[E1]]) : (i32) -> i32 +// CHECK: %[[R4:.*]] = func.call @non_vectorizable(%[[E2]]) : (i32) -> i32 +// CHECK: %[[R5:.*]] = func.call @non_vectorizable(%[[E3]]) : (i32) -> i32 +// CHECK: %[[RES1:.*]] = ub.poison : vector<4xi32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[RES2:.*]] = vector.insertelement %[[R2]], %[[RES1]][%[[C0]] : index] : vector<4xi32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[RES3:.*]] = vector.insertelement %[[R3]], %[[RES2]][%[[C1]] : index] : vector<4xi32> +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[RES4:.*]] = vector.insertelement %[[R4]], %[[RES3]][%[[C2]] : index] : vector<4xi32> +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[RES5:.*]] = vector.insertelement %[[R5]], %[[RES4]][%[[C3]] : index] : vector<4xi32> +// CHECK: vector.store %[[RES5]], %[[C]][%[[MULT]]] : memref, vector<4xi32> +// CHECK: scf.reduce +// CHECK: } +// CHECK: %[[UB1:.*]] = arith.muli %[[COUNT]], %[[C4]] : index +// CHECK: %[[UB2:.*]] = arith.addi %[[UB1]], %[[C00]] : index +// CHECK: scf.parallel (%[[I:.*]]) = (%[[UB2]]) to (%[[DIM]]) step (%{{.*}}) { +// CHECK: %[[A_VAL:.*]] = memref.load %[[A]][%[[I]]] : memref +// CHECK: %[[B_VAL:.*]] = memref.load %[[B]][%[[I]]] : memref +// CHECK: %[[R1:.*]] = arith.addi %[[A_VAL:.*]], %[[B_VAL:.*]] : i32 +// CHECK: %[[R2:.*]] = func.call @non_vectorizable(%[[R1]]) : (i32) -> i32 +// CHECK: memref.store %[[R2]], %[[C]][%[[I]]] : memref +// CHECK: scf.reduce +// CHECK: } +func.func @test(%A: memref, %B: memref, %C: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %count = memref.dim %A, %c0 : memref + scf.parallel (%i) = (%c0) to (%count) step (%c1) { + %1 = memref.load %A[%i] : memref + %2 = memref.load %B[%i] : memref + %3 = arith.addi %1, %2 : i32 + %4 = func.call @non_vectorizable(%3) : (i32) -> (i32) + memref.store %4, %C[%i] : memref + } + return +} + +// ----- + +// CHECK-LABEL: @test +// CHECK-SAME: (%[[A:.*]]: memref, %[[B:.*]]: memref, %[[C:.*]]: memref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[COUNT:.*]] = arith.ceildivsi %[[DIM]], %[[C4]] : index +// CHECK: scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) { +// CHECK: %[[OFFSETS:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> +// CHECK: %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index +// CHECK: %[[O1:.*]] = vector.splat %[[MULT]] : vector<4xindex> +// CHECK: %[[O2:.*]] = arith.addi %[[O1]], %[[OFFSETS]] : vector<4xindex> +// CHECK: %[[O3:.*]] = vector.splat %[[C2]] : vector<4xindex> +// CHECK: %[[O4:.*]] = arith.muli %[[O2]], %[[O3]] : vector<4xindex> +// CHECK: %[[M:.*]] = arith.subi %[[DIM]], %[[MULT]] : index +// CHECK: %[[MASK:.*]] = vector.create_mask %[[M]] : vector<4xi1> +// CHECK: %[[P:.*]] = ub.poison : vector<4xindex> +// CHECK: %[[A_VAL:.*]] = vector.gather %arg0[%{{.*}}] [%[[O4]]], %[[MASK]], %[[P]] : memref, vector<4xindex>, vector<4xi1>, vector<4xindex> into vector<4xindex> +// CHECK: %[[P:.*]] = ub.poison : vector<4xindex> +// CHECK: %[[B_VAL:.*]] = vector.maskedload %[[B]][%[[MULT]]], %[[MASK]], %[[P]] : memref, vector<4xi1>, vector<4xindex> into vector<4xindex> +// CHECK: %[[RES:.*]] = arith.addi %[[A_VAL]], %[[B_VAL]] : vector<4xindex> +// CHECK: vector.scatter %[[C]][%{{.*}}] [%[[O4]]], %[[MASK]], %[[RES]] : memref, vector<4xindex>, vector<4xi1>, vector<4xindex> +// CHECK: scf.reduce + +module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry> } { +func.func @test(%A: memref, %B: memref, %C: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %count = memref.dim %A, %c0 : memref + scf.parallel (%i) = (%c0) to (%count) step (%c1) { + %0 = arith.muli %i, %c2 : index + %1 = memref.load %A[%0] : memref + %2 = memref.load %B[%i] : memref + %3 = arith.addi %1, %2 : index + memref.store %3, %C[%0] : memref + } + return +} +} + +// ----- + +// CHECK-LABEL: @test +// CHECK-SAME: (%[[A:.*]]: memref) -> f32 { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[INIT:.*]] = arith.constant 0.0{{.*}} : f32 +// CHECK: %[[DIM:.*]] = memref.dim %[[A]], %[[C0]] : memref +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[COUNT:.*]] = arith.ceildivsi %[[DIM]], %[[C4]] : index +// CHECK: %[[RES:.*]] = scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) init (%[[INIT]]) -> f32 { +// CHECK: %[[MULT:.*]] = arith.muli %[[I]], %[[C4]] : index +// CHECK: %[[M:.*]] = arith.subi %[[DIM]], %[[MULT]] : index +// CHECK: %[[MASK:.*]] = vector.create_mask %[[M]] : vector<4xi1> +// CHECK: %[[P:.*]] = ub.poison : vector<4xf32> +// CHECK: %[[A_VAL:.*]] = vector.maskedload %[[A]][%[[MULT]]], %[[MASK]], %[[P]] : memref, vector<4xi1>, vector<4xf32> into vector<4xf32> +// CHECK: %[[N:.*]] = arith.constant 0.0{{.*}} : f32 +// CHECK: %[[N_SPLAT:.*]] = vector.splat %[[N]] : vector<4xf32> +// CHECK: %[[RED1:.*]] = arith.select %[[MASK]], %[[A_VAL]], %[[N_SPLAT]] : vector<4xi1>, vector<4xf32> +// CHECK: %[[RED2:.*]] = vector.reduction , %[[RED1]] : vector<4xf32> into f32 +// CHECK: scf.reduce(%[[RED2]] : f32) { +// CHECK: ^bb0(%[[R_ARG1:.*]]: f32, %[[R_ARG2:.*]]: f32): +// CHECK: %[[R:.*]] = arith.addf %[[R_ARG1]], %[[R_ARG2]] : f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: return %[[RES]] : f32 +func.func @test(%A: memref) -> f32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init = arith.constant 0.0 : f32 + %count = memref.dim %A, %c0 : memref + %res = scf.parallel (%i) = (%c0) to (%count) step (%c1) init (%init) -> f32 { + %1 = memref.load %A[%i] : memref + scf.reduce(%1 : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %2 = arith.addf %lhs, %rhs : f32 + scf.reduce.return %2 : f32 + } + } + return %res : f32 +} + + +// ----- + +func.func private @combine(f32, f32) -> (f32) + +// CHECK-LABEL: @test +// CHECK-SAME: (%[[A:.*]]: memref) -> f32 { +// CHECK: %[[C00:.*]] = arith.constant 0 : index +// CHECK: %[[INIT:.*]] = arith.constant 0.0{{.*}} : f32 +// CHECK: %[[DIM:.*]] = memref.dim %[[A]], %[[C00]] : memref +// CHECK: %[[C4:.*]] = arith.constant 4 : index +// CHECK: %[[COUNT:.*]] = arith.divsi %[[DIM]], %[[C4]] : index +// CHECK: %[[RES:.*]] = scf.parallel (%[[I:.*]]) = (%{{.*}}) to (%[[COUNT]]) step (%{{.*}}) init (%[[INIT]]) -> f32 { +// CHECK: %[[MULT:.*]] = arith.muli %arg1, %c4 : index +// CHECK: %[[A_VAL:.*]] = vector.load %[[A]][%[[MULT]]] : memref, vector<4xf32> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[E0:.*]] = vector.extractelement %[[A_VAL]][%[[C0]] : index] : vector<4xf32> +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[E1:.*]] = vector.extractelement %[[A_VAL]][%[[C1]] : index] : vector<4xf32> +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[E2:.*]] = vector.extractelement %[[A_VAL]][%[[C2]] : index] : vector<4xf32> +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[E3:.*]] = vector.extractelement %[[A_VAL]][%[[C3]] : index] : vector<4xf32> +// CHECK: %[[R0:.*]] = func.call @combine(%[[E0]], %[[E1]]) : (f32, f32) -> f32 +// CHECK: %[[R1:.*]] = func.call @combine(%[[R0]], %[[E2]]) : (f32, f32) -> f32 +// CHECK: %[[R2:.*]] = func.call @combine(%[[R1]], %[[E3]]) : (f32, f32) -> f32 +// CHECK: scf.reduce(%[[R2]] : f32) { +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = func.call @combine(%[[LHS]], %[[RHS]]) : (f32, f32) -> f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: } +// CHECK: %[[UB1:.*]] = arith.muli %[[COUNT]], %[[C4]] : index +// CHECK: %[[UB2:.*]] = arith.addi %[[UB1]], %[[C00]] : index +// CHECK: %[[RES1:.*]] = scf.parallel (%[[I:.*]]) = (%[[UB2]]) to (%[[DIM]]) step (%{{.*}}) init (%[[RES]]) -> f32 { +// CHECK: %[[A_VAL:.*]] = memref.load %[[A]][%[[I]]] : memref +// CHECK: scf.reduce(%[[A_VAL]] : f32) { +// CHECK: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[R:.*]] = func.call @combine(%[[LHS]], %[[RHS]]) : (f32, f32) -> f32 +// CHECK: scf.reduce.return %[[R]] : f32 +// CHECK: } +// CHECK: } +// CHECK: return %[[RES1]] : f32 +func.func @test(%A: memref) -> f32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %init = arith.constant 0.0 : f32 + %count = memref.dim %A, %c0 : memref + %res = scf.parallel (%i) = (%c0) to (%count) step (%c1) init (%init) -> f32 { + %1 = memref.load %A[%i] : memref + scf.reduce(%1 : f32) { + ^bb0(%lhs: f32, %rhs: f32): + %2 = func.call @combine(%lhs, %rhs) : (f32, f32) -> (f32) + scf.reduce.return %2 : f32 + } + } + return %res : f32 +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index 975a41ac3d5fe..01c92199b6f3a 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -26,6 +26,7 @@ add_mlir_library(MLIRTestTransforms TestInlining.cpp TestIntRangeInference.cpp TestMakeIsolatedFromAbove.cpp + TestSCFVectorize.cpp ${MLIRTestTransformsPDLSrc} EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Transforms/TestSCFVectorize.cpp b/mlir/test/lib/Transforms/TestSCFVectorize.cpp new file mode 100644 index 0000000000000..84ea190a33de2 --- /dev/null +++ b/mlir/test/lib/Transforms/TestSCFVectorize.cpp @@ -0,0 +1,110 @@ +//===- TestSCFVectorize.cpp - SCF vectorization test pass -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/SCFVectorize.h" + +#include "mlir/Analysis/DataLayoutAnalysis.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" + +using namespace mlir; + +namespace { +struct TestSCFVectorizePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFVectorizePass) + + TestSCFVectorizePass() = default; + TestSCFVectorizePass(const TestSCFVectorizePass &pass) : PassWrapper(pass) {} + + Option vectorBitwidth{*this, "vector-bitwidth", + llvm::cl::desc("Target vector bitwidth "), + llvm::cl::init(128)}; + + StringRef getArgument() const final { return "test-scf-vectorize"; } + StringRef getDescription() const final { return "Test SCF vectorization"; } + + virtual void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + LogicalResult initializeOptions( + StringRef options, + function_ref errorHandler) override { + if (failed(PassWrapper::initializeOptions(options, errorHandler))) + return failure(); + + if (vectorBitwidth <= 0) + return errorHandler("Invalid vector bitwidth: " + + llvm::Twine(vectorBitwidth)); + + return success(); + } + + void runOnOperation() override { + Operation *root = getOperation(); + auto &DLAnalysis = getAnalysis(); + + llvm::SmallVector> + toVectorize; + + // Simple heuristic: total number of elements processed by vector ops, but + // prefer masked mode over non-masked. + auto getBenefit = [](const SCFVectorizeInfo &info) { + return info.factor * info.count * (int(info.masked) + 1); + }; + + root->walk([&](scf::ParallelOp loop) { + const DataLayout &DL = DLAnalysis.getAbove(loop); + std::optional best; + for (auto dim : llvm::seq(0u, loop.getNumLoops())) { + auto info = getLoopVectorizeInfo(loop, dim, vectorBitwidth, &DL); + if (!info) + continue; + + if (!best) { + best = *info; + continue; + } + + if (getBenefit(*info) > getBenefit(*best)) + best = *info; + } + + if (!best) + return; + + toVectorize.emplace_back( + loop, SCFVectorizeParams{best->dim, best->factor, best->masked}); + }); + + if (toVectorize.empty()) + return markAllAnalysesPreserved(); + + for (auto &&[loop, params] : toVectorize) { + const DataLayout &DL = DLAnalysis.getAbove(loop); + if (failed(vectorizeLoop(loop, params, &DL))) + return signalPassFailure(); + } + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTesSCFVectorize() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index d2ba3d06835fb..1ddf437233326 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -75,6 +75,7 @@ void registerInliner(); void registerMemRefBoundCheck(); void registerPatternsTestPass(); void registerSimpleParametricTilingPass(); +void registerTesSCFVectorize(); void registerTestAffineLoopParametricTilingPass(); void registerTestAliasAnalysisPass(); void registerTestArithEmulateWideIntPass(); @@ -204,6 +205,7 @@ void registerTestPasses() { mlir::test::registerMemRefBoundCheck(); mlir::test::registerPatternsTestPass(); mlir::test::registerSimpleParametricTilingPass(); + mlir::test::registerTesSCFVectorize(); mlir::test::registerTestAffineLoopParametricTilingPass(); mlir::test::registerTestAliasAnalysisPass(); mlir::test::registerTestArithEmulateWideIntPass(); From e58d2924b0109a371d677d8b9aa6d6bc9cfff00c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 3 Jun 2024 12:19:00 +0200 Subject: [PATCH 04/10] fix typo --- mlir/lib/Transforms/SCFVectorize.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp index 29e184e584a56..c74cfa4abf80d 100644 --- a/mlir/lib/Transforms/SCFVectorize.cpp +++ b/mlir/lib/Transforms/SCFVectorize.cpp @@ -66,7 +66,7 @@ static bool isRangePermutation(ValueRange val1, ValueRange val2) { template static std::optional -cavTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp, +canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp, const DataLayout *DL) { auto loopIndexVars = loop.getInductionVars(); assert(dim < loopIndexVars.size()); @@ -97,14 +97,14 @@ cavTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp, /// Returns memref element bitwidth or `std::nullopt` if access cannot be /// vectorized. static std::optional -cavTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op, +canTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op, const DataLayout *DL) { assert(dim < loop.getInductionVars().size()); if (auto storeOp = dyn_cast(op)) - return cavTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL); + return canTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL); if (auto loadOp = dyn_cast(op)) - return cavTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL); + return canTriviallyVectorizeMemOpImpl(loop, dim, loadOp, DL); return std::nullopt; } @@ -141,7 +141,7 @@ canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) { static std::optional cenVectorizeMemrefOp(scf::ParallelOp loop, unsigned dim, Operation &op, const DataLayout *DL) { - if (auto w = cavTriviallyVectorizeMemOp(loop, dim, op, DL)) + if (auto w = canTriviallyVectorizeMemOp(loop, dim, op, DL)) return w; return canGatherScatter(loop, op, DL); @@ -419,7 +419,7 @@ LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop, }; auto canTriviallyVectorizeMemOp = [&](auto op) -> bool { - return !!::cavTriviallyVectorizeMemOpImpl(loop, dim, op, DL); + return !!::canTriviallyVectorizeMemOpImpl(loop, dim, op, DL); }; auto canGatherScatter = [&](auto op) { From d5167004f47c47354847e276d75aa8918f582072 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 3 Jun 2024 12:20:14 +0200 Subject: [PATCH 05/10] use has_value() --- mlir/lib/Transforms/SCFVectorize.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Transforms/SCFVectorize.cpp index c74cfa4abf80d..7b907b655976a 100644 --- a/mlir/lib/Transforms/SCFVectorize.cpp +++ b/mlir/lib/Transforms/SCFVectorize.cpp @@ -419,11 +419,11 @@ LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop, }; auto canTriviallyVectorizeMemOp = [&](auto op) -> bool { - return !!::canTriviallyVectorizeMemOpImpl(loop, dim, op, DL); + return ::canTriviallyVectorizeMemOpImpl(loop, dim, op, DL).has_value(); }; auto canGatherScatter = [&](auto op) { - return !!::canGatherScatterImpl(loop, op, DL); + return ::canGatherScatterImpl(loop, op, DL).has_value(); }; // Get idices for vectorized memref load/store. From c33fb1ad508bf326cfc73aafe2625cdde6c6e486 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 3 Jun 2024 12:44:54 +0200 Subject: [PATCH 06/10] move files to scf dialect --- .../{ => Dialect/SCF}/Transforms/SCFVectorize.h | 5 ++--- mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt | 3 ++- .../{ => Dialect/SCF}/Transforms/SCFVectorize.cpp | 14 +++++++------- mlir/lib/Transforms/CMakeLists.txt | 1 - .../SCF}/test-scf-vectorize.mlir | 0 mlir/test/lib/Dialect/SCF/CMakeLists.txt | 1 + .../SCF}/TestSCFVectorize.cpp | 14 +++++++------- mlir/test/lib/Transforms/CMakeLists.txt | 1 - 8 files changed, 19 insertions(+), 20 deletions(-) rename mlir/include/mlir/{ => Dialect/SCF}/Transforms/SCFVectorize.h (98%) rename mlir/lib/{ => Dialect/SCF}/Transforms/SCFVectorize.cpp (97%) rename mlir/test/{Transforms => Dialect/SCF}/test-scf-vectorize.mlir (100%) rename mlir/test/lib/{Transforms => Dialect/SCF}/TestSCFVectorize.cpp (87%) diff --git a/mlir/include/mlir/Transforms/SCFVectorize.h b/mlir/include/mlir/Dialect/SCF/Transforms/SCFVectorize.h similarity index 98% rename from mlir/include/mlir/Transforms/SCFVectorize.h rename to mlir/include/mlir/Dialect/SCF/Transforms/SCFVectorize.h index d2a5e3085ae37..ebaa1edac531c 100644 --- a/mlir/include/mlir/Transforms/SCFVectorize.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/SCFVectorize.h @@ -17,9 +17,7 @@ struct LogicalResult; namespace scf { class ParallelOp; } -} // namespace mlir - -namespace mlir { +namespace scf { /// Loop vectorization info struct SCFVectorizeInfo { @@ -65,6 +63,7 @@ struct SCFVectorizeParams { mlir::LogicalResult vectorizeLoop(mlir::scf::ParallelOp loop, const SCFVectorizeParams ¶ms, const DataLayout *DL = nullptr); +} // namespace scf } // namespace mlir #endif // MLIR_TRANSFORMS_SCFVECTORIZE_H_ diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt index d363ffe941fce..898f20efa7078 100644 --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -13,10 +13,11 @@ add_mlir_dialect_library(MLIRSCFTransforms ParallelLoopCollapsing.cpp ParallelLoopFusion.cpp ParallelLoopTiling.cpp + SCFVectorize.cpp StructuralTypeConversions.cpp TileUsingInterface.cpp - WrapInZeroTripCheck.cpp UpliftWhileToFor.cpp + WrapInZeroTripCheck.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF diff --git a/mlir/lib/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp similarity index 97% rename from mlir/lib/Transforms/SCFVectorize.cpp rename to mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp index 7b907b655976a..7bc7fa544f286 100644 --- a/mlir/lib/Transforms/SCFVectorize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/SCFVectorize.h" +#include "mlir/Dialect/SCF/Transforms/SCFVectorize.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" // getCombinerOpKind @@ -157,9 +157,9 @@ static std::optional getReductionKind(Block &body) { return linalg::getCombinerOpKind(&body.front()); } -std::optional -mlir::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim, - unsigned vectorBitwidth, const DataLayout *DL) { +std::optional +mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim, + unsigned vectorBitwidth, const DataLayout *DL) { assert(dim < loop.getStep().size()); assert(vectorBitwidth > 0); unsigned factor = vectorBitwidth / 8; @@ -234,9 +234,9 @@ static arith::FastMathFlags getFMF(Operation &op) { return arith::FastMathFlags::none; } -LogicalResult mlir::vectorizeLoop(scf::ParallelOp loop, - const SCFVectorizeParams ¶ms, - const DataLayout *DL) { +LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, + const scf::SCFVectorizeParams ¶ms, + const DataLayout *DL) { auto dim = params.dim; auto factor = params.factor; auto masked = params.masked; diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index ed71c73c938ed..90c0298fb5e46 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -14,7 +14,6 @@ add_mlir_library(MLIRTransforms PrintIR.cpp RemoveDeadValues.cpp SCCP.cpp - SCFVectorize.cpp SROA.cpp StripDebugInfo.cpp SymbolDCE.cpp diff --git a/mlir/test/Transforms/test-scf-vectorize.mlir b/mlir/test/Dialect/SCF/test-scf-vectorize.mlir similarity index 100% rename from mlir/test/Transforms/test-scf-vectorize.mlir rename to mlir/test/Dialect/SCF/test-scf-vectorize.mlir diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt index 792430cc84b65..9af1459d17df9 100644 --- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_library(MLIRSCFTestPasses TestLoopParametricTiling.cpp TestLoopUnrolling.cpp TestSCFUtils.cpp + TestSCFVectorize.cpp TestSCFWrapInZeroTripCheck.cpp TestUpliftWhileToFor.cpp TestWhileOpBuilder.cpp diff --git a/mlir/test/lib/Transforms/TestSCFVectorize.cpp b/mlir/test/lib/Dialect/SCF/TestSCFVectorize.cpp similarity index 87% rename from mlir/test/lib/Transforms/TestSCFVectorize.cpp rename to mlir/test/lib/Dialect/SCF/TestSCFVectorize.cpp index 84ea190a33de2..3f92dd03438bd 100644 --- a/mlir/test/lib/Transforms/TestSCFVectorize.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFVectorize.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/SCFVectorize.h" +#include "mlir/Dialect/SCF/Transforms/SCFVectorize.h" #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -58,20 +58,20 @@ struct TestSCFVectorizePass Operation *root = getOperation(); auto &DLAnalysis = getAnalysis(); - llvm::SmallVector> + llvm::SmallVector> toVectorize; // Simple heuristic: total number of elements processed by vector ops, but // prefer masked mode over non-masked. - auto getBenefit = [](const SCFVectorizeInfo &info) { + auto getBenefit = [](const scf::SCFVectorizeInfo &info) { return info.factor * info.count * (int(info.masked) + 1); }; root->walk([&](scf::ParallelOp loop) { const DataLayout &DL = DLAnalysis.getAbove(loop); - std::optional best; + std::optional best; for (auto dim : llvm::seq(0u, loop.getNumLoops())) { - auto info = getLoopVectorizeInfo(loop, dim, vectorBitwidth, &DL); + auto info = scf::getLoopVectorizeInfo(loop, dim, vectorBitwidth, &DL); if (!info) continue; @@ -88,7 +88,7 @@ struct TestSCFVectorizePass return; toVectorize.emplace_back( - loop, SCFVectorizeParams{best->dim, best->factor, best->masked}); + loop, scf::SCFVectorizeParams{best->dim, best->factor, best->masked}); }); if (toVectorize.empty()) @@ -96,7 +96,7 @@ struct TestSCFVectorizePass for (auto &&[loop, params] : toVectorize) { const DataLayout &DL = DLAnalysis.getAbove(loop); - if (failed(vectorizeLoop(loop, params, &DL))) + if (failed(scf::vectorizeLoop(loop, params, &DL))) return signalPassFailure(); } } diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt index 01c92199b6f3a..975a41ac3d5fe 100644 --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -26,7 +26,6 @@ add_mlir_library(MLIRTestTransforms TestInlining.cpp TestIntRangeInference.cpp TestMakeIsolatedFromAbove.cpp - TestSCFVectorize.cpp ${MLIRTestTransformsPDLSrc} EXCLUDE_FROM_LIBMLIR From d8427d7f672c8be8695c71b2e76c04dae4285286 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 3 Jun 2024 19:10:39 +0200 Subject: [PATCH 07/10] getTypeBitWidth std::optional --- .../Dialect/SCF/Transforms/SCFVectorize.cpp | 46 +++++++++++-------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp index 7bc7fa544f286..7e7add925dbb9 100644 --- a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp @@ -20,11 +20,12 @@ using namespace mlir; static bool isSupportedVecElem(Type type) { return type.isIntOrIndexOrFloat(); } -/// Return type bitwidth for vectorization purposes or 0 if type cannot be +/// Return type bitwidth for vectorization purposes or empty if type cannot be /// vectorized. -static unsigned getTypeBitWidth(Type type, const DataLayout *DL) { +static std::optional getTypeBitWidth(Type type, + const DataLayout *DL) { if (!isSupportedVecElem(type)) - return 0; + return std::nullopt; if (DL) return DL->getTypeSizeInBits(type); @@ -32,16 +33,21 @@ static unsigned getTypeBitWidth(Type type, const DataLayout *DL) { if (type.isIntOrFloat()) return type.getIntOrFloatBitWidth(); - return 0; + return std::nullopt; } -static unsigned getArgsTypeWidth(Operation &op, const DataLayout *DL) { +static std::optional getArgsTypeWidth(Operation &op, + const DataLayout *DL) { unsigned ret = 0; - for (auto arg : op.getOperands()) - ret = std::max(ret, getTypeBitWidth(arg.getType(), DL)); + for (auto r : {ValueRange(op.getOperands()), ValueRange(op.getResults())}) { + for (auto arg : op.getOperands()) { + std::optional w = getTypeBitWidth(arg.getType(), DL); + if (!w) + return std::nullopt; - for (auto res : op.getResults()) - ret = std::max(ret, getTypeBitWidth(res.getType(), DL)); + ret = std::max(ret, *w); + } + } return ret; } @@ -72,8 +78,8 @@ canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp, assert(dim < loopIndexVars.size()); auto memref = memOp.getMemRef(); auto type = cast(memref.getType()); - auto width = getTypeBitWidth(type.getElementType(), DL); - if (width == 0) + std::optional width = getTypeBitWidth(type.getElementType(), DL); + if (!width) return std::nullopt; if (!type.getLayout().isIdentity()) @@ -114,13 +120,17 @@ static std::optional canGatherScatterImpl(scf::ParallelOp loop, Op op, const DataLayout *DL) { auto memref = op.getMemRef(); auto memrefType = cast(memref.getType()); - auto width = getTypeBitWidth(memrefType.getElementType(), DL); - if (width == 0) + std::optional width = + getTypeBitWidth(memrefType.getElementType(), DL); + if (!width) return std::nullopt; DominanceInfo dom; - return dom.properlyDominates(memref, loop) && op.getIndices().size() == 1 && - memrefType.getLayout().isIdentity(); + if (!dom.properlyDominates(memref, loop) || op.getIndices().size() != 1 || + !memrefType.getLayout().isIdentity()) + return std::nullopt; + + return width; } // Check if memref access can be converted into gather/scatter. @@ -206,11 +216,11 @@ mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim, continue; } - auto width = getArgsTypeWidth(op, DL); - if (width == 0) + std::optional width = getArgsTypeWidth(op, DL); + if (!width) return std::nullopt; - auto newFactor = vectorBitwidth / width; + auto newFactor = vectorBitwidth / *width; if (newFactor <= 1) continue; From f8da459b45d300ef080ac27dc7ba308ccda65c4b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 3 Jun 2024 19:17:05 +0200 Subject: [PATCH 08/10] update assert messages --- .../Dialect/SCF/Transforms/SCFVectorize.cpp | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp index 7e7add925dbb9..536efc72a0305 100644 --- a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp @@ -75,7 +75,7 @@ static std::optional canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp, const DataLayout *DL) { auto loopIndexVars = loop.getInductionVars(); - assert(dim < loopIndexVars.size()); + assert(dim < loopIndexVars.size() && "Invalid loop dimension"); auto memref = memOp.getMemRef(); auto type = cast(memref.getType()); std::optional width = getTypeBitWidth(type.getElementType(), DL); @@ -105,7 +105,7 @@ canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp, static std::optional canTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op, const DataLayout *DL) { - assert(dim < loop.getInductionVars().size()); + assert(dim < loop.getInductionVars().size() && "Invalid loop dimension"); if (auto storeOp = dyn_cast(op)) return canTriviallyVectorizeMemOpImpl(loop, dim, storeOp, DL); @@ -170,8 +170,8 @@ static std::optional getReductionKind(Block &body) { std::optional mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim, unsigned vectorBitwidth, const DataLayout *DL) { - assert(dim < loop.getStep().size()); - assert(vectorBitwidth > 0); + assert(dim < loop.getStep().size() && "Invalid loop dimension"); + assert(vectorBitwidth > 0 && "Invalid vector bitwidth"); unsigned factor = vectorBitwidth / 8; if (factor <= 1) return std::nullopt; @@ -250,9 +250,9 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, auto dim = params.dim; auto factor = params.factor; auto masked = params.masked; - assert(dim < loop.getStep().size()); - assert(factor > 1); - assert(isConstantIntValue(loop.getStep()[dim], 1)); + assert(dim < loop.getStep().size() && "Invalid loop dimension"); + assert(factor > 1 && "Invalid vectorize factor"); + assert(isConstantIntValue(loop.getStep()[dim], 1) && "Loop stepust be 1"); OpBuilder builder(loop); auto lower = llvm::to_vector(loop.getLowerBound()); @@ -332,7 +332,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, return vec; } auto type = orig.getType(); - assert(isSupportedVecElem(type)); + assert(isSupportedVecElem(type) && "Unsupported vector element type"); Value val = orig; auto origIndexVars = loop.getInductionVars(); @@ -367,7 +367,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, // cache and not handled here. auto &ret = unpackedVals[val]; - assert(ret.empty()); + assert(ret.empty() && "Invalid unpackedVals state"); if (!isSupportedVecElem(val.getType())) { // Non vectorizable value, it must be a value defined outside the loop, // just replicate it. @@ -387,8 +387,8 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, // Add unpacked values to the cache. auto setUnpackedVals = [&](Value origVal, ValueRange newVals) { - assert(newVals.size() == factor); - assert(unpackedVals.count(origVal) == 0); + assert(newVals.size() == factor && "Invalid values count"); + assert(unpackedVals.count(origVal) == 0 && "Invalid unpackedVals state"); unpackedVals[origVal].append(newVals.begin(), newVals.end()); auto type = origVal.getType(); @@ -555,7 +555,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, for (auto &&[i, arg] : llvm::enumerate(op.getOperands())) { auto unpacked = getUnpackedVals(arg); - assert(unpacked.size() == factor); + assert(unpacked.size() == factor && "Invalid unpacked size"); for (auto j : llvm::seq(0u, factor)) duplicatedArgs[j * numArgs + i] = unpacked[j]; } @@ -588,7 +588,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, llvm::zip(reduceOp.getReductions(), reduceOp.getOperands())) { scalarMapping.clear(); Block &reduceBody = body.front(); - assert(reduceBody.getNumArguments() == 2); + assert(reduceBody.getNumArguments() == 2 && "Malformed scf.reduce"); Value reduceVal; if (auto redKind = getReductionKind(reduceBody)) { @@ -596,7 +596,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, Value redArg = getVecVal(arg); if (redArg) { auto neutral = arith::getNeutralElement(&reduceBody.front()); - assert(neutral); + assert(neutral && "getNeutralElement has unepectedly failed"); Value neutralVal = builder.create(loc, *neutral); Value neutralVec = builder.create(loc, neutralVal, redArg.getType()); @@ -618,7 +618,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, auto lhs = reduceBody.getArgument(0); auto rhs = reduceBody.getArgument(1); auto unpacked = getUnpackedVals(arg); - assert(unpacked.size() == factor); + assert(unpacked.size() == factor && "Invalid unpacked size"); reduceVal = unpacked.front(); for (auto i : llvm::seq(1u, factor)) { Value val = unpacked[i]; From c4f9d1e7f55320a279594ffe83f7d2c07eca5860 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 3 Jun 2024 19:42:01 +0200 Subject: [PATCH 09/10] remove auto --- .../Dialect/SCF/Transforms/SCFVectorize.cpp | 139 +++++++++--------- 1 file changed, 69 insertions(+), 70 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp index 536efc72a0305..b7d1281fb20ca 100644 --- a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp @@ -40,7 +40,7 @@ static std::optional getArgsTypeWidth(Operation &op, const DataLayout *DL) { unsigned ret = 0; for (auto r : {ValueRange(op.getOperands()), ValueRange(op.getResults())}) { - for (auto arg : op.getOperands()) { + for (Value arg : op.getOperands()) { std::optional w = getTypeBitWidth(arg.getType(), DL); if (!w) return std::nullopt; @@ -62,7 +62,7 @@ static bool isRangePermutation(ValueRange val1, ValueRange val2) { if (val1.size() != val2.size()) return false; - for (auto v1 : val1) { + for (Value v1 : val1) { auto it = llvm::find(val2, v1); if (it == val2.end()) return false; @@ -74,9 +74,9 @@ template static std::optional canTriviallyVectorizeMemOpImpl(scf::ParallelOp loop, unsigned dim, Op memOp, const DataLayout *DL) { - auto loopIndexVars = loop.getInductionVars(); + ValueRange loopIndexVars = loop.getInductionVars(); assert(dim < loopIndexVars.size() && "Invalid loop dimension"); - auto memref = memOp.getMemRef(); + Value memref = memOp.getMemRef(); auto type = cast(memref.getType()); std::optional width = getTypeBitWidth(type.getElementType(), DL); if (!width) @@ -118,7 +118,7 @@ canTriviallyVectorizeMemOp(scf::ParallelOp loop, unsigned dim, Operation &op, template static std::optional canGatherScatterImpl(scf::ParallelOp loop, Op op, const DataLayout *DL) { - auto memref = op.getMemRef(); + Value memref = op.getMemRef(); auto memrefType = cast(memref.getType()); std::optional width = getTypeBitWidth(memrefType.getElementType(), DL); @@ -151,7 +151,7 @@ canGatherScatter(scf::ParallelOp loop, Operation &op, const DataLayout *DL) { static std::optional cenVectorizeMemrefOp(scf::ParallelOp loop, unsigned dim, Operation &op, const DataLayout *DL) { - if (auto w = canTriviallyVectorizeMemOp(loop, dim, op, DL)) + if (std::optional w = canTriviallyVectorizeMemOp(loop, dim, op, DL)) return w; return canGatherScatter(loop, op, DL); @@ -200,8 +200,8 @@ mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim, return std::nullopt; /// Check mem ops. - if (auto w = cenVectorizeMemrefOp(loop, dim, op, DL)) { - auto newFactor = vectorBitwidth / *w; + if (std::optional w = cenVectorizeMemrefOp(loop, dim, op, DL)) { + unsigned newFactor = vectorBitwidth / *w; if (newFactor > 1) { factor = std::min(factor, newFactor); ++count; @@ -220,7 +220,7 @@ mlir::scf::getLoopVectorizeInfo(scf::ParallelOp loop, unsigned dim, if (!width) return std::nullopt; - auto newFactor = vectorBitwidth / *width; + unsigned newFactor = vectorBitwidth / *width; if (newFactor <= 1) continue; @@ -247,26 +247,26 @@ static arith::FastMathFlags getFMF(Operation &op) { LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, const scf::SCFVectorizeParams ¶ms, const DataLayout *DL) { - auto dim = params.dim; - auto factor = params.factor; - auto masked = params.masked; + unsigned dim = params.dim; + unsigned factor = params.factor; + bool masked = params.masked; assert(dim < loop.getStep().size() && "Invalid loop dimension"); assert(factor > 1 && "Invalid vectorize factor"); assert(isConstantIntValue(loop.getStep()[dim], 1) && "Loop stepust be 1"); OpBuilder builder(loop); - auto lower = llvm::to_vector(loop.getLowerBound()); - auto upper = llvm::to_vector(loop.getUpperBound()); - auto step = llvm::to_vector(loop.getStep()); + SmallVector lower = llvm::to_vector(loop.getLowerBound()); + SmallVector upper = llvm::to_vector(loop.getUpperBound()); + SmallVector step = llvm::to_vector(loop.getStep()); - auto loc = loop.getLoc(); + Location loc = loop.getLoc(); - auto origIndexVar = loop.getInductionVars()[dim]; + Value origIndexVar = loop.getInductionVars()[dim]; Value factorVal = builder.create(loc, factor); - auto origLower = lower[dim]; - auto origUpper = upper[dim]; + Value origLower = lower[dim]; + Value origUpper = upper[dim]; Value count = builder.createOrFold(loc, origUpper, origLower); Value newCount; @@ -284,10 +284,10 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, // Vectorized loop. auto newLoop = builder.create(loc, lower, upper, step, loop.getInitVals()); - auto newIndexVar = newLoop.getInductionVars()[dim]; + Value newIndexVar = newLoop.getInductionVars()[dim]; auto toVectorType = [&](Type elemType) -> VectorType { - int64_t f = factor; + auto f = static_cast(factor); return VectorType::get(f, elemType); }; @@ -311,14 +311,14 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, // Get vector value in new loop for provided `orig` value in source loop. auto getVecVal = [&](Value orig) -> Value { // Use cached value if present. - if (auto mapped = mapping.lookupOrNull(orig)) + if (Value mapped = mapping.lookupOrNull(orig)) return mapped; // Vectorized loop index, loop index is divided by factor, so for factorN // vectorized index will looks like `splat(idx) + (0, 1, ..., N - 1)` if (orig == origIndexVar) { - auto vecType = toVectorType(builder.getIndexType()); - llvm::SmallVector elems(factor); + VectorType vecType = toVectorType(builder.getIndexType()); + SmallVector elems(factor); for (auto i : llvm::seq(0u, factor)) elems[i] = builder.getIndexAttr(i); auto attr = DenseElementsAttr::get(vecType, elems); @@ -331,11 +331,11 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, mapping.map(orig, vec); return vec; } - auto type = orig.getType(); + Type type = orig.getType(); assert(isSupportedVecElem(type) && "Unsupported vector element type"); Value val = orig; - auto origIndexVars = loop.getInductionVars(); + ValueRange origIndexVars = loop.getInductionVars(); auto it = llvm::find(origIndexVars, orig); // If loop index, but not on vectorized dimension, just take new loop index @@ -346,14 +346,12 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, // Values which are defined inside loop body are preemptively added to the // mapper and not handled here. Values defined outside body are just // splatted. - - auto vecType = toVectorType(type); - Value vec = builder.create(loc, val, vecType); + Value vec = builder.create(loc, val, toVectorType(type)); mapping.map(orig, vec); return vec; }; - llvm::DenseMap> unpackedVals; + llvm::DenseMap> unpackedVals; // Get unpacked values for provided `orig` value in source loop. // Values are returned as `ValueRange` and not as vector value. @@ -376,7 +374,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, } // Get vector value and extract elements from it. - auto vecVal = getVecVal(val); + Value vecVal = getVecVal(val); ret.resize(factor); for (auto i : llvm::seq(0u, factor)) { Value idx = builder.create(loc, i); @@ -391,13 +389,13 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, assert(unpackedVals.count(origVal) == 0 && "Invalid unpackedVals state"); unpackedVals[origVal].append(newVals.begin(), newVals.end()); - auto type = origVal.getType(); + Type type = origVal.getType(); if (!isSupportedVecElem(type)) return; // If type is vectorizabale construct a vector add it to vector cache as // well. - auto vecType = toVectorType(type); + VectorType vecType = toVectorType(type); Value vec = createPosionVec(vecType); for (auto i : llvm::seq(0u, factor)) { @@ -422,7 +420,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, } else { maskSize = builder.getIndexAttr(factor); } - auto vecType = toVectorType(builder.getI1Type()); + VectorType vecType = toVectorType(builder.getI1Type()); mask = builder.create(loc, vecType, maskSize); return mask; @@ -437,11 +435,11 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, }; // Get idices for vectorized memref load/store. - auto getMemrefVecIndices = [&](ValueRange indices) { + auto getMemrefVecIndices = [&](ValueRange indices) -> SmallVector { scalarMapping.clear(); scalarMapping.map(loop.getInductionVars(), newLoop.getInductionVars()); - llvm::SmallVector ret(indices.size()); + SmallVector ret(indices.size()); for (auto &&[i, val] : llvm::enumerate(indices)) { if (val == origIndexVar) { Value idx = getrIndexVarMult(); @@ -457,13 +455,13 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, // Create vectorized memref load for specified non-vectorized load. auto genLoad = [&](auto loadOp) { - auto indices = getMemrefVecIndices(loadOp.getIndices()); - auto resType = toVectorType(loadOp.getResult().getType()); - auto memref = loadOp.getMemRef(); + SmallVector indices = getMemrefVecIndices(loadOp.getIndices()); + VectorType resType = toVectorType(loadOp.getResult().getType()); + Value memref = loadOp.getMemRef(); Value vecLoad; if (masked) { - auto mask = getMask(); - auto init = createPosionVec(resType); + Value mask = getMask(); + Value init = createPosionVec(resType); vecLoad = builder.create(loc, resType, memref, indices, mask, init); } else { @@ -474,19 +472,19 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, // Create vectorized memref store for specified non-vectorized store. auto genStore = [&](auto storeOp) { - auto indices = getMemrefVecIndices(storeOp.getIndices()); - auto value = getVecVal(storeOp.getValueToStore()); - auto memref = storeOp.getMemRef(); + SmallVector indices = getMemrefVecIndices(storeOp.getIndices()); + Value value = getVecVal(storeOp.getValueToStore()); + Value memref = storeOp.getMemRef(); if (masked) { - auto mask = getMask(); + Value mask = getMask(); builder.create(loc, memref, indices, mask, value); } else { builder.create(loc, value, memref, indices); } }; - llvm::SmallVector duplicatedArgs; - llvm::SmallVector duplicatedResults; + SmallVector duplicatedArgs; + SmallVector duplicatedResults; builder.setInsertionPointToStart(newLoop.getBody()); for (Operation &op : loop.getBody()->without_terminator()) { @@ -494,11 +492,11 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, if (isSupportedVectorOp(op)) { // If op can be vectorized, clone it with vectorized inputs and update // resuls to vectorized types. - for (auto arg : op.getOperands()) + for (Value arg : op.getOperands()) getVecVal(arg); // init mapper for op args - auto newOp = builder.clone(op, mapping); - for (auto res : newOp->getResults()) + Operation *newOp = builder.clone(op, mapping); + for (Value res : newOp->getResults()) res.setType(toVectorType(res.getType())); continue; @@ -512,11 +510,11 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, continue; } if (canGatherScatter(loadOp)) { - auto resType = toVectorType(loadOp.getResult().getType()); - auto memref = loadOp.getMemRef(); - auto mask = getMask(); - auto indexVec = getVecVal(loadOp.getIndices()[0]); - auto init = createPosionVec(resType); + VectorType resType = toVectorType(loadOp.getResult().getType()); + Value memref = loadOp.getMemRef(); + Value mask = getMask(); + Value indexVec = getVecVal(loadOp.getIndices()[0]); + Value init = createPosionVec(resType); auto gather = builder.create( loc, resType, memref, zero, indexVec, mask, init); @@ -531,10 +529,10 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, continue; } if (canGatherScatter(storeOp)) { - auto memref = storeOp.getMemRef(); - auto value = getVecVal(storeOp.getValueToStore()); - auto mask = getMask(); - auto indexVec = getVecVal(storeOp.getIndices()[0]); + Value memref = storeOp.getMemRef(); + Value value = getVecVal(storeOp.getValueToStore()); + Value mask = getMask(); + Value indexVec = getVecVal(storeOp.getIndices()[0]); builder.create(loc, memref, zero, indexVec, mask, value); @@ -554,7 +552,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, duplicatedResults.resize(numResults * factor); for (auto &&[i, arg] : llvm::enumerate(op.getOperands())) { - auto unpacked = getUnpackedVals(arg); + ValueRange unpacked = getUnpackedVals(arg); assert(unpacked.size() == factor && "Invalid unpacked size"); for (auto j : llvm::seq(0u, factor)) duplicatedArgs[j * numArgs + i] = unpacked[j]; @@ -565,7 +563,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, .drop_front(numArgs * i) .take_front(numArgs); scalarMapping.map(op.getOperands(), args); - auto results = builder.clone(op, scalarMapping)->getResults(); + ValueRange results = builder.clone(op, scalarMapping)->getResults(); for (auto j : llvm::seq(0u, numResults)) duplicatedResults[j * factor + i] = results[j]; @@ -581,7 +579,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, // Vectorize `scf.reduce` op. auto reduceOp = cast(loop.getBody()->getTerminator()); - llvm::SmallVector reduceVals; + SmallVector reduceVals; reduceVals.reserve(reduceOp.getNumOperands()); for (auto &&[body, arg] : @@ -595,16 +593,17 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, // Generate `vector.reduce` if possible. Value redArg = getVecVal(arg); if (redArg) { - auto neutral = arith::getNeutralElement(&reduceBody.front()); + std::optional neutral = + arith::getNeutralElement(&reduceBody.front()); assert(neutral && "getNeutralElement has unepectedly failed"); Value neutralVal = builder.create(loc, *neutral); Value neutralVec = builder.create(loc, neutralVal, redArg.getType()); - auto mask = getMask(); + Value mask = getMask(); redArg = builder.create(loc, mask, redArg, neutralVec); } - auto fmf = getFMF(reduceBody.front()); + arith::FastMathFlags fmf = getFMF(reduceBody.front()); reduceVal = builder.create(loc, *redKind, redArg, fmf); } else { @@ -615,16 +614,16 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, // individually. auto reduceTerm = cast(reduceBody.getTerminator()); - auto lhs = reduceBody.getArgument(0); - auto rhs = reduceBody.getArgument(1); - auto unpacked = getUnpackedVals(arg); + Value lhs = reduceBody.getArgument(0); + Value rhs = reduceBody.getArgument(1); + ValueRange unpacked = getUnpackedVals(arg); assert(unpacked.size() == factor && "Invalid unpacked size"); reduceVal = unpacked.front(); for (auto i : llvm::seq(1u, factor)) { Value val = unpacked[i]; scalarMapping.map(lhs, reduceVal); scalarMapping.map(rhs, val); - for (auto &redOp : reduceBody.without_terminator()) + for (Operation &redOp : reduceBody.without_terminator()) builder.clone(redOp, scalarMapping); reduceVal = scalarMapping.lookupOrDefault(reduceTerm.getResult()); @@ -648,7 +647,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, builder.createOrFold(loc, newCount, factorVal); newLower = builder.createOrFold(loc, origLower, newLower); - auto lowerCopy = llvm::to_vector(loop.getLowerBound()); + SmallVector lowerCopy = llvm::to_vector(loop.getLowerBound()); lowerCopy[dim] = newLower; loop.getLowerBoundMutable().assign(lowerCopy); loop.getInitValsMutable().assign(newLoop.getResults()); From a4ea5d981ac3a01e5ef5ceb21175388f02820d21 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 3 Jun 2024 19:48:27 +0200 Subject: [PATCH 10/10] remove tmp var --- mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp index b7d1281fb20ca..d441dffc6c58a 100644 --- a/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/SCFVectorize.cpp @@ -287,8 +287,7 @@ LogicalResult mlir::scf::vectorizeLoop(scf::ParallelOp loop, Value newIndexVar = newLoop.getInductionVars()[dim]; auto toVectorType = [&](Type elemType) -> VectorType { - auto f = static_cast(factor); - return VectorType::get(f, elemType); + return VectorType::get(factor, elemType); }; IRMapping mapping;