Skip to content

Commit 80182a7

Browse files
committed
[OpenACC][CIR] Implement 'wait' directive lowering
This construct has a couple of 'intexprs' which are lowered the same way as clauses, plus has a pair of simple clauses that needed lowering. This patch does all of that.
1 parent 7d71164 commit 80182a7

File tree

3 files changed

+135
-19
lines changed

3 files changed

+135
-19
lines changed

clang/lib/CIR/CodeGen/CIRGenFunction.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -626,10 +626,9 @@ class CIRGenFunction : public CIRGenTypeCache {
626626
//===--------------------------------------------------------------------===//
627627
private:
628628
template <typename Op>
629-
mlir::LogicalResult
630-
emitOpenACCOp(mlir::Location start, OpenACCDirectiveKind dirKind,
631-
SourceLocation dirLoc,
632-
llvm::ArrayRef<const OpenACCClause *> clauses);
629+
Op emitOpenACCOp(mlir::Location start, OpenACCDirectiveKind dirKind,
630+
SourceLocation dirLoc,
631+
llvm::ArrayRef<const OpenACCClause *> clauses);
633632
// Function to do the basic implementation of an operation with an Associated
634633
// Statement. Models AssociatedStmtConstruct.
635634
template <typename Op, typename TermOp>

clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp

+55-15
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,18 @@ class OpenACCClauseCIREmitter final
317317
operation.getAsyncOperandsDeviceTypeAttr(),
318318
createIntExpr(clause.getIntExpr()), range));
319319
}
320+
} else if constexpr (isOneOfTypes<OpTy, WaitOp>) {
321+
// Wait doesn't have a device_type, so its handling here is slightly
322+
// different.
323+
if (!clause.hasIntExpr())
324+
operation.setAsync(true);
325+
else
326+
operation.getAsyncOperandMutable().append(
327+
createIntExpr(clause.getIntExpr()));
320328
} else {
321329
// TODO: When we've implemented this for everything, switch this to an
322330
// unreachable. Combined constructs remain. Data, enter data, exit data,
323-
// update, wait, combined constructs remain.
331+
// update, combined constructs remain.
324332
return clauseNotImplemented(clause);
325333
}
326334
}
@@ -345,15 +353,15 @@ class OpenACCClauseCIREmitter final
345353

346354
void VisitIfClause(const OpenACCIfClause &clause) {
347355
if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, InitOp,
348-
ShutdownOp, SetOp, DataOp>) {
356+
ShutdownOp, SetOp, DataOp, WaitOp>) {
349357
operation.getIfCondMutable().append(
350358
createCondition(clause.getConditionExpr()));
351359
} else {
352360
// 'if' applies to most of the constructs, but hold off on lowering them
353361
// until we can write tests/know what we're doing with codegen to make
354362
// sure we get it right.
355363
// TODO: When we've implemented this for everything, switch this to an
356-
// unreachable. Enter data, exit data, host_data, update, wait, combined
364+
// unreachable. Enter data, exit data, host_data, update, combined
357365
// constructs remain.
358366
return clauseNotImplemented(clause);
359367
}
@@ -444,11 +452,9 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
444452
}
445453

446454
template <typename Op>
447-
mlir::LogicalResult CIRGenFunction::emitOpenACCOp(
455+
Op CIRGenFunction::emitOpenACCOp(
448456
mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
449457
llvm::ArrayRef<const OpenACCClause *> clauses) {
450-
mlir::LogicalResult res = mlir::success();
451-
452458
llvm::SmallVector<mlir::Type> retTy;
453459
llvm::SmallVector<mlir::Value> operands;
454460
auto op = builder.create<Op>(start, retTy, operands);
@@ -461,7 +467,7 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOp(
461467
makeClauseEmitter(op, *this, builder, dirKind, dirLoc)
462468
.VisitClauseList(clauses);
463469
}
464-
return res;
470+
return op;
465471
}
466472

467473
mlir::LogicalResult
@@ -500,22 +506,61 @@ CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
500506
mlir::LogicalResult
501507
CIRGenFunction::emitOpenACCInitConstruct(const OpenACCInitConstruct &s) {
502508
mlir::Location start = getLoc(s.getSourceRange().getBegin());
503-
return emitOpenACCOp<InitOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
509+
emitOpenACCOp<InitOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
504510
s.clauses());
511+
return mlir::success();
505512
}
506513

