Skip to content

Commit 1fe405f

Browse files
committed
[mlir] Add loop bound normalization pass
1 parent 1034b4d commit 1fe405f

File tree

16 files changed

+796
-63
lines changed

16 files changed

+796
-63
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

+97
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def ForOp : SCF_Op<"for",
140140
"getSingleUpperBound", "getYieldedValuesMutable",
141141
"promoteIfSingleIteration", "replaceWithAdditionalYields",
142142
"yieldTiledValuesAndReplace"]>,
143+
LoopLikeWithInductionVarsOpInterface,
143144
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
144145
ConditionallySpeculatable,
145146
DeclareOpInterfaceMethods<RegionBranchOpInterface,
@@ -267,6 +268,74 @@ def ForOp : SCF_Op<"for",
267268
return getBody()->getArguments().drop_front(getNumInductionVars())[index];
268269
}
269270

271+
/// Return the induction variables.
272+
::mlir::ValueRange getInductionVars() {
273+
return getBody()->getArguments().take_front(getNumInductionVars());
274+
}
275+
276+
/// Get lower bounds as `OpFoldResult`.
277+
SmallVector<OpFoldResult> getMixedLowerBound() {
278+
return {getAsOpFoldResult(getLowerBound())};
279+
}
280+
281+
/// Get upper bounds as `OpFoldResult`.
282+
SmallVector<OpFoldResult> getMixedUpperBound() {
283+
return {getAsOpFoldResult(getUpperBound())};
284+
}
285+
286+
// Get steps as `OpFoldResult`.
287+
SmallVector<OpFoldResult> getMixedStep() {
288+
return {getAsOpFoldResult(getStep())};
289+
}
290+
291+
/// Get lower bounds as values.
292+
SmallVector<Value> getLowerBound(OpBuilder &b) {
293+
return ValueRange{getLowerBound()};
294+
}
295+
296+
/// Get upper bounds as values.
297+
SmallVector<Value> getUpperBound(OpBuilder &b) {
298+
return ValueRange{getUpperBound()};
299+
}
300+
301+
/// Get steps as values.
302+
SmallVector<Value> getStep(OpBuilder &b) {
303+
return ValueRange{getStep()};
304+
}
305+
306+
/// Set the lower bounds from `OpFoldResult`.
307+
void setMixedLowerBounds(OpBuilder &b, ArrayRef<OpFoldResult> lbs) {
308+
setLowerBound(getValueOrCreateConstantIndexOp(b, getLoc(), lbs[0]));
309+
}
310+
311+
/// Set the upper bounds from `OpFoldResult`.
312+
void setMixedUpperBounds(OpBuilder &b, ArrayRef<OpFoldResult> ubs) {
313+
setUpperBound(getValueOrCreateConstantIndexOp(b, getLoc(), ubs[0]));
314+
}
315+
316+
/// Set the steps from `OpFoldResult`.
317+
void setMixedSteps(OpBuilder &b, ArrayRef<OpFoldResult> steps) {
318+
setStep(getValueOrCreateConstantIndexOp(b, getLoc(), steps[0]));
319+
}
320+
321+
/// Set the lower bounds from values.
322+
void setLowerBounds(ArrayRef<Value> lbs) {
323+
assert(lbs.size() == 1 && "expected a single lower bound");
324+
setLowerBound(lbs[0]);
325+
}
326+
327+
/// Set the upper bounds from values.
328+
void setUpperBounds(ArrayRef<Value> ubs) {
329+
assert(ubs.size() == 1 && "expected a single upper bound");
330+
setUpperBound(ubs[0]);
331+
}
332+
333+
/// Set the steps from values.
334+
void setSteps(ArrayRef<Value> steps) {
335+
assert(steps.size() == 1 && "expected a single step");
336+
setStep(steps[0]);
337+
}
338+
270339
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
271340
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
272341
void setStep(Value step) { getOperation()->setOperand(2, step); }
@@ -304,6 +373,7 @@ def ForallOp : SCF_Op<"forall", [
304373
["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar",
305374
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
306375
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
376+
LoopLikeWithInductionVarsOpInterface,
307377
RecursiveMemoryEffects,
308378
SingleBlockImplicitTerminator<"scf::InParallelOp">,
309379
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@@ -543,6 +613,33 @@ def ForallOp : SCF_Op<"forall", [
543613
return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedStep());
544614
}
545615

616+
/// Set the lower bounds from `OpFoldResult`.
617+
void setMixedLowerBounds(OpBuilder &b, ArrayRef<OpFoldResult> lbs);
618+
619+
/// Set the upper bounds from `OpFoldResult`.
620+
void setMixedUpperBounds(OpBuilder &b, ArrayRef<OpFoldResult> ubs);
621+
622+
/// Set the steps from `OpFoldResult`.
623+
void setMixedSteps(OpBuilder &b, ArrayRef<OpFoldResult> steps);
624+
625+
/// Set the lower bounds from values.
626+
void setLowerBounds(ArrayRef<Value> lbs) {
627+
OpBuilder b(getOperation()->getContext());
628+
return setMixedLowerBounds(b, getAsOpFoldResult(lbs));
629+
}
630+
631+
/// Set the upper bounds from values.
632+
void setUpperBounds(ArrayRef<Value> ubs) {
633+
OpBuilder b(getOperation()->getContext());
634+
return setMixedUpperBounds(b, getAsOpFoldResult(ubs));
635+
}
636+
637+
/// Set the steps from values.
638+
void setSteps(ArrayRef<Value> steps) {
639+
OpBuilder b(getOperation()->getContext());
640+
return setMixedSteps(b, getAsOpFoldResult(steps));
641+
}
642+
546643
int64_t getRank() { return getStaticLowerBound().size(); }
547644

548645
/// Number of operands controlling the loop: lbs, ubs, steps
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===- LoopUtils.h - Helpers related to loop operations ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header file defines utilities for loop operations.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/IR/PatternMatch.h"
14+
15+
namespace mlir {
16+
17+
// This structure is to pass and return sets of loop parameters without
18+
// confusing the order.
19+
struct LoopParams {
20+
Value lowerBound;
21+
Value upperBound;
22+
Value step;
23+
};
24+
25+
/// Calculate the normalized loop upper bounds with lower bound equal to zero
26+
/// and step equal to one.
27+
LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
28+
Value lb, Value ub, Value step);
29+
30+
} // namespace mlir

