Skip to content

Commit e34cc7c

Browse files
committed
[OpenACC] Implement 'wait' construct
The arguments to this are the same as for the 'wait' clause, so this reuses all of that infrastructure. So all this has to do is support a pair of clauses that are already implemented (if and async), plus create an AST node. This patch does so, and adds proper testing.
1 parent e0526b0 commit e34cc7c

31 files changed

+847
-178
lines changed

clang/include/clang-c/Index.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -2186,7 +2186,11 @@ enum CXCursorKind {
21862186
*/
21872187
CXCursor_OpenACCHostDataConstruct = 326,
21882188

2189-
CXCursor_LastStmt = CXCursor_OpenACCHostDataConstruct,
2189+
/** OpenACC wait Construct.
2190+
*/
2191+
CXCursor_OpenACCWaitConstruct = 327,
2192+
2193+
CXCursor_LastStmt = CXCursor_OpenACCWaitConstruct,
21902194

21912195
/**
21922196
* Cursor that represents the translation unit itself.

clang/include/clang/AST/ASTNodeTraverser.h

+11-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class ASTNodeTraverser
159159

160160
// Some statements have custom mechanisms for dumping their children.
161161
if (isa<DeclStmt>(S) || isa<GenericSelectionExpr>(S) ||
162-
isa<RequiresExpr>(S))
162+
isa<RequiresExpr>(S) || isa<OpenACCWaitConstruct>(S))
163163
return;
164164

165165
if (Traversal == TK_IgnoreUnlessSpelledInSource &&
@@ -825,6 +825,16 @@ class ASTNodeTraverser
825825
Visit(C);
826826
}
827827

828+
void VisitOpenACCWaitConstruct(const OpenACCWaitConstruct *Node) {
829+
// Needs custom child checking to put clauses AFTER the children, which are
830+
// the expressions in the 'wait' construct. Others likely need this as well,
831+
// and might need to do the associated statement after it.
832+
for (const Stmt *S : Node->children())
833+
Visit(S);
834+
for (const auto *C : Node->clauses())
835+
Visit(C);
836+
}
837+
828838
void VisitInitListExpr(const InitListExpr *ILE) {
829839
if (auto *Filler = ILE->getArrayFiller()) {
830840
Visit(Filler, "array_filler");

clang/include/clang/AST/RecursiveASTVisitor.h

+11-2
Original file line numberDiff line numberDiff line change
@@ -4063,10 +4063,19 @@ DEF_TRAVERSE_STMT(OpenACCCombinedConstruct,
40634063
{ TRY_TO(TraverseOpenACCAssociatedStmtConstruct(S)); })
40644064
DEF_TRAVERSE_STMT(OpenACCDataConstruct,
40654065
{ TRY_TO(TraverseOpenACCAssociatedStmtConstruct(S)); })
4066-
DEF_TRAVERSE_STMT(OpenACCEnterDataConstruct, {})
4067-
DEF_TRAVERSE_STMT(OpenACCExitDataConstruct, {})
4066+
DEF_TRAVERSE_STMT(OpenACCEnterDataConstruct,
4067+
{ TRY_TO(VisitOpenACCClauseList(S->clauses())); })
4068+
DEF_TRAVERSE_STMT(OpenACCExitDataConstruct,
4069+
{ TRY_TO(VisitOpenACCClauseList(S->clauses())); })
40684070
DEF_TRAVERSE_STMT(OpenACCHostDataConstruct,
40694071
{ TRY_TO(TraverseOpenACCAssociatedStmtConstruct(S)); })
4072+
DEF_TRAVERSE_STMT(OpenACCWaitConstruct, {
4073+
if (S->hasDevNumExpr())
4074+
TRY_TO(TraverseStmt(S->getDevNumExpr()));
4075+
for (auto *E : S->getQueueIdExprs())
4076+
TRY_TO(TraverseStmt(E));
4077+
TRY_TO(VisitOpenACCClauseList(S->clauses()));
4078+
})
40704079

40714080
// Traverse HLSL: Out argument expression
40724081
DEF_TRAVERSE_STMT(HLSLOutArgExpr, {})

clang/include/clang/AST/StmtOpenACC.h

+123
Original file line numberDiff line numberDiff line change
@@ -469,5 +469,128 @@ class OpenACCHostDataConstruct final
469469
return const_cast<OpenACCHostDataConstruct *>(this)->getStructuredBlock();
470470
}
471471
};
472+
473+
// This class represents a 'wait' construct, which has some expressions plus a
474+
// clause list.
475+
class OpenACCWaitConstruct final
476+
: public OpenACCConstructStmt,
477+
private llvm::TrailingObjects<OpenACCWaitConstruct, Expr *,
478+
OpenACCClause *> {
479+
// FIXME: We should be storing a `const OpenACCClause *` to be consistent with
480+
// the rest of the constructs, but TrailingObjects doesn't allow for mixing
481+
// constness in its implementation of `getTrailingObjects`.
482+
483+
friend TrailingObjects;
484+
friend class ASTStmtWriter;
485+
friend class ASTStmtReader;
486+
// Locations of the left and right parens of the 'wait-argument'
487+
// expression-list.
488+
SourceLocation LParenLoc, RParenLoc;
489+
// Location of the 'queues' keyword, if present.
490+
SourceLocation QueuesLoc;
491+
492+
// Number of the expressions being represented. Index '0' is always the
493+
// 'devnum' expression, even if it not present.
494+
unsigned NumExprs = 0;
495+
496+
OpenACCWaitConstruct(unsigned NumExprs, unsigned NumClauses)
497+
: OpenACCConstructStmt(OpenACCWaitConstructClass,
498+
OpenACCDirectiveKind::Wait, SourceLocation{},
499+
SourceLocation{}, SourceLocation{}),
500+
NumExprs(NumExprs) {
501+
assert(NumExprs >= 1 &&
502+
"NumExprs should always be >= 1 because the 'devnum' "
503+
"expr is represented by a null if necessary");
504+
std::uninitialized_value_construct(getExprPtr(),
505+
getExprPtr() + NumExprs);
506+
std::uninitialized_value_construct(getTrailingObjects<OpenACCClause *>(),
507+
getTrailingObjects<OpenACCClause *>() +
508+
NumClauses);
509+
setClauseList(MutableArrayRef(const_cast<const OpenACCClause **>(
510+
getTrailingObjects<OpenACCClause *>()),
511+
NumClauses));
512+
}
513+
514+
OpenACCWaitConstruct(SourceLocation Start, SourceLocation DirectiveLoc,
515+
SourceLocation LParenLoc, Expr *DevNumExpr,
516+
SourceLocation QueuesLoc, ArrayRef<Expr *> QueueIdExprs,
517+
SourceLocation RParenLoc, SourceLocation End,
518+
ArrayRef<const OpenACCClause *> Clauses)
519+
: OpenACCConstructStmt(OpenACCWaitConstructClass,
520+
OpenACCDirectiveKind::Wait, Start, DirectiveLoc,
521+
End),
522+
LParenLoc(LParenLoc), RParenLoc(RParenLoc), QueuesLoc(QueuesLoc),
523+
NumExprs(QueueIdExprs.size() + 1) {
524+
assert(NumExprs >= 1 &&
525+
"NumExprs should always be >= 1 because the 'devnum' "
526+
"expr is represented by a null if necessary");
527+
528+
std::uninitialized_copy(&DevNumExpr, &DevNumExpr + 1,
529+
getExprPtr());
530+
std::uninitialized_copy(QueueIdExprs.begin(), QueueIdExprs.end(),
531+
getExprPtr() + 1);
532+
533+
std::uninitialized_copy(const_cast<OpenACCClause **>(Clauses.begin()),
534+
const_cast<OpenACCClause **>(Clauses.end()),
535+
getTrailingObjects<OpenACCClause *>());
536+
setClauseList(MutableArrayRef(const_cast<const OpenACCClause **>(
537+
getTrailingObjects<OpenACCClause *>()),
538+
Clauses.size()));
539+
}
540+
541+
size_t numTrailingObjects(OverloadToken<Expr *>) const { return NumExprs; }
542+
size_t numTrailingObjects(OverloadToken<const OpenACCClause *>) const {
543+
return clauses().size();
544+
}
545+
546+
Expr **getExprPtr() const {
547+
return const_cast<Expr**>(getTrailingObjects<Expr *>());
548+
}
549+
550+
llvm::ArrayRef<Expr *> getExprs() const {
551+
return llvm::ArrayRef<Expr *>(getExprPtr(), NumExprs);
552+
}
553+
554+
llvm::ArrayRef<Expr *> getExprs() {
555+
return llvm::ArrayRef<Expr *>(getExprPtr(), NumExprs);
556+
}
557+
558+
public:
559+
static bool classof(const Stmt *T) {
560+
return T->getStmtClass() == OpenACCWaitConstructClass;
561+
}
562+
563+
static OpenACCWaitConstruct *
564+
CreateEmpty(const ASTContext &C, unsigned NumExprs, unsigned NumClauses);
565+
566+
static OpenACCWaitConstruct *
567+
Create(const ASTContext &C, SourceLocation Start, SourceLocation DirectiveLoc,
568+
SourceLocation LParenLoc, Expr *DevNumExpr, SourceLocation QueuesLoc,
569+
ArrayRef<Expr *> QueueIdExprs, SourceLocation RParenLoc,
570+
SourceLocation End, ArrayRef<const OpenACCClause *> Clauses);
571+
572+
SourceLocation getLParenLoc() const { return LParenLoc; }
573+
SourceLocation getRParenLoc() const { return RParenLoc; }
574+
bool hasQueuesTag() const { return !QueuesLoc.isInvalid(); }
575+
SourceLocation getQueuesLoc() const { return QueuesLoc; }
576+
577+
bool hasDevNumExpr() const { return getExprs()[0]; }
578+
Expr *getDevNumExpr() const { return getExprs()[0]; }
579+
llvm::ArrayRef<Expr *> getQueueIdExprs() { return getExprs().drop_front(); }
580+
llvm::ArrayRef<Expr *> getQueueIdExprs() const {
581+
return getExprs().drop_front();
582+
}
583+
584+
child_range children() {
585+
Stmt **Begin = reinterpret_cast<Stmt **>(getExprPtr());
586+
return child_range(Begin, Begin + NumExprs);
587+
}
588+
589+
const_child_range children() const {
590+
Stmt *const *Begin =
591+
reinterpret_cast<Stmt *const *>(getExprPtr());
592+
return const_child_range(Begin, Begin + NumExprs);
593+
}
594+
};
472595
} // namespace clang
473596
#endif // LLVM_CLANG_AST_STMTOPENACC_H

clang/include/clang/AST/TextNodeDumper.h

+1
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ class TextNodeDumper
415415
void VisitOpenACCEnterDataConstruct(const OpenACCEnterDataConstruct *S);
416416
void VisitOpenACCExitDataConstruct(const OpenACCExitDataConstruct *S);
417417
void VisitOpenACCHostDataConstruct(const OpenACCHostDataConstruct *S);
418+
void VisitOpenACCWaitConstruct(const OpenACCWaitConstruct *S);
418419
void VisitOpenACCAsteriskSizeExpr(const OpenACCAsteriskSizeExpr *S);
419420
void VisitEmbedExpr(const EmbedExpr *S);
420421
void VisitAtomicExpr(const AtomicExpr *AE);

clang/include/clang/Basic/StmtNodes.td

+1
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def OpenACCDataConstruct : StmtNode<OpenACCAssociatedStmtConstruct>;
312312
def OpenACCEnterDataConstruct : StmtNode<OpenACCConstructStmt>;
313313
def OpenACCExitDataConstruct : StmtNode<OpenACCConstructStmt>;
314314
def OpenACCHostDataConstruct : StmtNode<OpenACCAssociatedStmtConstruct>;
315+
def OpenACCWaitConstruct : StmtNode<OpenACCConstructStmt>;
315316

316317
// OpenACC Additional Expressions.
317318
def OpenACCAsteriskSizeExpr : StmtNode<Expr>;

clang/include/clang/Parse/Parser.h

+13-2
Original file line numberDiff line numberDiff line change
@@ -3706,17 +3706,28 @@ class Parser : public CodeCompletionHandler {
37063706
OpenACCDirectiveKind DirKind;
37073707
SourceLocation StartLoc;
37083708
SourceLocation DirLoc;
3709+
SourceLocation LParenLoc;
3710+
SourceLocation RParenLoc;
37093711
SourceLocation EndLoc;
3712+
SourceLocation MiscLoc;
3713+
SmallVector<Expr *> Exprs;
37103714
SmallVector<OpenACCClause *> Clauses;
3711-
// TODO OpenACC: As we implement support for the Atomic, Routine, Cache, and
3712-
// Wait constructs, we likely want to put that information in here as well.
3715+
// TODO OpenACC: As we implement support for the Atomic, Routine, and Cache
3716+
// constructs, we likely want to put that information in here as well.
37133717
};
37143718

37153719
struct OpenACCWaitParseInfo {
37163720
bool Failed = false;
37173721
Expr *DevNumExpr = nullptr;
37183722
SourceLocation QueuesLoc;
37193723
SmallVector<Expr *> QueueIdExprs;
3724+
3725+
SmallVector<Expr *> getAllExprs() {
3726+
SmallVector<Expr *> Out;
3727+
Out.push_back(DevNumExpr);
3728+
Out.insert(Out.end(), QueueIdExprs.begin(), QueueIdExprs.end());
3729+
return Out;
3730+
}
37203731
};
37213732

37223733
/// Represents the 'error' state of parsing an OpenACC Clause, and stores

clang/include/clang/Sema/SemaOpenACC.h

+12-6
Original file line numberDiff line numberDiff line change
@@ -679,12 +679,18 @@ class SemaOpenACC : public SemaBase {
679679

680680
/// Called after the directive has been completely parsed, including the
681681
/// declaration group or associated statement.
682-
StmtResult ActOnEndStmtDirective(OpenACCDirectiveKind K,
683-
SourceLocation StartLoc,
684-
SourceLocation DirLoc,
685-
SourceLocation EndLoc,
686-
ArrayRef<OpenACCClause *> Clauses,
687-
StmtResult AssocStmt);
682+
/// LParenLoc: Location of the left paren, if it exists (not on all
683+
/// constructs).
684+
/// MiscLoc: First misc location, if necessary (not all constructs).
685+
/// Exprs: List of expressions on the construct itself, if necessary (not all
686+
/// constructs).
687+
/// RParenLoc: Location of the right paren, if it exists (not on all
688+
/// constructs).
689+
StmtResult ActOnEndStmtDirective(
690+
OpenACCDirectiveKind K, SourceLocation StartLoc, SourceLocation DirLoc,
691+
SourceLocation LParenLoc, SourceLocation MiscLoc, ArrayRef<Expr *> Exprs,
692+
SourceLocation RParenLoc, SourceLocation EndLoc,
693+
ArrayRef<OpenACCClause *> Clauses, StmtResult AssocStmt);
688694

689695
/// Called after the directive has been completely parsed, including the
690696
/// declaration group or associated statement.

clang/include/clang/Serialization/ASTBitCodes.h

+1
Original file line numberDiff line numberDiff line change
@@ -2021,6 +2021,7 @@ enum StmtCode {
20212021
STMT_OPENACC_ENTER_DATA_CONSTRUCT,
20222022
STMT_OPENACC_EXIT_DATA_CONSTRUCT,
20232023
STMT_OPENACC_HOST_DATA_CONSTRUCT,
2024+
STMT_OPENACC_WAIT_CONSTRUCT,
20242025

20252026
// HLSL Constructs
20262027
EXPR_HLSL_OUT_ARG,

clang/lib/AST/StmtOpenACC.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,32 @@ OpenACCHostDataConstruct *OpenACCHostDataConstruct::Create(
196196
Clauses, StructuredBlock);
197197
return Inst;
198198
}
199+
200+
OpenACCWaitConstruct *OpenACCWaitConstruct::CreateEmpty(const ASTContext &C,
201+
unsigned NumExprs,
202+
unsigned NumClauses) {
203+
void *Mem = C.Allocate(
204+
OpenACCWaitConstruct::totalSizeToAlloc<Expr *, OpenACCClause *>(
205+
NumExprs, NumClauses));
206+
207+
auto *Inst = new (Mem) OpenACCWaitConstruct(NumExprs, NumClauses);
208+
return Inst;
209+
}
210+
211+
OpenACCWaitConstruct *OpenACCWaitConstruct::Create(
212+
const ASTContext &C, SourceLocation Start, SourceLocation DirectiveLoc,
213+
SourceLocation LParenLoc, Expr *DevNumExpr, SourceLocation QueuesLoc,
214+
ArrayRef<Expr *> QueueIdExprs, SourceLocation RParenLoc, SourceLocation End,
215+
ArrayRef<const OpenACCClause *> Clauses) {
216+
217+
assert(llvm::all_of(QueueIdExprs, [](Expr *E) { return E != nullptr; }));
218+
219+
void *Mem = C.Allocate(
220+
OpenACCWaitConstruct::totalSizeToAlloc<Expr *, OpenACCClause *>(
221+
QueueIdExprs.size() + 1, Clauses.size()));
222+
223+
auto *Inst = new (Mem)
224+
OpenACCWaitConstruct(Start, DirectiveLoc, LParenLoc, DevNumExpr,
225+
QueuesLoc, QueueIdExprs, RParenLoc, End, Clauses);
226+
return Inst;
227+
}

clang/lib/AST/StmtPrinter.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,34 @@ void StmtPrinter::VisitOpenACCHostDataConstruct(OpenACCHostDataConstruct *S) {
12381238
PrintStmt(S->getStructuredBlock());
12391239
}
12401240

1241+
void StmtPrinter::VisitOpenACCWaitConstruct(OpenACCWaitConstruct *S) {
1242+
Indent() << "#pragma acc wait";
1243+
if (!S->getLParenLoc().isInvalid()) {
1244+
OS << "(";
1245+
if (S->hasDevNumExpr()) {
1246+
OS << "devnum: ";
1247+
S->getDevNumExpr()->printPretty(OS, nullptr, Policy);
1248+
OS << " : ";
1249+
}
1250+
1251+
if (S->hasQueuesTag())
1252+
OS << "queues: ";
1253+
1254+
llvm::interleaveComma(S->getQueueIdExprs(), OS, [&](const Expr *E) {
1255+
E->printPretty(OS, nullptr, Policy);
1256+
});
1257+
1258+
OS << ")";
1259+
}
1260+
1261+
if (!S->clauses().empty()) {
1262+
OS << ' ';
1263+
OpenACCClausePrinter Printer(OS, Policy);
1264+
Printer.VisitClauseList(S->clauses());
1265+
}
1266+
OS << '\n';
1267+
}
1268+
12411269
//===----------------------------------------------------------------------===//
12421270
// Expr printing methods.
12431271
//===----------------------------------------------------------------------===//

clang/lib/AST/StmtProfile.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -2743,6 +2743,14 @@ void StmtProfiler::VisitOpenACCHostDataConstruct(
27432743
P.VisitOpenACCClauseList(S->clauses());
27442744
}
27452745

2746+
void StmtProfiler::VisitOpenACCWaitConstruct(const OpenACCWaitConstruct *S) {
2747+
// VisitStmt covers 'children', so the exprs inside of it are covered.
2748+
VisitStmt(S);
2749+
2750+
OpenACCClauseProfiler P{*this};
2751+
P.VisitOpenACCClauseList(S->clauses());
2752+
}
2753+
27462754
void StmtProfiler::VisitHLSLOutArgExpr(const HLSLOutArgExpr *S) {
27472755
VisitStmt(S);
27482756
}

clang/lib/AST/TextNodeDumper.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -2960,6 +2960,10 @@ void TextNodeDumper::VisitOpenACCHostDataConstruct(
29602960
OS << " " << S->getDirectiveKind();
29612961
}
29622962

2963+
void TextNodeDumper::VisitOpenACCWaitConstruct(const OpenACCWaitConstruct *S) {
2964+
OS << " " << S->getDirectiveKind();
2965+
}
2966+
29632967
void TextNodeDumper::VisitEmbedExpr(const EmbedExpr *S) {
29642968
AddChild("begin", [=] { OS << S->getStartingElementPos(); });
29652969
AddChild("number of elements", [=] { OS << S->getDataElementCount(); });

clang/lib/CodeGen/CGStmt.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,9 @@ void CodeGenFunction::EmitStmt(const Stmt *S, ArrayRef<const Attr *> Attrs) {
470470
case Stmt::OpenACCHostDataConstructClass:
471471
EmitOpenACCHostDataConstruct(cast<OpenACCHostDataConstruct>(*S));
472472
break;
473+
case Stmt::OpenACCWaitConstructClass:
474+
EmitOpenACCWaitConstruct(cast<OpenACCWaitConstruct>(*S));
475+
break;
473476
}
474477
}
475478

clang/lib/CodeGen/CodeGenFunction.h

+5
Original file line numberDiff line numberDiff line change
@@ -4118,6 +4118,11 @@ class CodeGenFunction : public CodeGenTypeCache {
41184118
EmitStmt(S.getStructuredBlock());
41194119
}
41204120

4121+
void EmitOpenACCWaitConstruct(const OpenACCWaitConstruct &S) {
4122+
// TODO OpenACC: Implement this. It is currently implemented as a 'no-op',
4123+
// but in the future we will implement some sort of IR.
4124+
}
4125+
41214126
//===--------------------------------------------------------------------===//
41224127
// LValue Expression Emission
41234128
//===--------------------------------------------------------------------===//

0 commit comments

Comments
 (0)