507514
mlir::LogicalResult
508515
CIRGenFunction::emitOpenACCSetConstruct(const OpenACCSetConstruct &s) {
509516
mlir::Location start = getLoc(s.getSourceRange().getBegin());
510-
return emitOpenACCOp<SetOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
517+
emitOpenACCOp<SetOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
511518
s.clauses());
519+
return mlir::success();
512520
}
513521

514522
mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct(
515523
const OpenACCShutdownConstruct &s) {
516524
mlir::Location start = getLoc(s.getSourceRange().getBegin());
517-
return emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind(),
525+
emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind(),
518526
s.getDirectiveLoc(), s.clauses());
527+
return mlir::success();
528+
}
529+
530+
mlir::LogicalResult
531+
CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {
532+
mlir::Location start = getLoc(s.getSourceRange().getBegin());
533+
auto waitOp = emitOpenACCOp<WaitOp>(start, s.getDirectiveKind(),
534+
s.getDirectiveLoc(), s.clauses());
535+
536+
auto createIntExpr = [this](const Expr *intExpr) {
537+
mlir::Value expr = emitScalarExpr(intExpr);
538+
mlir::Location exprLoc = cgm.getLoc(intExpr->getBeginLoc());
539+
540+
mlir::IntegerType targetType = mlir::IntegerType::get(
541+
&getMLIRContext(), getContext().getIntWidth(intExpr->getType()),
542+
intExpr->getType()->isSignedIntegerOrEnumerationType()
543+
? mlir::IntegerType::SignednessSemantics::Signed
544+
: mlir::IntegerType::SignednessSemantics::Unsigned);
545+
546+
auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
547+
exprLoc, targetType, expr);
548+
return conversionOp.getResult(0);
549+
};
550+
551+
// Emit the correct 'wait' clauses.
552+
{
553+
mlir::OpBuilder::InsertionGuard guardCase(builder);
554+
builder.setInsertionPoint(waitOp);
555+
556+
if (s.hasDevNumExpr())
557+
waitOp.getWaitDevnumMutable().append(createIntExpr(s.getDevNumExpr()));
558+
559+
for (Expr *QueueExpr : s.getQueueIdExprs())
560+
waitOp.getWaitOperandsMutable().append(createIntExpr(QueueExpr));
561+
}
562+
563+
return mlir::success();
519564
}
520565

