Skip to content

Commit 4b56345

Browse files
authored
[mlir][SCF] Unify tileUsingFor and tileReductionUsingFor implementation (#120115)
This patch unifies the tiling implementation for tileUsingFor and tileReductionUsingFor. This is done by passing an addition option to SCFTilingOptions, allowing it to set how reduction dimensions should be tiled. Currently, there are 3 different options for reduction tiling: FullReduction (old tileUsingFor), PartialReductionOuterReduction (old tileReductionUsingFor) and PartialReductionOuterParallel (linalg::tileReductionUsingForall, this isn't implemented in this patch). The patch makes tileReductionUsingFor use the tileUsingFor implementation with the new reduction tiling options. There are no test changes because the implementation was doing almost the exactly same thing. This was also tested in IREE (which uses both these APIs heavily) and there were no test changes.
1 parent e7303fe commit 4b56345

File tree

4 files changed

+305
-225
lines changed

4 files changed

+305
-225
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

+39-18
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,36 @@ struct SCFTilingOptions {
8585
return *this;
8686
}
8787

88+
/// Specify how reduction dimensions should be tiled.
89+
///
90+
/// Tiling can be thought of as splitting a dimension into 2 and materializing
91+
/// the outer dimension as a loop:
92+
///
93+
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
94+
///
95+
/// For parallel dimensions, the split can only happen in one way, with both
96+
/// dimensions being parallel. For reduction dimensions however, there is a
97+
/// choice in how we split the reduction dimension. This enum exposes this
98+
/// choice.
99+
enum class ReductionTilingStrategy {
100+
// [reduction] -> [reduction1, reduction2]
101+
// -> loop[reduction1] { [reduction2] }
102+
FullReduction,
103+
// [reduction] -> [reduction1, parallel2]
104+
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
105+
PartialReductionOuterReduction,
106+
// [reduction] -> [parallel1, reduction2]
107+
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
108+
PartialReductionOuterParallel
109+
};
110+
ReductionTilingStrategy reductionStrategy =
111+
ReductionTilingStrategy::FullReduction;
112+
SCFTilingOptions &
113+
setReductionTilingStrategy(ReductionTilingStrategy strategy) {
114+
reductionStrategy = strategy;
115+
return *this;
116+
}
117+
88118
/// Specify mapping of loops to devices. This is only respected when the loop
89119
/// constructs support such a mapping (like `scf.forall`). Will be ignored
90120
/// when using loop constructs that dont support such a mapping (like
@@ -102,11 +132,16 @@ struct SCFTilingResult {
102132
/// matter except the last op. The replacements are expected to be the results
103133
/// of the last op.
104134
SmallVector<Operation *> tiledOps;
135+
/// The initial destination values passed to the tiled operations.
136+
SmallVector<Value> initialValues;
105137
/// The `scf.for` operations that iterate over the tiles.
106138
SmallVector<LoopLikeOpInterface> loops;
107-
/// Values to use as replacements for the untiled op. Is the same size as the
108-
/// number of results of the untiled op.
109-
SmallVector<Value> replacements;
139+
/// The result generated by the loop nest in tiling, may hold partial results,
140+
/// which need to be merged to match the computation of the untiled operation.
141+
/// `mergeResult` contains the operations used to perform this merge from
142+
/// partial results and the values that can be used as replacements of
143+
/// the untiled operation.
144+
MergeResult mergeResult;
110145
/// Slices generated after tiling that can be used for fusing with the tiled
111146
/// producer.
112147
SmallVector<Operation *> generatedSlices;
@@ -300,20 +335,6 @@ tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
300335
FailureOr<SmallVector<scf::ForOp>>
301336
lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
302337

303-
/// Transformation information returned after reduction tiling.
304-
struct SCFReductionTilingResult {
305-
/// The partial reduction tiled op generated.
306-
SmallVector<Operation *> parallelTiledOps;
307-
/// The final reduction operation merging all the partial reductions.
308-
SmallVector<Operation *> mergeOps;
309-
/// Initial values used for reduction.
310-
SmallVector<Value> initialValues;
311-
/// The loop operations that iterate over the tiles.
312-
SmallVector<LoopLikeOpInterface> loops;
313-
/// The replacements to use for the results of the tiled operation.
314-
SmallVector<Value> replacements;
315-
};
316-
317338
/// Method to tile a reduction and generate a parallel op within a serial loop.
318339
/// Each of the partial reductions are calculated in parallel. Then after the
319340
/// loop all the partial reduction are merged into a final reduction.
@@ -338,7 +359,7 @@ struct SCFReductionTilingResult {
338359
/// %6 = linalg.generic %1 ["parallel", "reduction"]
339360
/// : tensor<7x4xf32> -> tensor<7xf32>
340361
/// ```
341-
FailureOr<scf::SCFReductionTilingResult>
362+
FailureOr<scf::SCFTilingResult>
342363
tileReductionUsingScf(RewriterBase &b, PartialReductionOpInterface op,
343364
ArrayRef<OpFoldResult> tileSize);
344365

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -2223,7 +2223,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
22232223
return emitDefaultDefiniteFailure(target);
22242224

22252225
if (target->getNumResults())
2226-
rewriter.replaceOp(target, maybeTilingResult->replacements);
2226+
rewriter.replaceOp(target, maybeTilingResult->mergeResult.replacements);
22272227
else
22282228
rewriter.eraseOp(target);
22292229

@@ -2630,17 +2630,18 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
26302630
transform::ApplyToEachResultList &results,
26312631
transform::TransformState &state) {
26322632
rewriter.setInsertionPoint(target);
2633-
FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
2633+
FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
26342634
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
26352635
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
26362636

26372637
if (failed(result))
26382638
return emitDefaultSilenceableFailure(target);
2639+
rewriter.replaceOp(target, result->mergeResult.replacements);
26392640
for (Value initValue : result->initialValues)
26402641
results.push_back(initValue.getDefiningOp());
2641-
for (auto parallelTiledOp : result->parallelTiledOps)
2642+
for (auto parallelTiledOp : result->tiledOps)
26422643
results.push_back(parallelTiledOp);
2643-
for (auto mergeOp : result->mergeOps)
2644+
for (auto mergeOp : result->mergeResult.mergeOps)
26442645
results.push_back(mergeOp);
26452646
results.push_back(result->loops.front());
26462647
return DiagnosedSilenceableFailure::success();
@@ -3064,7 +3065,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
30643065
if (failed(maybeTilingResult))
30653066
return DiagnosedSilenceableFailure::definiteFailure();
30663067

3067-
rewriter.replaceOp(op, maybeTilingResult->replacements);
3068+
rewriter.replaceOp(op, maybeTilingResult->mergeResult.replacements);
30683069

30693070
tiled.append(maybeTilingResult->tiledOps);
30703071
for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
@@ -3303,7 +3304,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
33033304
if (failed(maybeTilingResult))
33043305
return transformOp.emitDefaultSilenceableFailure(tileableOp);
33053306

3306-
rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
3307+
rewriter.replaceOp(tileableOp, maybeTilingResult->mergeResult.replacements);
33073308

33083309
tilingResult = *maybeTilingResult;
33093310

0 commit comments

Comments
 (0)