Skip to content

Commit a51f9e4

Browse files
committed
[mlir][OpenMP] convert wsloop cancellation to LLVMIR
Taskloop support will follow in a later patch.
1 parent 7b70fc7 commit a51f9e4

File tree

3 files changed

+125
-18
lines changed

3 files changed

+125
-18
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

+38-2
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
161161
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162162
omp::ClauseCancellationConstructType cancelledDirective =
163163
op.getCancelDirective();
164-
if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel &&
165-
cancelledDirective != omp::ClauseCancellationConstructType::Sections)
164+
if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
166165
result = todo("cancel directive construct type not yet supported");
167166
};
168167
auto checkDepend = [&todo](auto op, LogicalResult &result) {
@@ -2358,6 +2357,30 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23582357
? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
23592358
: llvm::omp::WorksharingLoopType::ForStaticLoop;
23602359

2360+
SmallVector<llvm::BranchInst *> cancelTerminators;
2361+
// This callback is invoked only if there is cancellation inside of the wsloop
2362+
// body.
2363+
auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2364+
llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder;
2365+
llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2366+
2367+
// ip is currently in the block branched to if cancellation occured.
2368+
// We need to create a branch to terminate that block.
2369+
llvmBuilder.restoreIP(ip);
2370+
2371+
// We must still clean up the wsloop after cancelling it, so we need to
2372+
// branch to the block that finalizes the wsloop.
2373+
// That block has not been created yet so use this block as a dummy for now
2374+
// and fix this after creating the wsloop.
2375+
cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2376+
return llvm::Error::success();
2377+
};
2378+
// We have to add the cleanup to the OpenMPIRBuilder before the body gets
2379+
// created in case the body contains omp.cancel (which will then expect to be
2380+
// able to find this cleanup callback).
2381+
ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for,
2382+
constructIsCancellable(wsloopOp)});
2383+
23612384
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
23622385
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
23632386
wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
@@ -2379,6 +2402,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23792402
if (failed(handleError(wsloopIP, opInst)))
23802403
return failure();
23812404

