@@ -469,5 +469,128 @@ class OpenACCHostDataConstruct final
469
469
return const_cast <OpenACCHostDataConstruct *>(this )->getStructuredBlock ();
470
470
}
471
471
};
472
+
473
+ // This class represents a 'wait' construct, which has some expressions plus a
474
+ // clause list.
475
+ class OpenACCWaitConstruct final
476
+ : public OpenACCConstructStmt,
477
+ private llvm::TrailingObjects<OpenACCWaitConstruct, Expr *,
478
+ OpenACCClause *> {
479
+ // FIXME: We should be storing a `const OpenACCClause *` to be consistent with
480
+ // the rest of the constructs, but TrailingObjects doesn't allow for mixing
481
+ // constness in its implementation of `getTrailingObjects`.
482
+
483
+ friend TrailingObjects;
484
+ friend class ASTStmtWriter ;
485
+ friend class ASTStmtReader ;
486
+ // Locations of the left and right parens of the 'wait-argument'
487
+ // expression-list.
488
+ SourceLocation LParenLoc, RParenLoc;
489
+ // Location of the 'queues' keyword, if present.
490
+ SourceLocation QueuesLoc;
491
+
492
+ // Number of the expressions being represented. Index '0' is always the
493
+ // 'devnum' expression, even if it not present.
494
+ unsigned NumExprs = 0 ;
495
+
496
+ OpenACCWaitConstruct (unsigned NumExprs, unsigned NumClauses)
497
+ : OpenACCConstructStmt(OpenACCWaitConstructClass,
498
+ OpenACCDirectiveKind::Wait, SourceLocation{},
499
+ SourceLocation{}, SourceLocation{}),
500
+ NumExprs (NumExprs) {
501
+ assert (NumExprs >= 1 &&
502
+ " NumExprs should always be >= 1 because the 'devnum' "
503
+ " expr is represented by a null if necessary" );
504
+ std::uninitialized_value_construct (getExprPtr (),
505
+ getExprPtr () + NumExprs);
506
+ std::uninitialized_value_construct (getTrailingObjects<OpenACCClause *>(),
507
+ getTrailingObjects<OpenACCClause *>() +
508
+ NumClauses);
509
+ setClauseList (MutableArrayRef (const_cast <const OpenACCClause **>(
510
+ getTrailingObjects<OpenACCClause *>()),
511
+ NumClauses));
512
+ }
513
+
514
+ OpenACCWaitConstruct (SourceLocation Start, SourceLocation DirectiveLoc,
515
+ SourceLocation LParenLoc, Expr *DevNumExpr,
516
+ SourceLocation QueuesLoc, ArrayRef<Expr *> QueueIdExprs,
517
+ SourceLocation RParenLoc, SourceLocation End,
518
+ ArrayRef<const OpenACCClause *> Clauses)
519
+ : OpenACCConstructStmt(OpenACCWaitConstructClass,
520
+ OpenACCDirectiveKind::Wait, Start, DirectiveLoc,
521
+ End),
522
+ LParenLoc (LParenLoc), RParenLoc(RParenLoc), QueuesLoc(QueuesLoc),
523
+ NumExprs(QueueIdExprs.size() + 1) {
524
+ assert (NumExprs >= 1 &&
525
+ " NumExprs should always be >= 1 because the 'devnum' "
526
+ " expr is represented by a null if necessary" );
527
+
528
+ std::uninitialized_copy (&DevNumExpr, &DevNumExpr + 1 ,
529
+ getExprPtr ());
530
+ std::uninitialized_copy (QueueIdExprs.begin (), QueueIdExprs.end (),
531
+ getExprPtr () + 1 );
532
+
533
+ std::uninitialized_copy (const_cast <OpenACCClause **>(Clauses.begin ()),
534
+ const_cast <OpenACCClause **>(Clauses.end ()),
535
+ getTrailingObjects<OpenACCClause *>());
536
+ setClauseList (MutableArrayRef (const_cast <const OpenACCClause **>(
537
+ getTrailingObjects<OpenACCClause *>()),
538
+ Clauses.size ()));
539
+ }
540
+
541
+ size_t numTrailingObjects (OverloadToken<Expr *>) const { return NumExprs; }
542
+ size_t numTrailingObjects (OverloadToken<const OpenACCClause *>) const {
543
+ return clauses ().size ();
544
+ }
545
+
546
+ Expr **getExprPtr () const {
547
+ return const_cast <Expr**>(getTrailingObjects<Expr *>());
548
+ }
549
+
550
+ llvm::ArrayRef<Expr *> getExprs () const {
551
+ return llvm::ArrayRef<Expr *>(getExprPtr (), NumExprs);
552
+ }
553
+
554
+ llvm::ArrayRef<Expr *> getExprs () {
555
+ return llvm::ArrayRef<Expr *>(getExprPtr (), NumExprs);
556
+ }
557
+
558
+ public:
559
+ static bool classof (const Stmt *T) {
560
+ return T->getStmtClass () == OpenACCWaitConstructClass;
561
+ }
562
+
563
+ static OpenACCWaitConstruct *
564
+ CreateEmpty (const ASTContext &C, unsigned NumExprs, unsigned NumClauses);
565
+
566
+ static OpenACCWaitConstruct *
567
+ Create (const ASTContext &C, SourceLocation Start, SourceLocation DirectiveLoc,
568
+ SourceLocation LParenLoc, Expr *DevNumExpr, SourceLocation QueuesLoc,
569
+ ArrayRef<Expr *> QueueIdExprs, SourceLocation RParenLoc,
570
+ SourceLocation End, ArrayRef<const OpenACCClause *> Clauses);
571
+
572
+ SourceLocation getLParenLoc () const { return LParenLoc; }
573
+ SourceLocation getRParenLoc () const { return RParenLoc; }
574
+ bool hasQueuesTag () const { return !QueuesLoc.isInvalid (); }
575
+ SourceLocation getQueuesLoc () const { return QueuesLoc; }
576
+
577
+ bool hasDevNumExpr () const { return getExprs ()[0 ]; }
578
+ Expr *getDevNumExpr () const { return getExprs ()[0 ]; }
579
+ llvm::ArrayRef<Expr *> getQueueIdExprs () { return getExprs ().drop_front (); }
580
+ llvm::ArrayRef<Expr *> getQueueIdExprs () const {
581
+ return getExprs ().drop_front ();
582
+ }
583
+
584
+ child_range children () {
585
+ Stmt **Begin = reinterpret_cast <Stmt **>(getExprPtr ());
586
+ return child_range (Begin, Begin + NumExprs);
587
+ }
588
+
589
+ const_child_range children () const {
590
+ Stmt *const *Begin =
591
+ reinterpret_cast <Stmt *const *>(getExprPtr ());
592
+ return const_child_range (Begin, Begin + NumExprs);
593
+ }
594
+ };
472
595
} // namespace clang
473
596
#endif // LLVM_CLANG_AST_STMTOPENACC_H
0 commit comments