Skip to content

Commit fbc8335

Browse files
authoredApr 7, 2025
[MLIR][OpenMP] Add codegen for teams reductions (#133310)
This patch adds the lowering of teams reductions from the omp dialect to LLVM-IR. Some minor cleanup was done in clang to remove an unused parameter.
1 parent cb9afe5 commit fbc8335

File tree

12 files changed

+666
-111
lines changed

12 files changed

+666
-111
lines changed
 

‎clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,7 +1660,6 @@ void CGOpenMPRuntimeGPU::emitReduction(
16601660
return;
16611661

16621662
bool ParallelReduction = isOpenMPParallelDirective(Options.ReductionKind);
1663-
bool DistributeReduction = isOpenMPDistributeDirective(Options.ReductionKind);
16641663
bool TeamsReduction = isOpenMPTeamsDirective(Options.ReductionKind);
16651664

16661665
ASTContext &C = CGM.getContext();
@@ -1757,7 +1756,7 @@ void CGOpenMPRuntimeGPU::emitReduction(
17571756
llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
17581757
cantFail(OMPBuilder.createReductionsGPU(
17591758
OmpLoc, AllocaIP, CodeGenIP, ReductionInfos, false, TeamsReduction,
1760-
DistributeReduction, llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
1759+
llvm::OpenMPIRBuilder::ReductionGenCBKind::Clang,
17611760
CGF.getTarget().getGridValue(),
17621761
C.getLangOpts().OpenMPCUDAReductionBufNum, RTLoc));
17631762
CGF.Builder.restoreIP(AfterIP);

‎llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,8 +1907,6 @@ class OpenMPIRBuilder {
19071907
/// nowait.
19081908
/// \param IsTeamsReduction Optional flag set if it is a teams
19091909
/// reduction.
1910-
/// \param HasDistribute Optional flag set if it is a
1911-
/// distribute reduction.
19121910
/// \param GridValue Optional GPU grid value.
19131911
/// \param ReductionBufNum Optional OpenMPCUDAReductionBufNumValue to be
19141912
/// used for teams reduction.
@@ -1917,7 +1915,6 @@ class OpenMPIRBuilder {
19171915
const LocationDescription &Loc, InsertPointTy AllocaIP,
19181916
InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
19191917
bool IsNoWait = false, bool IsTeamsReduction = false,
1920-
bool HasDistribute = false,
19211918
ReductionGenCBKind ReductionGenCBKind = ReductionGenCBKind::MLIR,
19221919
std::optional<omp::GV> GridValue = {}, unsigned ReductionBufNum = 1024,
19231920
Value *SrcLocInfo = nullptr);
@@ -1985,11 +1982,14 @@ class OpenMPIRBuilder {
19851982
/// \param IsNoWait A flag set if the reduction is marked as nowait.
19861983
/// \param IsByRef A flag set if the reduction is using reference
19871984
/// or direct value.
1985+
/// \param IsTeamsReduction Optional flag set if it is a teams
1986+
/// reduction.
19881987
InsertPointOrErrorTy createReductions(const LocationDescription &Loc,
19891988
InsertPointTy AllocaIP,
19901989
ArrayRef<ReductionInfo> ReductionInfos,
19911990
ArrayRef<bool> IsByRef,
1992-
bool IsNoWait = false);
1991+
bool IsNoWait = false,
1992+
bool IsTeamsReduction = false);
19931993

19941994
///}
19951995

@@ -2273,6 +2273,8 @@ class OpenMPIRBuilder {
22732273
int32_t MinTeams = 1;
22742274
SmallVector<int32_t, 3> MaxThreads = {-1};
22752275
int32_t MinThreads = 1;
2276+
int32_t ReductionDataSize = 0;
2277+
int32_t ReductionBufferLength = 0;
22762278
};
22772279

22782280
/// Container to pass LLVM IR runtime values or constants related to the

‎llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 119 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3495,9 +3495,9 @@ checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
34953495
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
34963496
const LocationDescription &Loc, InsertPointTy AllocaIP,
34973497
InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
3498-
bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
3499-
ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
3500-
unsigned ReductionBufNum, Value *SrcLocInfo) {
3498+
bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind,
3499+
std::optional<omp::GV> GridValue, unsigned ReductionBufNum,
3500+
Value *SrcLocInfo) {
35013501
if (!updateToLocation(Loc))
35023502
return InsertPointTy();
35033503
Builder.restoreIP(CodeGenIP);
@@ -3514,6 +3514,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
35143514
if (ReductionInfos.size() == 0)
35153515
return Builder.saveIP();
35163516

3517+
BasicBlock *ContinuationBlock = nullptr;
3518+
if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
3519+
// Copied code from createReductions
3520+
BasicBlock *InsertBlock = Loc.IP.getBlock();
3521+
ContinuationBlock =
3522+
InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
3523+
InsertBlock->getTerminator()->eraseFromParent();
3524+
Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
3525+
}
3526+
35173527
Function *CurFunc = Builder.GetInsertBlock()->getParent();
35183528
AttributeList FuncAttrs;
35193529
AttrBuilder AttrBldr(Ctx);
@@ -3669,11 +3679,21 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
36693679
ReductionFunc;
36703680
});
36713681
} else {
3672-
assert(false && "Unhandled ReductionGenCBKind");
3682+
Value *LHSValue = Builder.CreateLoad(RI.ElementType, LHS, "final.lhs");
3683+
Value *RHSValue = Builder.CreateLoad(RI.ElementType, RHS, "final.rhs");
3684+
Value *Reduced;
3685+
InsertPointOrErrorTy AfterIP =
3686+
RI.ReductionGen(Builder.saveIP(), RHSValue, LHSValue, Reduced);
3687+
if (!AfterIP)
3688+
return AfterIP.takeError();
3689+
Builder.CreateStore(Reduced, LHS, false);
36733690
}
36743691
}
36753692
emitBlock(ExitBB, CurFunc);
3676-
3693+
if (ContinuationBlock) {
3694+
Builder.CreateBr(ContinuationBlock);
3695+
Builder.SetInsertPoint(ContinuationBlock);
3696+
}
36773697
Config.setEmitLLVMUsed();
36783698