mlir/include/mlir/Interfaces/LoopLikeInterface.h

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
1414
#define MLIR_INTERFACES_LOOPLIKEINTERFACE_H_
1515

16+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1617
#include "mlir/IR/OpDefinition.h"
1718

1819
namespace mlir {
@@ -28,6 +29,9 @@ using NewYieldValuesFn = std::function<SmallVector<Value>(
2829
namespace detail {
2930
/// Verify invariants of the LoopLikeOpInterface.
3031
LogicalResult verifyLoopLikeOpInterface(Operation *op);
32+
33+
/// Verify invariants of the LoopLikeWithInductionVarsOpInterface.
34+
LogicalResult verifyLoopLikeWithInductionVarsOpInterface(Operation *op);
3135
} // namespace detail
3236

3337
//===----------------------------------------------------------------------===//

mlir/include/mlir/Interfaces/LoopLikeInterface.td

+126
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,132 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
375375
}];
376376
}
377377

378+
def LoopLikeWithInductionVarsOpInterface
379+
: OpInterface<"LoopLikeWithInductionVarsOpInterface"> {
380+
let description = [{
381+
Interface for loop-like operations with one or more induction variables.
382+
This interface contains helper functions for retrieving and updating the
383+
lower bound, upper bound and step size for each induction variable and
384+
provides a utility function to check whether the loop is normalized., i.e.
385+
all lower bounds are equal to zero and steps are equal to one.
386+
}];
387+
let cppNamespace = "::mlir";
388+
389+
let methods = [
390+
InterfaceMethod<[{
391+
Return the induction variables if they exist, otherwise return
392+
std::nullopt.
393+
}],
394+
/*retTy=*/"::mlir::ValueRange",
395+
/*methodName=*/"getInductionVars"
396+
>,
397+
InterfaceMethod<[{
398+
Return the lower bound values or attributes as OpFoldResult.
399+
}],
400+
/*retTy=*/"SmallVector<::mlir::OpFoldResult>",
401+
/*methodName=*/"getMixedLowerBound"
402+
>,
403+
InterfaceMethod<[{
404+
Return the step values or attributes if they exist as OpFoldResult.
405+
}],
406+
/*retTy=*/"SmallVector<::mlir::OpFoldResult>",
407+
/*methodName=*/"getMixedStep"
408+
>,
409+
InterfaceMethod<[{
410+
Return the upper bound values or attributes as OpFoldResult.
411+
}],
412+
/*retTy=*/"SmallVector<::mlir::OpFoldResult>",
413+
/*methodName=*/"getMixedUpperBound"
414+
>,
415+
InterfaceMethod<[{
416+
Return the lower bounds as values.
417+
}],
418+
/*retTy=*/"SmallVector<Value>",
419+
/*methodName=*/"getLowerBound",
420+
/*args=*/(ins "OpBuilder &":$b)
421+
>,
422+
InterfaceMethod<[{
423+
Return the steps as values.
424+
}],
425+
/*retTy=*/"SmallVector<Value>",
426+
/*methodName=*/"getStep",
427+
/*args=*/(ins "OpBuilder &":$b)
428+
>,
429+
InterfaceMethod<[{
430+
Return the upper bounds as values.
431+
}],
432+
/*retTy=*/"SmallVector<Value>",
433+
/*methodName=*/"getUpperBound",
434+
/*args=*/(ins "OpBuilder &":$b)
435+
>,
436+
InterfaceMethod<[{
437+
Set the lower bounds from an array of `OpFoldResult`.
438+
}],
439+
/*retTy=*/"void",
440+
/*methodName=*/"setMixedLowerBounds",
441+
/*args=*/(ins "OpBuilder &":$b, "ArrayRef<OpFoldResult>":$lbs)
442+
>,
443+
InterfaceMethod<[{
444+
Set the steps from an array of `OpFoldResult`.
445+
}],
446+
/*retTy=*/"void",
447+
/*methodName=*/"setMixedSteps",
448+
/*args=*/(ins "OpBuilder &":$b, "ArrayRef<OpFoldResult>":$lbs)
449+
>,
450+
InterfaceMethod<[{
451+
Set the upper bounds from an array of `OpFoldResult`.
452+
}],
453+
/*retTy=*/"void",
454+
/*methodName=*/"setMixedUpperBounds",
455+
/*args=*/(ins "OpBuilder &":$b, "ArrayRef<OpFoldResult>":$lbs)
456+
>,
457+
InterfaceMethod<[{
458+
Set the lower bounds from an array of values.
459+
}],
460+
/*retTy=*/"void",
461+
/*methodName=*/"setLowerBounds",
462+
/*args=*/(ins "ArrayRef<Value>":$lbs)
463+
>,
464+
InterfaceMethod<[{
465+
Set the steps from an array of values.
466+
}],
467+
/*retTy=*/"void",
468+
/*methodName=*/"setSteps",
469+
/*args=*/(ins "ArrayRef<Value>":$lbs)
470+
>,
471+
InterfaceMethod<[{
472+
Set the upper bounds from an array of values.
473+
}],
474+
/*retTy=*/"void",
475+
/*methodName=*/"setUpperBounds",
476+
/*args=*/(ins "ArrayRef<Value>":$lbs)
477+
>,
478+
InterfaceMethod<[{
479+
Checks if the lower bounds are zeros and steps are ones.
480+
}],
481+
/*retTy=*/"bool",
482+
/*methodName=*/"isNormalized",
483+
/*args=*/(ins),
484+
/*methodBody=*/"",
485+
/*defaultImplementation=*/[{
486+
auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
487+
return llvm::all_of(results, [&](OpFoldResult ofr) {
488+
auto intValue = getConstantIntValue(ofr);
489+
return intValue.has_value() && intValue == val;
490+
});
491+
};
492+
SmallVector<::mlir::OpFoldResult> lbs = $_op.getMixedLowerBound();
493+
SmallVector<::mlir::OpFoldResult> steps = $_op.getMixedStep();
494+
return allEqual(lbs, 0) && allEqual(steps, 1);
495+
}]
496+
>
497+
];
498+
499+
let verify = [{
500+
return detail::verifyLoopLikeWithInductionVarsOpInterface($_op);
501+
}];
502+
}
503+
378504
//===----------------------------------------------------------------------===//
379505
// Traits
380506
//===----------------------------------------------------------------------===//

