@@ -3495,9 +3495,9 @@ checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3495
3495
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
3496
3496
const LocationDescription &Loc, InsertPointTy AllocaIP,
3497
3497
InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
3498
- bool IsNoWait, bool IsTeamsReduction, bool HasDistribute ,
3499
- ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
3500
- unsigned ReductionBufNum, Value *SrcLocInfo) {
3498
+ bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind ,
3499
+ std::optional<omp::GV> GridValue, unsigned ReductionBufNum ,
3500
+ Value *SrcLocInfo) {
3501
3501
if (!updateToLocation(Loc))
3502
3502
return InsertPointTy();
3503
3503
Builder.restoreIP(CodeGenIP);
@@ -3514,6 +3514,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
3514
3514
if (ReductionInfos.size() == 0)
3515
3515
return Builder.saveIP();
3516
3516
3517
+ BasicBlock *ContinuationBlock = nullptr;
3518
+ if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
3519
+ // Copied code from createReductions
3520
+ BasicBlock *InsertBlock = Loc.IP.getBlock();
3521
+ ContinuationBlock =
3522
+ InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
3523
+ InsertBlock->getTerminator()->eraseFromParent();
3524
+ Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
3525
+ }
3526
+
3517
3527
Function *CurFunc = Builder.GetInsertBlock()->getParent();
3518
3528
AttributeList FuncAttrs;
3519
3529
AttrBuilder AttrBldr(Ctx);
@@ -3669,11 +3679,21 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
3669
3679
ReductionFunc;
3670
3680
});
3671
3681
} else {
3672
- assert(false && "Unhandled ReductionGenCBKind");
3682
+ Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs");
3683
+ Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs");
3684
+ Value *Reduced;
3685
+ InsertPointOrErrorTy AfterIP =
3686
+ RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
3687
+ if (!AfterIP)
3688
+ return AfterIP.takeError();
3689
+ Builder.CreateStore(Reduced, LHS, false);
3673
3690
}
3674
3691
}
3675
3692
emitBlock(ExitBB, CurFunc);
3676
-
3693
+ if (ContinuationBlock) {
3694
+ Builder.CreateBr(ContinuationBlock);
3695
+ Builder.SetInsertPoint(ContinuationBlock);
3696
+ }
3677
3697
Config.setEmitLLVMUsed();
3678
3698
3679
3699
return Builder.saveIP();
@@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) {
3688
3708
".omp.reduction.func", &M);
3689
3709
}
3690
3710
3691
- OpenMPIRBuilder::InsertPointOrErrorTy
3692
- OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
3693
- InsertPointTy AllocaIP,
3694
- ArrayRef<ReductionInfo> ReductionInfos,
3695
- ArrayRef<bool> IsByRef, bool IsNoWait) {
3696
- assert(ReductionInfos.size() == IsByRef.size());
3697
- for (const ReductionInfo &RI : ReductionInfos) {
3698
- (void)RI;
3699
- assert(RI.Variable && "expected non-null variable");
3700
- assert(RI.PrivateVariable && "expected non-null private variable");
3701
- assert(RI.ReductionGen && "expected non-null reduction generator callback");
3702
- assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
3703
- "expected variables and their private equivalents to have the same "
3704
- "type");
3705
- assert(RI.Variable->getType()->isPointerTy() &&
3706
- "expected variables to be pointers");
3711
+ static Error populateReductionFunction(
3712
+ Function *ReductionFunc,
3713
+ ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3714
+ IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
3715
+ Module *Module = ReductionFunc->getParent();
3716
+ BasicBlock *ReductionFuncBlock =
3717
+ BasicBlock::Create(Module->getContext(), "", ReductionFunc);
3718
+ Builder.SetInsertPoint(ReductionFuncBlock);
3719
+ Value *LHSArrayPtr = nullptr;
3720
+ Value *RHSArrayPtr = nullptr;
3721
+ if (IsGPU) {
3722
+ // Need to alloca memory here and deal with the pointers before getting
3723
+ // LHS/RHS pointers out
3724
+ //
3725
+ Argument *Arg0 = ReductionFunc->getArg(0);
3726
+ Argument *Arg1 = ReductionFunc->getArg(1);
3727
+ Type *Arg0Type = Arg0->getType();
3728
+ Type *Arg1Type = Arg1->getType();
3729
+
3730
+ Value *LHSAlloca =
3731
+ Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
3732
+ Value *RHSAlloca =
3733
+ Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
3734
+ Value *LHSAddrCast =
3735
+ Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type);
3736
+ Value *RHSAddrCast =
3737
+ Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type);
3738
+ Builder.CreateStore(Arg0, LHSAddrCast);
3739
+ Builder.CreateStore(Arg1, RHSAddrCast);
3740
+ LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
3741
+ RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
3742
+ } else {
3743
+ LHSArrayPtr = ReductionFunc->getArg(0);
3744
+ RHSArrayPtr = ReductionFunc->getArg(1);
3707
3745
}
3708
3746
3747
+ unsigned NumReductions = ReductionInfos.size();
3748
+ Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
3749
+
3750
+ for (auto En : enumerate(ReductionInfos)) {
3751
+ const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3752
+ Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3753
+ RedArrayTy, LHSArrayPtr, 0, En.index());
3754
+ Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3755
+ Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3756
+ LHSI8Ptr, RI.Variable->getType());
3757
+ Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3758
+ Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3759
+ RedArrayTy, RHSArrayPtr, 0, En.index());
3760
+ Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3761
+ Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3762
+ RHSI8Ptr, RI.PrivateVariable->getType());
3763
+ Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3764
+ Value *Reduced;
3765
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
3766
+ RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3767
+ if (!AfterIP)
3768
+ return AfterIP.takeError();
3769
+
3770
+ Builder.restoreIP(*AfterIP);
3771
+ // TODO: Consider flagging an error.
3772
+ if (!Builder.GetInsertBlock())
3773
+ return Error::success();
3774
+
3775
+ // store is inside of the reduction region when using by-ref
3776
+ if (!IsByRef[En.index()])
3777
+ Builder.CreateStore(Reduced, LHSPtr);
3778
+ }
3779
+ Builder.CreateRetVoid();
3780
+ return Error::success();
3781
+ }
3782
+
3783
+ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
3784
+ const LocationDescription &Loc, InsertPointTy AllocaIP,
3785
+ ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
3786
+ bool IsNoWait, bool IsTeamsReduction) {
3787
+ assert(ReductionInfos.size() == IsByRef.size());
3788
+ if (Config.isGPU())
3789
+ return createReductionsGPU(Loc, AllocaIP, Builder.saveIP(), ReductionInfos,
3790
+ IsNoWait, IsTeamsReduction);
3791
+
3792
+ checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
3793
+
3709
3794
if (!updateToLocation(Loc))
3710
3795
return InsertPointTy();
3711
3796
3797
+ if (ReductionInfos.size() == 0)
3798
+ return Builder.saveIP();
3799
+
3712
3800
BasicBlock *InsertBlock = Loc.IP.getBlock();
3713
3801
BasicBlock *ContinuationBlock =
3714
3802
InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
@@ -3832,38 +3920,13 @@ OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
3832
3920
// Populate the outlined reduction function using the elementwise reduction
3833
3921
// function. Partial values are extracted from the type-erased array of
3834
3922
// pointers to private variables.
3835
- BasicBlock *ReductionFuncBlock =
3836
- BasicBlock::Create(Module->getContext(), "", ReductionFunc);
3837
- Builder.SetInsertPoint(ReductionFuncBlock);
3838
- Value *LHSArrayPtr = ReductionFunc->getArg(0);
3839
- Value *RHSArrayPtr = ReductionFunc->getArg(1);
3923
+ Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
3924
+ IsByRef, /*isGPU=*/false);
3925
+ if (Err)
3926
+ return Err;
3840
3927
3841
- for (auto En : enumerate(ReductionInfos)) {
3842
- const ReductionInfo &RI = En.value();
3843
- Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3844
- RedArrayTy, LHSArrayPtr, 0, En.index());
3845
- Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3846
- Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
3847
- Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3848
- Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3849
- RedArrayTy, RHSArrayPtr, 0, En.index());
3850
- Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3851
- Value *RHSPtr =
3852
- Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
3853
- Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3854
- Value *Reduced;
3855
- InsertPointOrErrorTy AfterIP =
3856
- RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3857
- if (!AfterIP)
3858
- return AfterIP.takeError();
3859
- Builder.restoreIP(*AfterIP);
3860
- if (!Builder.GetInsertBlock())
3861
- return InsertPointTy();
3862
- // store is inside of the reduction region when using by-ref
3863
- if (!IsByRef[En.index()])
3864
- Builder.CreateStore(Reduced, LHSPtr);
3865
- }
3866
- Builder.CreateRetVoid();
3928
+ if (!Builder.GetInsertBlock())
3929
+ return InsertPointTy();
3867
3930
3868
3931
Builder.SetInsertPoint(ContinuationBlock);
3869
3932
return Builder.saveIP();
@@ -6239,8 +6302,10 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
6239
6302
Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
6240
6303
Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
6241
6304
Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
6242
- Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
6243
- Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
6305
+ Constant *ReductionDataSize =
6306
+ ConstantInt::getSigned(Int32, Attrs.ReductionDataSize);
6307
+ Constant *ReductionBufferLength =
6308
+ ConstantInt::getSigned(Int32, Attrs.ReductionBufferLength);
6244
6309
6245
6310
Function *Fn = getOrCreateRuntimeFunctionPtr(
6246
6311
omp::RuntimeFunction::OMPRTL___kmpc_target_init);
0 commit comments