36793699
return Builder.saveIP();
@@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) {
36883708
".omp.reduction.func", &M);
36893709
}
36903710

3691-
OpenMPIRBuilder::InsertPointOrErrorTy
3692-
OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
3693-
InsertPointTy AllocaIP,
3694-
ArrayRef<ReductionInfo> ReductionInfos,
3695-
ArrayRef<bool> IsByRef, bool IsNoWait) {
3696-
assert(ReductionInfos.size() == IsByRef.size());
3697-
for (const ReductionInfo &RI : ReductionInfos) {
3698-
(void)RI;
3699-
assert(RI.Variable && "expected non-null variable");
3700-
assert(RI.PrivateVariable && "expected non-null private variable");
3701-
assert(RI.ReductionGen && "expected non-null reduction generator callback");
3702-
assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
3703-
"expected variables and their private equivalents to have the same "
3704-
"type");
3705-
assert(RI.Variable->getType()->isPointerTy() &&
3706-
"expected variables to be pointers");
3711+
static Error populateReductionFunction(
3712+
Function *ReductionFunc,
3713+
ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3714+
IRBuilder<> &Builder, ArrayRef<bool> IsByRef, bool IsGPU) {
3715+
Module *Module = ReductionFunc->getParent();
3716+
BasicBlock *ReductionFuncBlock =
3717+
BasicBlock::Create(Module->getContext(), "", ReductionFunc);
3718+
Builder.SetInsertPoint(ReductionFuncBlock);
3719+
Value *LHSArrayPtr = nullptr;
3720+
Value *RHSArrayPtr = nullptr;
3721+
if (IsGPU) {
3722+
// Need to alloca memory here and deal with the pointers before getting
3723+
// LHS/RHS pointers out
3724+
//
3725+
Argument *Arg0 = ReductionFunc->getArg(0);
3726+
Argument *Arg1 = ReductionFunc->getArg(1);
3727+
Type *Arg0Type = Arg0->getType();
3728+
Type *Arg1Type = Arg1->getType();
3729+
3730+
Value *LHSAlloca =
3731+
Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
3732+
Value *RHSAlloca =
3733+
Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
3734+
Value *LHSAddrCast =
3735+
Builder.CreatePointerBitCastOrAddrSpaceCast(LHSAlloca, Arg0Type);
3736+
Value *RHSAddrCast =
3737+
Builder.CreatePointerBitCastOrAddrSpaceCast(RHSAlloca, Arg1Type);
3738+
Builder.CreateStore(Arg0, LHSAddrCast);
3739+
Builder.CreateStore(Arg1, RHSAddrCast);
3740+
LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
3741+
RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
3742+
} else {
3743+
LHSArrayPtr = ReductionFunc->getArg(0);
3744+
RHSArrayPtr = ReductionFunc->getArg(1);
37073745
}
37083746

3747+
unsigned NumReductions = ReductionInfos.size();
3748+
Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
3749+
3750+
for (auto En : enumerate(ReductionInfos)) {
3751+
const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3752+
Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3753+
RedArrayTy, LHSArrayPtr, 0, En.index());
3754+
Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3755+
Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3756+
LHSI8Ptr, RI.Variable->getType());
3757+
Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3758+
Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3759+
RedArrayTy, RHSArrayPtr, 0, En.index());
3760+
Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3761+
Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3762+
RHSI8Ptr, RI.PrivateVariable->getType());
3763+
Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3764+
Value *Reduced;
3765+
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
3766+
RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3767+
if (!AfterIP)
3768+
return AfterIP.takeError();
3769+
3770+
Builder.restoreIP(*AfterIP);
3771+
// TODO: Consider flagging an error.
3772+
if (!Builder.GetInsertBlock())
3773+
return Error::success();
3774+
3775+
// store is inside of the reduction region when using by-ref
3776+
if (!IsByRef[En.index()])
3777+
Builder.CreateStore(Reduced, LHSPtr);
3778+
}
3779+
Builder.CreateRetVoid();
3780+
return Error::success();
3781+
}
3782+
3783+
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions(
3784+
const LocationDescription &Loc, InsertPointTy AllocaIP,
3785+
ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool> IsByRef,
3786+
bool IsNoWait, bool IsTeamsReduction) {
3787+
assert(ReductionInfos.size() == IsByRef.size());
3788+
if (Config.isGPU())
3789+
return createReductionsGPU(Loc, AllocaIP, Builder.saveIP(), ReductionInfos,
3790+
IsNoWait, IsTeamsReduction);
3791+
3792+
checkReductionInfos(ReductionInfos, /*IsGPU*/ false);
3793+
37093794
if (!updateToLocation(Loc))
37103795
return InsertPointTy();
37113796

3797+
if (ReductionInfos.size() == 0)
3798+
return Builder.saveIP();
3799+
37123800
BasicBlock *InsertBlock = Loc.IP.getBlock();
37133801
BasicBlock *ContinuationBlock =
37143802
InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
@@ -3832,38 +3920,13 @@ OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
38323920
// Populate the outlined reduction function using the elementwise reduction
38333921
// function. Partial values are extracted from the type-erased array of
38343922
// pointers to private variables.
3835-
BasicBlock *ReductionFuncBlock =
3836-
BasicBlock::Create(Module->getContext(), "", ReductionFunc);
3837-
Builder.SetInsertPoint(ReductionFuncBlock);
3838-
Value *LHSArrayPtr = ReductionFunc->getArg(0);
3839-
Value *RHSArrayPtr = ReductionFunc->getArg(1);
3923+
Error Err = populateReductionFunction(ReductionFunc, ReductionInfos, Builder,
3924+
IsByRef, /*isGPU=*/false);
3925+
if (Err)
3926+
return Err;
38403927

3841-
for (auto En : enumerate(ReductionInfos)) {
3842-
const ReductionInfo &RI = En.value();
3843-
Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3844-
RedArrayTy, LHSArrayPtr, 0, En.index());
3845-
Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3846-
Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
3847-
Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3848-
Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3849-
RedArrayTy, RHSArrayPtr, 0, En.index());
3850-
Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3851-
Value *RHSPtr =
3852-
Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
3853-
Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3854-
Value *Reduced;
3855-
InsertPointOrErrorTy AfterIP =
3856-
RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3857-
if (!AfterIP)
3858-
return AfterIP.takeError();
3859-
Builder.restoreIP(*AfterIP);
3860-
if (!Builder.GetInsertBlock())
3861-
return InsertPointTy();
3862-
// store is inside of the reduction region when using by-ref
3863-
if (!IsByRef[En.index()])
3864-
Builder.CreateStore(Reduced, LHSPtr);
3865-
}
3866-
Builder.CreateRetVoid();
3928+
if (!Builder.GetInsertBlock())
3929+
return InsertPointTy();
38673930

38683931
Builder.SetInsertPoint(ContinuationBlock);
38693932
return Builder.saveIP();
@@ -6239,8 +6302,10 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
62396302
Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
62406303
Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
62416304
Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
6242-
Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
6243-
Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
6305+
Constant *ReductionDataSize =
6306+
ConstantInt::getSigned(Int32, Attrs.ReductionDataSize);
6307+
Constant *ReductionBufferLength =
6308+
ConstantInt::getSigned(Int32, Attrs.ReductionBufferLength);
62446309

62456310
Function *Fn = getOrCreateRuntimeFunctionPtr(
62466311
omp::RuntimeFunction::OMPRTL___kmpc_target_init);

‎llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2354,6 +2354,7 @@ TEST_F(OpenMPIRBuilderTest, StaticWorkshareLoopTarget) {
23542354
"256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8");
23552355
OpenMPIRBuilder OMPBuilder(*M);
23562356
OMPBuilder.Config.IsTargetDevice = true;
2357+
OMPBuilder.Config.setIsGPU(false);
23572358
OMPBuilder.initialize();
23582359
IRBuilder<> Builder(BB);
23592360
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});

0 commit comments

Comments
 (0)