2405+
ompBuilder->popFinalizationCB();
2406+
if (!cancelTerminators.empty()) {
2407+
// If we cancelled the loop, we should branch to the finalization block of
2408+
// the wsloop (which is always immediately before the loop continuation
2409+
// block). Now the finalization has been created, we can fix the branch.
2410+
llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor();
2411+
for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2412+
assert(cancelBranch->getNumSuccessors() == 1 &&
2413+
"cancel branch should have one target");
2414+
cancelBranch->setSuccessor(0, wsloopFini);
2415+
}
2416+
}
2417+
23822418
// Process the reductions if required.
23832419
if (failed(createReductionsAndCleanup(
23842420
wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,

mlir/test/Target/LLVMIR/openmp-cancel.mlir

+87
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,90 @@ llvm.func @cancel_sections_if(%cond : i1) {
156156
// CHECK: ret void
157157
// CHECK: .cncl: ; preds = %[[VAL_27]]
158158
// CHECK: br label %[[VAL_19]]
159+
160+
llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
161+
omp.wsloop {
162+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
163+
omp.cancel cancellation_construct_type(loop) if(%cond)
164+
omp.yield
165+
}
166+
}
167+
llvm.return
168+
}
169+
// CHECK-LABEL: define void @cancel_wsloop_if
170+
// CHECK: %[[VAL_0:.*]] = alloca i32, align 4
171+
// CHECK: %[[VAL_1:.*]] = alloca i32, align 4
172+
// CHECK: %[[VAL_2:.*]] = alloca i32, align 4
173+
// CHECK: %[[VAL_3:.*]] = alloca i32, align 4
174+
// CHECK: br label %[[VAL_4:.*]]
175+
// CHECK: omp.region.after_alloca: ; preds = %[[VAL_5:.*]]
176+
// CHECK: br label %[[VAL_6:.*]]
177+
// CHECK: entry: ; preds = %[[VAL_4]]
178+
// CHECK: br label %[[VAL_7:.*]]
179+
// CHECK: omp.wsloop.region: ; preds = %[[VAL_6]]
180+
// CHECK: %[[VAL_8:.*]] = icmp slt i32 %[[VAL_9:.*]], 0
181+
// CHECK: %[[VAL_10:.*]] = sub i32 0, %[[VAL_9]]
182+
// CHECK: %[[VAL_11:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_10]], i32 %[[VAL_9]]
183+
// CHECK: %[[VAL_12:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_13:.*]], i32 %[[VAL_14:.*]]
184+
// CHECK: %[[VAL_15:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_14]], i32 %[[VAL_13]]
185+
// CHECK: %[[VAL_16:.*]] = sub nsw i32 %[[VAL_15]], %[[VAL_12]]
186+
// CHECK: %[[VAL_17:.*]] = icmp sle i32 %[[VAL_15]], %[[VAL_12]]
187+
// CHECK: %[[VAL_18:.*]] = sub i32 %[[VAL_16]], 1
188+
// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_18]], %[[VAL_11]]
189+
// CHECK: %[[VAL_20:.*]] = add i32 %[[VAL_19]], 1
190+
// CHECK: %[[VAL_21:.*]] = icmp ule i32 %[[VAL_16]], %[[VAL_11]]
191+
// CHECK: %[[VAL_22:.*]] = select i1 %[[VAL_21]], i32 1, i32 %[[VAL_20]]
192+
// CHECK: %[[VAL_23:.*]] = select i1 %[[VAL_17]], i32 0, i32 %[[VAL_22]]
193+
// CHECK: br label %[[VAL_24:.*]]
194+
// CHECK: omp_loop.preheader: ; preds = %[[VAL_7]]
195+
// CHECK: store i32 0, ptr %[[VAL_1]], align 4
196+
// CHECK: %[[VAL_25:.*]] = sub i32 %[[VAL_23]], 1
197+
// CHECK: store i32 %[[VAL_25]], ptr %[[VAL_2]], align 4
198+
// CHECK: store i32 1, ptr %[[VAL_3]], align 4
199+
// CHECK: %[[VAL_26:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
200+
// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_26]], i32 34, ptr %[[VAL_0]], ptr %[[VAL_1]], ptr %[[VAL_2]], ptr %[[VAL_3]], i32 1, i32 0)
201+
// CHECK: %[[VAL_27:.*]] = load i32, ptr %[[VAL_1]], align 4
202+
// CHECK: %[[VAL_28:.*]] = load i32, ptr %[[VAL_2]], align 4
203+
// CHECK: %[[VAL_29:.*]] = sub i32 %[[VAL_28]], %[[VAL_27]]
204+
// CHECK: %[[VAL_30:.*]] = add i32 %[[VAL_29]], 1
205+
// CHECK: br label %[[VAL_31:.*]]
206+
// CHECK: omp_loop.header: ; preds = %[[VAL_32:.*]], %[[VAL_24]]
207+
// CHECK: %[[VAL_33:.*]] = phi i32 [ 0, %[[VAL_24]] ], [ %[[VAL_34:.*]], %[[VAL_32]] ]
208+
// CHECK: br label %[[VAL_35:.*]]
209+
// CHECK: omp_loop.cond: ; preds = %[[VAL_31]]
210+
// CHECK: %[[VAL_36:.*]] = icmp ult i32 %[[VAL_33]], %[[VAL_30]]
211+
// CHECK: br i1 %[[VAL_36]], label %[[VAL_37:.*]], label %[[VAL_38:.*]]
212+
// CHECK: omp_loop.body: ; preds = %[[VAL_35]]
213+
// CHECK: %[[VAL_39:.*]] = add i32 %[[VAL_33]], %[[VAL_27]]
214+
// CHECK: %[[VAL_40:.*]] = mul i32 %[[VAL_39]], %[[VAL_9]]
215+
// CHECK: %[[VAL_41:.*]] = add i32 %[[VAL_40]], %[[VAL_14]]
216+
// CHECK: br label %[[VAL_42:.*]]
217+
// CHECK: omp.loop_nest.region: ; preds = %[[VAL_37]]
218+
// CHECK: br i1 %[[VAL_43:.*]], label %[[VAL_44:.*]], label %[[VAL_45:.*]]
219+
// CHECK: 25: ; preds = %[[VAL_42]]
220+
// CHECK: %[[VAL_46:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
221+
// CHECK: %[[VAL_47:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_46]], i32 2)
222+
// CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0
223+
// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]]
224+
// CHECK: .split: ; preds = %[[VAL_44]]
225+
// CHECK: br label %[[VAL_51:.*]]
226+
// CHECK: 28: ; preds = %[[VAL_42]]
227+
// CHECK: br label %[[VAL_51]]
228+
// CHECK: 29: ; preds = %[[VAL_45]], %[[VAL_49]]
229+
// CHECK: br label %[[VAL_52:.*]]
230+
// CHECK: omp.region.cont1: ; preds = %[[VAL_51]]
231+
// CHECK: br label %[[VAL_32]]
232+
// CHECK: omp_loop.inc: ; preds = %[[VAL_52]]
233+
// CHECK: %[[VAL_34]] = add nuw i32 %[[VAL_33]], 1
234+
// CHECK: br label %[[VAL_31]]
235+
// CHECK: omp_loop.exit: ; preds = %[[VAL_50]], %[[VAL_35]]
236+
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_26]])
237+
// CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
238+
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_53]])
239+
// CHECK: br label %[[VAL_54:.*]]
240+
// CHECK: omp_loop.after: ; preds = %[[VAL_38]]
241+
// CHECK: br label %[[VAL_55:.*]]
242+
// CHECK: omp.region.cont: ; preds = %[[VAL_54]]
243+
// CHECK: ret void
244+
// CHECK: .cncl: ; preds = %[[VAL_44]]
245+
// CHECK: br label %[[VAL_38]]

mlir/test/Target/LLVMIR/openmp-todo.mlir

-16
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {
2626

2727
// -----
2828

29-
llvm.func @cancel_wsloop(%lb : i32, %ub : i32, %step: i32) {
30-
// expected-error@below {{LLVM Translation failed for operation: omp.wsloop}}
31-
omp.wsloop {
32-
// expected-error@below {{LLVM Translation failed for operation: omp.loop_nest}}
33-
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
34-
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancel operation}}
35-
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
36-
omp.cancel cancellation_construct_type(loop)
37-
omp.yield
38-
}
39-
}
40-
llvm.return
41-
}
42-
43-
// -----
44-
4529
llvm.func @cancel_taskgroup() {
4630
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
4731
omp.taskgroup {

0 commit comments

Comments
 (0)