mlir/include/mlir/Transforms/Passes.h

+4
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ std::unique_ptr<Pass> createLoopInvariantCodeMotionPass();
8282
/// Creates a pass that hoists loop-invariant subset ops.
8383
std::unique_ptr<Pass> createLoopInvariantSubsetHoistingPass();
8484

85+
/// Create a pass that normalizes the loop bounds of loop-like operations with
86+
/// induction variables.
87+
std::unique_ptr<Pass> createNormalizeLoopBoundsPass();
88+
8589
/// Creates a pass to strip debug information from a function.
8690
std::unique_ptr<Pass> createStripDebugInfoPass();
8791

mlir/include/mlir/Transforms/Passes.td

+6
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,12 @@ def Mem2Reg : Pass<"mem2reg"> {
377377
];
378378
}
379379

380+
def NormalizeLoopBounds : Pass<"normalize-loop-bounds"> {
381+
let summary = "Normalize the loop bounds of loop-like operations with "
382+
"induction variables.";
383+
let constructor = "mlir::createNormalizeLoopBoundsPass()";
384+
}
385+
380386
def PrintOpStats : Pass<"print-op-stats"> {
381387
let summary = "Print statistics of operations";
382388
let constructor = "mlir::createPrintOpStatsPass()";

mlir/lib/Dialect/SCF/IR/SCF.cpp

+60
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,66 @@ void ForallOp::build(
13871387
build(b, result, lbs, ubs, steps, outputs, mapping, bodyBuilderFn);
13881388
}
13891389

1390+
/// Set the lower bounds from `OpFoldResult`.
1391+
void ForallOp::setMixedLowerBounds(OpBuilder &b, ArrayRef<OpFoldResult> lbs) {
1392+
SmallVector<int64_t> staticLbs;
1393+
SmallVector<Value> dynamicLbs;
1394+
dispatchIndexOpFoldResults(lbs, dynamicLbs, staticLbs);
1395+
getOperation()->setOperands(0, getDynamicLowerBound().size(), dynamicLbs);
1396+
(*this)->setAttr(getStaticLowerBoundAttrName(),
1397+
b.getDenseI64ArrayAttr(staticLbs));
1398+
ArrayRef<int32_t> segmentSizes =
1399+
(*this)
1400+
->getAttrOfType<DenseI32ArrayAttr>("operandSegmentSizes")
1401+
.asArrayRef();
1402+
SmallVector<int32_t> newSegmentSizes(segmentSizes.begin(),
1403+
segmentSizes.end());
1404+
newSegmentSizes[0] = dynamicLbs.size();
1405+
(*this)->setAttr("operandSegmentSizes",
1406+
b.getDenseI32ArrayAttr(newSegmentSizes));
1407+
}
1408+
1409+
/// Set the upper bounds from `OpFoldResult`.
1410+
void ForallOp::setMixedUpperBounds(OpBuilder &b, ArrayRef<OpFoldResult> ubs) {
1411+
SmallVector<int64_t> staticUbs;
1412+
SmallVector<Value> dynamicUbs;
1413+
dispatchIndexOpFoldResults(ubs, dynamicUbs, staticUbs);
1414+
size_t offset = getDynamicLowerBound().size();
1415+
getOperation()->setOperands(offset, getDynamicUpperBound().size(),
1416+
dynamicUbs);
1417+
(*this)->setAttr(getStaticUpperBoundAttrName(),
1418+
b.getDenseI64ArrayAttr(staticUbs));
1419+
ArrayRef<int32_t> segmentSizes =
1420+
(*this)
1421+
->getAttrOfType<DenseI32ArrayAttr>("operandSegmentSizes")
1422+
.asArrayRef();
1423+
SmallVector<int32_t> newSegmentSizes(segmentSizes.begin(),
1424+
segmentSizes.end());
1425+
newSegmentSizes[1] = dynamicUbs.size();
1426+
(*this)->setAttr("operandSegmentSizes",
1427+
b.getDenseI32ArrayAttr(newSegmentSizes));
1428+
}
1429+
1430+
/// Set the steps from `OpFoldResult`.
1431+
void ForallOp::setMixedSteps(OpBuilder &b, ArrayRef<OpFoldResult> steps) {
1432+
SmallVector<int64_t> staticSteps;
1433+
SmallVector<Value> dynamicSteps;
1434+
dispatchIndexOpFoldResults(steps, dynamicSteps, staticSteps);
1435+
size_t offset = getDynamicLowerBound().size() + getDynamicUpperBound().size();
1436+
getOperation()->setOperands(offset, getDynamicStep().size(), dynamicSteps);
1437+
(*this)->setAttr(getStaticStepAttrName(),
1438+
b.getDenseI64ArrayAttr(staticSteps));
1439+
ArrayRef<int32_t> segmentSizes =
1440+
(*this)
1441+
->getAttrOfType<DenseI32ArrayAttr>("operandSegmentSizes")
1442+
.asArrayRef();
1443+
SmallVector<int32_t> newSegmentSizes(segmentSizes.begin(),
1444+
segmentSizes.end());
1445+
newSegmentSizes[2] = dynamicSteps.size();
1446+
(*this)->setAttr("operandSegmentSizes",
1447+
b.getDenseI32ArrayAttr(newSegmentSizes));
1448+
}
1449+
13901450
// Checks if the lbs are zeros and steps are ones.
13911451
bool ForallOp::isNormalized() {
13921452
auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {

0 commit comments

Comments
 (0)