diff --git a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp index 395311e430fbb..474232e6870a3 100644 --- a/llvm/lib/Target/DirectX/DXILLegalizePass.cpp +++ b/llvm/lib/Target/DirectX/DXILLegalizePass.cpp @@ -12,6 +12,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include @@ -20,6 +21,31 @@ using namespace llvm; +static void replaceFrem(Instruction &I, + SmallVectorImpl &ToRemove, + DenseMap &) { + auto *BO = dyn_cast(&I); + if (BO == nullptr || BO->getOpcode() != Instruction::FRem) + return; + + IRBuilder<> Builder(&I); + Value *P0 = BO->getOperand(0); + Value *P1 = BO->getOperand(1); + + Value *Div1 = Builder.CreateFDiv(P0, P1); + Value *Zero = ConstantFP::get(P0->getType(), 0.0); + Value *Cmp = Builder.CreateFCmpOGE(Div1, Zero, "cmp.i"); + Value *AbsVal = + Builder.CreateIntrinsic(Div1->getType(), Intrinsic::fabs, {Div1}); + Value *FracVal = + Builder.CreateIntrinsic(AbsVal->getType(), Intrinsic::dx_frac, {AbsVal}); + Value *NegFrac = Builder.CreateFNeg(FracVal); + Value *SelectVal = Builder.CreateSelect(Cmp, FracVal, NegFrac); + Value *MulVal = Builder.CreateFMul(SelectVal, P1); + BO->replaceAllUsesWith(MulVal); + ToRemove.push_back(BO); +} + static void fixI8TruncUseChain(Instruction &I, SmallVectorImpl &ToRemove, DenseMap &ReplacedValues) { @@ -169,6 +195,7 @@ class DXILLegalizationPipeline { void initializeLegalizationPipeline() { LegalizationPipeline.push_back(fixI8TruncUseChain); LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements); + LegalizationPipeline.push_back(replaceFrem); } }; diff --git a/llvm/test/CodeGen/DirectX/frem.ll b/llvm/test/CodeGen/DirectX/frem.ll new file mode 100644 index 0000000000000..e75cab90bec07 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/frem.ll @@ -0,0 +1,59 @@ + +; RUN: opt -S -dxil-legalize -mtriple=dxil-pc-shadermodel6.3-library %s -o - | FileCheck %s + +define noundef half @frem_half(half noundef %a, half noundef %b) { +; CHECK-LABEL: define noundef half @frem_half( +; CHECK-SAME: half noundef [[A:%.*]], half noundef [[B:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[FDIV:%.*]] = fdiv half [[A]], [[B]] +; CHECK-NEXT: [[FCMP:%.*]] = fcmp oge half [[FDIV]], 0xH0000 +; CHECK-NEXT: [[FABS:%.*]] = call half @llvm.fabs.f16(half [[FDIV]]) +; CHECK-NEXT: [[FRAC:%.*]] = call half @llvm.dx.frac.f16(half [[FABS]]) +; CHECK-NEXT: [[FNEG:%.*]] = fneg half [[FRAC]] +; CHECK-NEXT: [[SELC:%.*]] = select i1 [[FCMP]], half [[FRAC]], half [[FNEG]] +; CHECK-NEXT: [[FMUL:%.*]] = fmul half [[SELC]], [[B]] +; CHECK-NEXT: ret half [[FMUL]] +; +entry: + %fmod.i = frem reassoc nnan ninf nsz arcp afn half %a, %b + ret half %fmod.i +} + +; Note by the time the legalizer sees frem with vec type frem will be scalarized +; This test is for completeness not for expected input of DXL SMs <= 6.8. + +define noundef <2 x half> @frem_half2(<2 x half> noundef %a, <2 x half> noundef %b) { +; CHECK-LABEL: define noundef <2 x half> @frem_half2( +; CHECK-SAME: <2 x half> noundef [[A:%.*]], <2 x half> noundef [[B:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[FDIV:%.*]] = fdiv <2 x half> [[A]], [[B]] +; CHECK-NEXT: [[FCMP:%.*]] = fcmp oge <2 x half> [[FDIV]], zeroinitializer +; CHECK-NEXT: [[FABS:%.*]] = call <2 x half> @llvm.fabs.v2f16(<2 x half> [[FDIV]]) +; CHECK-NEXT: [[FRAC:%.*]] = call <2 x half> @llvm.dx.frac.v2f16(<2 x half> [[FABS]]) +; CHECK-NEXT: [[FNEG:%.*]] = fneg <2 x half> [[FRAC]] +; CHECK-NEXT: [[SELC:%.*]] = select <2 x i1> [[FCMP]], <2 x half> [[FRAC]], <2 x half> [[FNEG]] +; CHECK-NEXT: [[FMUL:%.*]] = fmul <2 x half> [[SELC]], [[B]] +; CHECK-NEXT: ret <2 x half> [[FMUL]] +; +entry: + %fmod.i = frem reassoc nnan ninf nsz arcp afn <2 x half> %a, %b + ret <2 x half> %fmod.i +} + +define noundef float @frem_float(float noundef %a, float noundef %b) { +; CHECK-LABEL: define noundef float @frem_float( +; CHECK-SAME: float noundef [[A:%.*]], float noundef [[B:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[FDIV:%.*]] = fdiv float [[A]], [[B]] +; CHECK-NEXT: [[FCMP:%.*]] = fcmp oge float [[FDIV]], 0.000000e+00 +; CHECK-NEXT: [[FABS:%.*]] = call float @llvm.fabs.f32(float [[FDIV]]) +; CHECK-NEXT: [[FRAC:%.*]] = call float @llvm.dx.frac.f32(float [[FABS]]) +; CHECK-NEXT: [[FNEG:%.*]] = fneg float [[FRAC]] +; CHECK-NEXT: [[SELC:%.*]] = select i1 [[FCMP]], float [[FRAC]], float [[FNEG]] +; CHECK-NEXT: [[FMUL:%.*]] = fmul float [[SELC]], [[B]] +; CHECK-NEXT: ret float [[FMUL]] +; +entry: + %fmod.i = frem reassoc nnan ninf nsz arcp afn float %a, %b + ret float %fmod.i +}