Skip to content

Commit 36d8e70

Browse files
authored
[mlir] Enable LICM for ops with only read side effects in scf.for (#120302)
Enable ops with only read side effects in scf.for to be hoisted with a scf.if guard that checks against the trip count This patch takes a step towards a less conservative LICM in MLIR as discussed in the following discourse thread: [Speculative LICM?](https://discourse.llvm.org/t/speculative-licm/80977) This patch in particular does the following: 1. Relaxes the original constraint for hoisting that only hoists ops without any side effects. This patch also allows the ops with only read side effects to be hoisted into an scf.if guard only if every op in the loop or its nested regions is side-effect free or has only read side effects. This scf.if guard wraps the original scf.for and checks for **trip_count > 0**. 2. To support this, two new interface methods are added to **LoopLikeInterface**: _wrapInTripCountCheck_ and _unwrapTripCountCheck_. Implementation starts with wrapping the scf.for loop into scf.if guard using _wrapInTripCountCheck_ and if there is no op hoisted into the this guard after we are done processing the worklist, it unwraps the guard by calling _unwrapTripCountCheck_.
1 parent a6a5507 commit 36d8e70

File tree

9 files changed

+293
-17
lines changed

9 files changed

+293
-17
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

+3-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def SCF_Dialect : Dialect {
4040
and then lowered to some final target like LLVM or SPIR-V.
4141
}];
4242

43-
let dependentDialects = ["arith::ArithDialect"];
43+
let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
4444
}
4545

4646
// Base class for SCF dialect ops.
@@ -138,7 +138,9 @@ def ForOp : SCF_Op<"for",
138138
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
139139
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
140140
"getLoopUpperBounds", "getYieldedValuesMutable",
141+
"moveOutOfLoopWithGuard",
141142
"promoteIfSingleIteration", "replaceWithAdditionalYields",
143+
"wrapInTripCountCheck", "unwrapTripCountCheck",
142144
"yieldTiledValuesAndReplace"]>,
143145
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
144146
ConditionallySpeculatable,

mlir/include/mlir/Interfaces/LoopLikeInterface.td