521566
mlir::LogicalResult
@@ -544,11 +589,6 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCHostDataConstruct(
544589
return mlir::failure();
545590
}
546591
mlir::LogicalResult
547-
CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {
548-
cgm.errorNYI(s.getSourceRange(), "OpenACC Wait Construct");
549-
return mlir::failure();
550-
}
551-
mlir::LogicalResult
552592
CIRGenFunction::emitOpenACCUpdateConstruct(const OpenACCUpdateConstruct &s) {
553593
cgm.errorNYI(s.getSourceRange(), "OpenACC Update Construct");
554594
return mlir::failure();

clang/test/CIR/CodeGenOpenACC/wait.c

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// RUN: %clang_cc1 -fopenacc -emit-cir -fclangir %s -o - | FileCheck %s
2+
3+
void acc_wait(int cond) {
4+
// CHECK: cir.func @acc_wait(%[[ARG:.*]]: !s32i{{.*}}) {
5+
// CHECK-NEXT: %[[COND:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["cond", init]
6+
// CHECK-NEXT: cir.store %[[ARG]], %[[COND]] : !s32i, !cir.ptr<!s32i>
7+
8+
#pragma acc wait
9+
// CHECK-NEXT: acc.wait
10+
11+
#pragma acc wait if (cond)
12+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
13+
// CHECK-NEXT: %[[BOOL_CAST:.*]] = cir.cast(int_to_bool, %[[COND_LOAD]] : !s32i), !cir.bool
14+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOOL_CAST]] : !cir.bool to i1
15+
// CHECK-NEXT: acc.wait if(%[[CONV_CAST]])
16+
17+
#pragma acc wait async
18+
// CHECK-NEXT: acc.wait attributes {async}
19+
20+
#pragma acc wait async(cond)
21+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
22+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
23+
// CHECK-NEXT: acc.wait async(%[[CONV_CAST]] : si32) loc
24+
25+
#pragma acc wait(1)
26+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
27+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
28+
// CHECK-NEXT: acc.wait(%[[ONE_CAST]] : si32) loc
29+
30+
#pragma acc wait(1, 2) async
31+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
32+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
33+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
34+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
35+
// CHECK-NEXT: acc.wait(%[[ONE_CAST]], %[[TWO_CAST]] : si32, si32) attributes {async}
36+
37+
38+
#pragma acc wait(queues:1) if (cond)
39+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
40+
// CHECK-NEXT: %[[BOOL_CAST:.*]] = cir.cast(int_to_bool, %[[COND_LOAD]] : !s32i), !cir.bool
41+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOOL_CAST]] : !cir.bool to i1
42+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
43+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
44+
// CHECK-NEXT: acc.wait(%[[ONE_CAST]] : si32) if(%[[CONV_CAST]])
45+
46+
#pragma acc wait(queues:1, 2) async(cond)
47+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
48+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[COND_LOAD]] : !s32i to si32
49+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
50+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
51+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
52+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
53+
// CHECK-NEXT: acc.wait(%[[ONE_CAST]], %[[TWO_CAST]] : si32, si32) async(%[[CONV_CAST]] : si32) loc
54+
55+
#pragma acc wait(devnum:1: 2, 3) if (cond)
56+
// CHECK-NEXT: %[[COND_LOAD:.*]] = cir.load %[[COND]] : !cir.ptr<!s32i>, !s32i
57+
// CHECK-NEXT: %[[BOOL_CAST:.*]] = cir.cast(int_to_bool, %[[COND_LOAD]] : !s32i), !cir.bool
58+
// CHECK-NEXT: %[[CONV_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOOL_CAST]] : !cir.bool to i1
59+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
60+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
61+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
62+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
63+
// CHECK-NEXT: %[[THREE_LITERAL:.*]] = cir.const #cir.int<3> : !s32i
64+
// CHECK-NEXT: %[[THREE_CAST:.*]] = builtin.unrealized_conversion_cast %[[THREE_LITERAL]] : !s32i to si32
65+
// CHECK-NEXT: acc.wait(%[[TWO_CAST]], %[[THREE_CAST]] : si32, si32) wait_devnum(%[[ONE_CAST]] : si32) if(%[[CONV_CAST]]) loc
66+
67+
#pragma acc wait(devnum:1: queues: 2, 3) async
68+
// CHECK-NEXT: %[[ONE_LITERAL:.*]] = cir.const #cir.int<1> : !s32i
69+
// CHECK-NEXT: %[[ONE_CAST:.*]] = builtin.unrealized_conversion_cast %[[ONE_LITERAL]] : !s32i to si32
70+
// CHECK-NEXT: %[[TWO_LITERAL:.*]] = cir.const #cir.int<2> : !s32i
71+
// CHECK-NEXT: %[[TWO_CAST:.*]] = builtin.unrealized_conversion_cast %[[TWO_LITERAL]] : !s32i to si32
72+
// CHECK-NEXT: %[[THREE_LITERAL:.*]] = cir.const #cir.int<3> : !s32i
73+
// CHECK-NEXT: %[[THREE_CAST:.*]] = builtin.unrealized_conversion_cast %[[THREE_LITERAL]] : !s32i to si32
74+
// CHECK-NEXT: acc.wait(%[[TWO_CAST]], %[[THREE_CAST]] : si32, si32) wait_devnum(%[[ONE_CAST]] : si32) attributes {async}
75+
76+
// CHECK-NEXT: cir.return
77+
}

0 commit comments

Comments
 (0)