@@ -140,6 +140,7 @@ def ForOp : SCF_Op<"for",
140
140
"getSingleUpperBound", "getYieldedValuesMutable",
141
141
"promoteIfSingleIteration", "replaceWithAdditionalYields",
142
142
"yieldTiledValuesAndReplace"]>,
143
+ LoopLikeWithInductionVarsOpInterface,
143
144
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
144
145
ConditionallySpeculatable,
145
146
DeclareOpInterfaceMethods<RegionBranchOpInterface,
@@ -267,6 +268,74 @@ def ForOp : SCF_Op<"for",
267
268
return getBody()->getArguments().drop_front(getNumInductionVars())[index];
268
269
}
269
270
271
+ /// Return the induction variables.
272
+ ::mlir::ValueRange getInductionVars() {
273
+ return getBody()->getArguments().take_front(getNumInductionVars());
274
+ }
275
+
276
+ /// Get lower bounds as `OpFoldResult`.
277
+ SmallVector<OpFoldResult> getMixedLowerBound() {
278
+ return {getAsOpFoldResult(getLowerBound())};
279
+ }
280
+
281
+ /// Get upper bounds as `OpFoldResult`.
282
+ SmallVector<OpFoldResult> getMixedUpperBound() {
283
+ return {getAsOpFoldResult(getUpperBound())};
284
+ }
285
+
286
+ // Get steps as `OpFoldResult`.
287
+ SmallVector<OpFoldResult> getMixedStep() {
288
+ return {getAsOpFoldResult(getStep())};
289
+ }
290
+
291
+ /// Get lower bounds as values.
292
+ SmallVector<Value> getLowerBound(OpBuilder &b) {
293
+ return ValueRange{getLowerBound()};
294
+ }
295
+
296
+ /// Get upper bounds as values.
297
+ SmallVector<Value> getUpperBound(OpBuilder &b) {
298
+ return ValueRange{getUpperBound()};
299
+ }
300
+
301
+ /// Get steps as values.
302
+ SmallVector<Value> getStep(OpBuilder &b) {
303
+ return ValueRange{getStep()};
304
+ }
305
+
306
+ /// Set the lower bounds from `OpFoldResult`.
307
+ void setMixedLowerBounds(OpBuilder &b, ArrayRef<OpFoldResult> lbs) {
308
+ setLowerBound(getValueOrCreateConstantIndexOp(b, getLoc(), lbs[0]));
309
+ }
310
+
311
+ /// Set the upper bounds from `OpFoldResult`.
312
+ void setMixedUpperBounds(OpBuilder &b, ArrayRef<OpFoldResult> ubs) {
313
+ setUpperBound(getValueOrCreateConstantIndexOp(b, getLoc(), ubs[0]));
314
+ }
315
+
316
+ /// Set the steps from `OpFoldResult`.
317
+ void setMixedSteps(OpBuilder &b, ArrayRef<OpFoldResult> steps) {
318
+ setStep(getValueOrCreateConstantIndexOp(b, getLoc(), steps[0]));
319
+ }
320
+
321
+ /// Set the lower bounds from values.
322
+ void setLowerBounds(ArrayRef<Value> lbs) {
323
+ assert(lbs.size() == 1 && "expected a single lower bound");
324
+ setLowerBound(lbs[0]);
325
+ }
326
+
327
+ /// Set the upper bounds from values.
328
+ void setUpperBounds(ArrayRef<Value> ubs) {
329
+ assert(ubs.size() == 1 && "expected a single upper bound");
330
+ setUpperBound(ubs[0]);
331
+ }
332
+
333
+ /// Set the steps from values.
334
+ void setSteps(ArrayRef<Value> steps) {
335
+ assert(steps.size() == 1 && "expected a single step");
336
+ setStep(steps[0]);
337
+ }
338
+
270
339
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
271
340
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
272
341
void setStep(Value step) { getOperation()->setOperand(2, step); }
@@ -304,6 +373,7 @@ def ForallOp : SCF_Op<"forall", [
304
373
["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar",
305
374
"getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
306
375
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
376
+ LoopLikeWithInductionVarsOpInterface,
307
377
RecursiveMemoryEffects,
308
378
SingleBlockImplicitTerminator<"scf::InParallelOp">,
309
379
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@@ -543,6 +613,33 @@ def ForallOp : SCF_Op<"forall", [
543
613
return getValueOrCreateConstantIndexOp(b, getLoc(), getMixedStep());
544
614
}
545
615
616
+ /// Set the lower bounds from `OpFoldResult`.
617
+ void setMixedLowerBounds(OpBuilder &b, ArrayRef<OpFoldResult> lbs);
618
+
619
+ /// Set the upper bounds from `OpFoldResult`.
620
+ void setMixedUpperBounds(OpBuilder &b, ArrayRef<OpFoldResult> ubs);
621
+
622
+ /// Set the steps from `OpFoldResult`.
623
+ void setMixedSteps(OpBuilder &b, ArrayRef<OpFoldResult> steps);
624
+
625
+ /// Set the lower bounds from values.
626
+ void setLowerBounds(ArrayRef<Value> lbs) {
627
+ OpBuilder b(getOperation()->getContext());
628
+ return setMixedLowerBounds(b, getAsOpFoldResult(lbs));
629
+ }
630
+
631
+ /// Set the upper bounds from values.
632
+ void setUpperBounds(ArrayRef<Value> ubs) {
633
+ OpBuilder b(getOperation()->getContext());
634
+ return setMixedUpperBounds(b, getAsOpFoldResult(ubs));
635
+ }
636
+
637
+ /// Set the steps from values.
638
+ void setSteps(ArrayRef<Value> steps) {
639
+ OpBuilder b(getOperation()->getContext());
640
+ return setMixedSteps(b, getAsOpFoldResult(steps));
641
+ }
642
+
546
643
int64_t getRank() { return getStaticLowerBound().size(); }
547
644
548
645
/// Number of operands controlling the loop: lbs, ubs, steps
0 commit comments