+12
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,18 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
7979
/*methodBody=*/"",
8080
/*defaultImplementation=*/"op->moveBefore($_op);"
8181
>,
82+
InterfaceMethod<[{
83+
Moves the given loop-invariant operation out of the loop with a
84+
trip-count guard.
85+
}],
86+
/*retTy=*/"void",
87+
/*methodName=*/"moveOutOfLoopWithGuard",
88+
/*args=*/(ins "::mlir::Operation *":$op),
89+
/*methodBody=*/"",
90+
/*defaultImplementation=*/[{
91+
return;
92+
}]
93+
>,
8294
InterfaceMethod<[{
8395
Promotes the loop body to its containing block if the loop is known to
8496
have a single iteration. Returns "success" if the promotion was

mlir/include/mlir/Interfaces/SideEffectInterfaces.h

+4
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
433433
/// conditions are satisfied.
434434
bool isMemoryEffectFree(Operation *op);
435435

436+
/// Returns true if the given operation is free of memory effects or has only
437+
/// read effect.
438+
bool isMemoryEffectFreeOrOnlyRead(Operation *op);
439+
436440
/// Returns the side effects of an operation. If the operation has
437441
/// RecursiveMemoryEffects, include all side effects of child operations.
438442
///

mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h

+10-6
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,19 @@ class Value;
4747
/// }
4848
/// }
4949
/// ```
50-
///
51-
/// Users must supply three callbacks.
50+
/// Users must supply four callbacks.
5251
///
5352
/// - `isDefinedOutsideRegion` returns true if the given value is invariant with
5453
/// respect to the given region. A common implementation might be:
5554
/// `value.getParentRegion()->isProperAncestor(region)`.
5655
/// - `shouldMoveOutOfRegion` returns true if the provided operation can be
57-
/// moved of the given region, e.g. if it is side-effect free.
58-
/// - `moveOutOfRegion` moves the operation out of the given region. A common
59-
/// implementation might be: `op->moveBefore(region->getParentOp())`.
56+
/// moved of the given region, e.g. if it is side-effect free or has only read
57+
/// side effects.
58+
/// - `moveOutOfRegionWithoutGuard` moves the operation out of the given region
59+
/// without a guard. A common implementation might be:
60+
/// `op->moveBefore(region->getParentOp())`.
61+
/// - `moveOutOfRegionWithGuard` moves the operation out of the given region
62+
/// with a guard.
6063
///
6164
/// An operation is moved if all of its operands satisfy
6265
/// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
@@ -66,7 +69,8 @@ size_t moveLoopInvariantCode(
6669
ArrayRef<Region *> regions,
6770
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
6871
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
69-
function_ref<void(Operation *, Region *)> moveOutOfRegion);
72+
function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
73+
function_ref<void(Operation *)> moveOutOfRegionWithGuard);
7074

7175
/// Move side-effect free loop invariant code out of a loop-like op using
7276
/// methods provided by the interface.

mlir/lib/Dialect/SCF/IR/SCF.cpp

+37-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1616
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
1717
#include "mlir/Dialect/Tensor/IR/Tensor.h"
18+
#include "mlir/Dialect/UB/IR/UBOps.h"
1819
#include "mlir/IR/BuiltinAttributes.h"
1920
#include "mlir/IR/IRMapping.h"
2021
#include "mlir/IR/Matchers.h"
@@ -395,6 +396,40 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
395396

396397
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
397398

399+
/// Moves the op out of the loop with a guard that checks if the loop has at
400+
/// least one iteration.
401+
void ForOp::moveOutOfLoopWithGuard(Operation *op) {
402+
IRRewriter rewriter(this->getContext());
403+
OpBuilder::InsertionGuard insertGuard(rewriter);
404+
rewriter.setInsertionPoint(this->getOperation());
405+
Location loc = this->getLoc();
406+
arith::CmpIOp cmpIOp = rewriter.create<arith::CmpIOp>(
407+
loc, arith::CmpIPredicate::ult, this->getLowerBound(),
408+
this->getUpperBound());
409+
// Create the trip-count check.
410+
scf::YieldOp thenYield;
411+
scf::IfOp ifOp = rewriter.create<scf::IfOp>(
412+
loc, cmpIOp,
413+
[&](OpBuilder &builder, Location loc) {
414+
thenYield = builder.create<scf::YieldOp>(loc, op->getResults());
415+
},
416+
[&](OpBuilder &builder, Location loc) {
417+
SmallVector<Value> poisonResults;
418+
poisonResults.reserve(op->getResults().size());
419+
for (Type type : op->getResults().getTypes()) {
420+
ub::PoisonOp poisonOp =
421+
rewriter.create<ub::PoisonOp>(loc, type, nullptr);
422+
poisonResults.push_back(poisonOp);
423+
}
424+
builder.create<scf::YieldOp>(loc, poisonResults);
425+
});
426+
for (auto [opResult, ifOpResult] :
427+
llvm::zip_equal(op->getResults(), ifOp->getResults()))
428+
rewriter.replaceAllUsesExcept(opResult, ifOpResult, thenYield);
429+
// Move the op into the then block.
430+
rewriter.moveOpBefore(op, thenYield);
431+
}
432+
398433
/// Promotes the loop body of a forOp to its containing block if the forOp
399434
/// it can be determined that the loop has a single iteration.
400435
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
@@ -3394,9 +3429,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
33943429

33953430
if (functionType.getNumInputs() != operands.size()) {
33963431
return parser.emitError(typeLoc)
3397-
<< "expected as many input types as operands "
3398-
<< "(expected " << operands.size() << " got "
3399-
<< functionType.getNumInputs() << ")";
3432+
<< "expected as many input types as operands " << "(expected "
3433+
<< operands.size() << " got " << functionType.getNumInputs() << ")";
34003434
}
34013435

34023436
// Resolve input operands.

mlir/lib/Interfaces/SideEffectInterfaces.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Interfaces/SideEffectInterfaces.h"
1010

1111
#include "mlir/IR/SymbolTable.h"
12+
#include "llvm/ADT/STLExtras.h"
1213
#include "llvm/ADT/SmallPtrSet.h"
1314
#include <utility>
1415

@@ -370,6 +371,17 @@ mlir::getEffectsRecursively(Operation *rootOp) {
370371
return effects;
371372
}
372373

374+
bool mlir::isMemoryEffectFreeOrOnlyRead(Operation *op) {
375+
std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
376+
getEffectsRecursively(op);
377+
if (!effects)
378+
return false;
379+
return llvm::all_of(*effects,
380+
[&](const MemoryEffects::EffectInstance &effect) {
381+
return isa<MemoryEffects::Read>(effect.getEffect());
382+
});
383+
}
384+
373385
bool mlir::isSpeculatable(Operation *op) {
374386
auto conditionallySpeculatable = dyn_cast<ConditionallySpeculatable>(op);
375387
if (!conditionallySpeculatable)

mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp

+27-7
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,18 @@ size_t mlir::moveLoopInvariantCode(
6060
ArrayRef<Region *> regions,
6161
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
6262
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
63-
function_ref<void(Operation *, Region *)> moveOutOfRegion) {
63+
function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
64+
function_ref<void(Operation *)> moveOutOfRegionWithGuard) {
6465
size_t numMoved = 0;
6566

6667
for (Region *region : regions) {
6768
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
6869
<< *region->getParentOp() << "\n");
6970

71+
bool anyOpHoistedWithGuard = false;
72+
bool loopSideEffectFreeOrHasOnlyReadSideEffect =
73+
isMemoryEffectFreeOrOnlyRead(region->getParentOp());
74+
7075
std::queue<Operation *> worklist;
7176
// Add top-level operations in the loop body to the worklist.
7277
for (Operation &op : region->getOps())
@@ -84,12 +89,26 @@ size_t mlir::moveLoopInvariantCode(
8489
continue;
8590

8691
LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");
92+
8793
if (!shouldMoveOutOfRegion(op, region) ||
8894
!canBeHoisted(op, definedOutside))
8995
continue;
96+
// Can only hoist pure ops (side-effect free) when there is an op with
97+
// write and/or unknown side effects in the loop.
98+
if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
99+
continue;
90100

91-
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
92-
moveOutOfRegion(op, region);
101+
bool moveWithoutGuard = !anyOpHoistedWithGuard && isMemoryEffectFree(op);
102+
if (moveWithoutGuard) {
103+
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op
104+
<< " without guard\n");
105+
moveOutOfRegionWithoutGuard(op);
106+
} else {
107+
LLVM_DEBUG(llvm::dbgs()
108+
<< "Moving loop-invariant op: " << *op << " with guard\n");
109+
moveOutOfRegionWithGuard(op);
110+
anyOpHoistedWithGuard = true;
111+
}
93112
++numMoved;
94113

95114
// Since the op has been moved, we need to check its users within the
@@ -106,13 +125,14 @@ size_t mlir::moveLoopInvariantCode(
106125
size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
107126
return moveLoopInvariantCode(
108127
loopLike.getLoopRegions(),
109-
[&](Value value, Region *) {
110-
return loopLike.isDefinedOutsideOfLoop(value);
128+
[&](Value value, Region *region) {
129+
return !region->isAncestor(value.getParentRegion());
111130
},
112131
[&](Operation *op, Region *) {
113-
return isMemoryEffectFree(op) && isSpeculatable(op);
132+
return isSpeculatable(op) && isMemoryEffectFreeOrOnlyRead(op);
114133
},
115-
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
134+
[&](Operation *op) { loopLike.moveOutOfLoop(op); },
135+
[&](Operation *op) { loopLike.moveOutOfLoopWithGuard(op); });
116136
}
117137

118138
namespace {

mlir/test/Transforms/loop-invariant-code-motion.mlir

+144
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,150 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
714714
return
715715
}
716716

717+
// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
718+
func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
719+
// CHECK: test.always_speculatable_op
720+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
721+
// CHECK-NEXT: scf.if %[[CMP]]
722+
// CHECK-NEXT: test.speculatable_op_with_memread
723+
// CHECK: else
724+
// CHECK-NEXT: ub.poison : i32
725+
// CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
726+
// CHECK-NOT: test.always_speculatable_op
727+
// CHECK-NOT: test.speculatable_op_with_memread
728+
%cst_0 = arith.constant 0 : i32
729+
%cst_42 = arith.constant dense<42> : tensor<64xi32>
730+
%ind_42 = arith.constant 42 : index
731+
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
732+
%always_speculate = "test.always_speculatable_op"() : () -> i32
733+
%only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
734+
%i_cast = arith.index_cast %i: index to i32
735+
%add = arith.addi %acc, %i_cast : i32
736+
%sum = arith.addi %add, %only_read : i32
737+
scf.yield %sum : i32
738+
}
739+
return %sum_result : i32
740+
}
741+
742+
// CHECK-LABEL: test_speculatable_op_with_read_side_effect_multiple_result_success
743+
func.func @test_speculatable_op_with_read_side_effect_multiple_result_success(%lb: index, %ub: index, %step: index) -> i32 {
744+
// CHECK: test.always_speculatable_op
745+
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
746+
// CHECK-NEXT: scf.if %[[CMP]]
747+
// CHECK-NEXT: test.speculatable_op_with_memread
748+
// CHECK: else
749+
// CHECK-NEXT: ub.poison : i32
750+
// CHECK-NEXT: ub.poison : f32
751+
// CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
752+
// CHECK-NOT: test.always_speculatable_op
753+
// CHECK-NOT: test.speculatable_op_with_memread
754+
%cst_0 = arith.constant 0 : i32
755+
%cst_42 = arith.constant dense<42> : tensor<64xi32>
756+
%ind_42 = arith.constant 42 : index
757+
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
758+
%always_speculate = "test.always_speculatable_op"() : () -> i32
759+
%only_read:2 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> (i32, f32)
760+
%i_cast = arith.index_cast %i: index to i32
761+
%add = arith.addi %acc, %i_cast : i32
762+
%sum = arith.addi %add, %only_read#0 : i32
763+
scf.yield %sum : i32
764+
}
765+
return %sum_result : i32
766+
}
767+
768+
// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
769+
func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
770+
// CHECK: %[[ALWAYS:.*]] = "test.always_speculatable_op"
771+
// CHECK-NEXT: %[[CMP0:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
772+
// CHECK-NEXT: %[[IF0:.*]] = scf.if %[[CMP0]]
773+
// CHECK-NEXT: test.speculatable_op_with_memread
774+
// CHECK: else
775+
// CHECK-NEXT: ub.poison : i32
776+
// CHECK: %[[CMP1:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
777+
// CHECK-NEXT: %[[IF1:.*]] = scf.if %[[CMP1]]
778+
// CHECK-NEXT: arith.addi %[[ALWAYS]], %[[IF0]]
779+
// CHECK: else
780+
// CHECK-NEXT: ub.poison : i32
781+
// CHECK: %[[CMP2:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
782+
// CHECK-NEXT: %[[IF2:.*]] = scf.if %[[CMP2]]
783+
// CHECK-NEXT: test.speculatable_op_with_memread
784+
// CHECK: else
785+
// CHECK-NEXT: ub.poison : i32
786+
// CHECK: %[[CMP3:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
787+
// CHECK-NEXT: %{{.*}} = scf.if %[[CMP3]]
788+
// CHECK-NEXT: arith.addi %[[IF1]], %[[IF2]]
789+
// CHECK: else
790+
// CHECK-NEXT: ub.poison : i32
791+
// CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]]
792+
// CHECK-NOT: test.always_speculatable_op
793+
// CHECK-NOT: test.speculatable_op_with_memread
794+
%cst_0 = arith.constant 0 : i32
795+
%cst_42 = arith.constant dense<42> : tensor<64xi32>
796+
%ind_42 = arith.constant 42 : index
797+
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
798+
%always_speculate = "test.always_speculatable_op"() : () -> i32
799+
%only_read_0 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
800+
%add_0 = arith.addi %always_speculate, %only_read_0 : i32
801+
%only_read_1 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
802+
%add_1 = arith.addi %add_0, %only_read_1 : i32
803+
%i_cast = arith.index_cast %i: index to i32
804+
%sum = arith.addi %add_1, %i_cast : i32
805+
scf.yield %sum : i32
806+
}
807+
return %sum_result : i32
808+
}
809+
810+
// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
811+
func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb: index, %ub: index, %step: index) -> i32 {
812+
// CHECK: test.always_speculatable_op
813+
// CHECK-NEXT: scf.for
814+
// CHECK-NOT: test.always_speculatable_op
815+
// CHECK: test.speculatable_op_with_memread
816+
// CHECK: test.speculatable_op_with_memwrite
817+
%cst_0 = arith.constant 0 : i32
818+
%cst_42 = arith.constant dense<42> : tensor<64xi32>
819+
%ind_42 = arith.constant 42 : index
820+
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
821+
%always_speculate = "test.always_speculatable_op"() : () -> i32
822+
%only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
823+
%i_cast = arith.index_cast %i: index to i32
824+
%add = arith.addi %acc, %i_cast : i32
825+
%sum = arith.addi %add, %only_read : i32
826+
%write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
827+
scf.yield %sum : i32
828+
}
829+
return %sum_result : i32
830+
}
831+
832+
// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
833+
func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write(%lb: index, %ub: index, %step: index) -> i32 {
834+
// CHECK: test.always_speculatable_op
835+
// CHECK-NEXT: scf.for
836+
// CHECK-NOT: test.always_speculatable_op
837+
// CHECK: test.speculatable_op_with_memread
838+
// CHECK: scf.for
839+
// CHECK: scf.if
840+
// CHECK: test.speculatable_op_with_memwrite
841+
%cst_0 = arith.constant 0 : i32
842+
%cst_42 = arith.constant dense<42> : tensor<64xi32>
843+
%ind_42 = arith.constant 42 : index
844+
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
845+
%always_speculate = "test.always_speculatable_op"() : () -> i32
846+
%only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
847+
%i_cast = arith.index_cast %i: index to i32
848+
%add = arith.addi %acc, %i_cast : i32
849+
%sum = arith.addi %add, %only_read : i32
850+
scf.for %j = %lb to %ub step %step {
851+
%eq42 = arith.cmpi eq, %j, %ind_42 : index
852+
scf.if %eq42 {
853+
%always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
854+
}
855+
}
856+
scf.yield %sum : i32
857+
}
858+
return %sum_result : i32
859+
}
860+
717861
// -----
718862

719863
func.func @speculate_tensor_dim_unknown_rank_unknown_dim(

0 commit comments

Comments
 (0)