Skip to content

Commit c47042c

Browse files
[mlir][SPRIV][NFC] Avoid rollback in TypeCastingOpPattern (#136284)
This pattern used to create an op and then attached the converted rounding mode attribute. When the latter failed, the pattern aborted and a rollback was triggered. This commit inverses the logic: the converted rounding mode is computed first, so that no changes have to be rolled back. Note: This is in preparation of the One-Shot Dialect Conversion refactoring.
1 parent 90de46c commit c47042c

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

Diff for: mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

+12-8
Original file line numberDiff line numberDiff line change
@@ -847,24 +847,28 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
847847
// Then we can just erase this operation by forwarding its operand.
848848
rewriter.replaceOp(op, adaptor.getOperands().front());
849849
} else {
850-
auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
851-
op, dstType, adaptor.getOperands());
850+
// Compute new rounding mode (if any).
851+
std::optional<spirv::FPRoundingMode> rm = std::nullopt;
852852
if (auto roundingModeOp =
853853
dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
854854
if (arith::RoundingModeAttr roundingMode =
855855
roundingModeOp.getRoundingModeAttr()) {
856-
if (auto rm =
857-
convertArithRoundingModeToSPIRV(roundingMode.getValue())) {
858-
newOp->setAttr(
859-
getDecorationString(spirv::Decoration::FPRoundingMode),
860-
spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
861-
} else {
856+
if (!(rm =
857+
convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
862858
return rewriter.notifyMatchFailure(
863859
op->getLoc(),
864860
llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
865861
}
866862
}
867863
}
864+
// Create replacement op and attach rounding mode attribute (if any).
865+
auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
866+
op, dstType, adaptor.getOperands());
867+
if (rm) {
868+
newOp->setAttr(
869+
getDecorationString(spirv::Decoration::FPRoundingMode),
870+
spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
871+
}
868872
}
869873
return success();
870874
}

0 commit comments

Comments
 (0)