@@ -317,10 +317,18 @@ class OpenACCClauseCIREmitter final
317
317
operation.getAsyncOperandsDeviceTypeAttr (),
318
318
createIntExpr (clause.getIntExpr ()), range));
319
319
}
320
+ } else if constexpr (isOneOfTypes<OpTy, WaitOp>) {
321
+ // Wait doesn't have a device_type, so its handling here is slightly
322
+ // different.
323
+ if (!clause.hasIntExpr ())
324
+ operation.setAsync (true );
325
+ else
326
+ operation.getAsyncOperandMutable ().append (
327
+ createIntExpr (clause.getIntExpr ()));
320
328
} else {
321
329
// TODO: When we've implemented this for everything, switch this to an
322
330
// unreachable. Combined constructs remain. Data, enter data, exit data,
323
- // update, wait, combined constructs remain.
331
+ // update, combined constructs remain.
324
332
return clauseNotImplemented (clause);
325
333
}
326
334
}
@@ -345,15 +353,15 @@ class OpenACCClauseCIREmitter final
345
353
346
354
void VisitIfClause (const OpenACCIfClause &clause) {
347
355
if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, InitOp,
348
- ShutdownOp, SetOp, DataOp>) {
356
+ ShutdownOp, SetOp, DataOp, WaitOp >) {
349
357
operation.getIfCondMutable ().append (
350
358
createCondition (clause.getConditionExpr ()));
351
359
} else {
352
360
// 'if' applies to most of the constructs, but hold off on lowering them
353
361
// until we can write tests/know what we're doing with codegen to make
354
362
// sure we get it right.
355
363
// TODO: When we've implemented this for everything, switch this to an
356
- // unreachable. Enter data, exit data, host_data, update, wait, combined
364
+ // unreachable. Enter data, exit data, host_data, update, combined
357
365
// constructs remain.
358
366
return clauseNotImplemented (clause);
359
367
}
@@ -444,11 +452,9 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
444
452
}
445
453
446
454
template <typename Op>
447
- mlir::LogicalResult CIRGenFunction::emitOpenACCOp (
455
+ Op CIRGenFunction::emitOpenACCOp (
448
456
mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
449
457
llvm::ArrayRef<const OpenACCClause *> clauses) {
450
- mlir::LogicalResult res = mlir::success ();
451
-
452
458
llvm::SmallVector<mlir::Type> retTy;
453
459
llvm::SmallVector<mlir::Value> operands;
454
460
auto op = builder.create <Op>(start, retTy, operands);
@@ -461,7 +467,7 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCOp(
461
467
makeClauseEmitter (op, *this , builder, dirKind, dirLoc)
462
468
.VisitClauseList (clauses);
463
469
}
464
- return res ;
470
+ return op ;
465
471
}
466
472
467
473
mlir::LogicalResult
@@ -500,22 +506,61 @@ CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
500
506
mlir::LogicalResult
501
507
CIRGenFunction::emitOpenACCInitConstruct (const OpenACCInitConstruct &s) {
502
508
mlir::Location start = getLoc (s.getSourceRange ().getBegin ());
503
- return emitOpenACCOp<InitOp>(start, s.getDirectiveKind (), s.getDirectiveLoc (),
509
+ emitOpenACCOp<InitOp>(start, s.getDirectiveKind (), s.getDirectiveLoc (),
504
510
s.clauses ());
511
+ return mlir::success ();
505
512
}
506
513
507
514
mlir::LogicalResult
508
515
CIRGenFunction::emitOpenACCSetConstruct (const OpenACCSetConstruct &s) {
509
516
mlir::Location start = getLoc (s.getSourceRange ().getBegin ());
510
- return emitOpenACCOp<SetOp>(start, s.getDirectiveKind (), s.getDirectiveLoc (),
517
+ emitOpenACCOp<SetOp>(start, s.getDirectiveKind (), s.getDirectiveLoc (),
511
518
s.clauses ());
519
+ return mlir::success ();
512
520
}
513
521
514
522
mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct (
515
523
const OpenACCShutdownConstruct &s) {
516
524
mlir::Location start = getLoc (s.getSourceRange ().getBegin ());
517
- return emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind (),
525
+ emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind (),
518
526
s.getDirectiveLoc (), s.clauses ());
527
+ return mlir::success ();
528
+ }
529
+
530
+ mlir::LogicalResult
531
+ CIRGenFunction::emitOpenACCWaitConstruct (const OpenACCWaitConstruct &s) {
532
+ mlir::Location start = getLoc (s.getSourceRange ().getBegin ());
533
+ auto waitOp = emitOpenACCOp<WaitOp>(start, s.getDirectiveKind (),
534
+ s.getDirectiveLoc (), s.clauses ());
535
+
536
+ auto createIntExpr = [this ](const Expr *intExpr) {
537
+ mlir::Value expr = emitScalarExpr (intExpr);
538
+ mlir::Location exprLoc = cgm.getLoc (intExpr->getBeginLoc ());
539
+
540
+ mlir::IntegerType targetType = mlir::IntegerType::get (
541
+ &getMLIRContext (), getContext ().getIntWidth (intExpr->getType ()),
542
+ intExpr->getType ()->isSignedIntegerOrEnumerationType ()
543
+ ? mlir::IntegerType::SignednessSemantics::Signed
544
+ : mlir::IntegerType::SignednessSemantics::Unsigned);
545
+
546
+ auto conversionOp = builder.create <mlir::UnrealizedConversionCastOp>(
547
+ exprLoc, targetType, expr);
548
+ return conversionOp.getResult (0 );
549
+ };
550
+
551
+ // Emit the correct 'wait' clauses.
552
+ {
553
+ mlir::OpBuilder::InsertionGuard guardCase (builder);
554
+ builder.setInsertionPoint (waitOp);
555
+
556
+ if (s.hasDevNumExpr ())
557
+ waitOp.getWaitDevnumMutable ().append (createIntExpr (s.getDevNumExpr ()));
558
+
559
+ for (Expr *QueueExpr : s.getQueueIdExprs ())
560
+ waitOp.getWaitOperandsMutable ().append (createIntExpr (QueueExpr));
561
+ }
562
+
563
+ return mlir::success ();
519
564
}
520
565
521
566
mlir::LogicalResult
@@ -544,11 +589,6 @@ mlir::LogicalResult CIRGenFunction::emitOpenACCHostDataConstruct(
544
589
return mlir::failure ();
545
590
}
546
591
mlir::LogicalResult
547
- CIRGenFunction::emitOpenACCWaitConstruct (const OpenACCWaitConstruct &s) {
548
- cgm.errorNYI (s.getSourceRange (), " OpenACC Wait Construct" );
549
- return mlir::failure ();
550
- }
551
- mlir::LogicalResult
552
592
CIRGenFunction::emitOpenACCUpdateConstruct (const OpenACCUpdateConstruct &s) {
553
593
cgm.errorNYI (s.getSourceRange (), " OpenACC Update Construct" );
554
594
return mlir::failure ();
0 commit comments