From 514554965bd91246a963a8d7da311226c8854490 Mon Sep 17 00:00:00 2001 From: Jeffrey Byrnes Date: Mon, 28 Apr 2025 10:10:59 -0700 Subject: [PATCH 1/2] [InstCombine] Combine and->cmp->sel->or-disjoint into and->mul Change-Id: I08cfc8c494bce343bf09e3110186e1a8553e2473 --- .../InstCombine/InstCombineAndOrXor.cpp | 48 +++ .../test/Transforms/InstCombine/or-bitmask.ll | 328 ++++++++++++++++++ 2 files changed, 376 insertions(+) create mode 100644 llvm/test/Transforms/InstCombine/or-bitmask.ll diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index b74bce391aa56..18d3f551adef4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3641,6 +3641,54 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { foldAddLikeCommutative(I.getOperand(1), I.getOperand(0), /*NSW=*/true, /*NUW=*/true)) return R; + + Value *Cond0 = nullptr, *Cond1 = nullptr; + const APInt *Op0Eq = nullptr, *Op0Ne = nullptr; + const APInt *Op1Eq = nullptr, *Op1Ne = nullptr; + + // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C + if (match(I.getOperand(0), + m_Select(m_Value(Cond0), m_APInt(Op0Eq), m_APInt(Op0Ne))) && + match(I.getOperand(1), + m_Select(m_Value(Cond1), m_APInt(Op1Eq), m_APInt(Op1Ne)))) { + CmpPredicate Pred0, Pred1; + + auto LHSDecompose = + decomposeBitTest(Cond0, /*LookThruTrunc=*/true, + /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true); + auto RHSDecompose = + decomposeBitTest(Cond1, /*LookThruTrunc=*/true, + /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true); + + if (LHSDecompose && RHSDecompose && LHSDecompose->X == RHSDecompose->X && + (ICmpInst::isEquality(LHSDecompose->Pred)) && + !RHSDecompose->Mask.isNegative() && + !LHSDecompose->Mask.isNegative() && RHSDecompose->Mask.isPowerOf2() && + LHSDecompose->Mask.isPowerOf2() && + LHSDecompose->Mask != RHSDecompose->Mask && + LHSDecompose->C.isZero() && RHSDecompose->C.isZero()) { + if (LHSDecompose->Pred == ICmpInst::ICMP_NE) + std::swap(Op0Eq, Op0Ne); + if (RHSDecompose->Pred == ICmpInst::ICMP_NE) + std::swap(Op1Eq, Op1Ne); + + if (Op0Ne->isStrictlyPositive() && Op1Ne->isStrictlyPositive() && + Op0Eq->isZero() && Op1Eq->isZero() && + Op0Ne->urem(LHSDecompose->Mask).isZero() && + Op1Ne->urem(RHSDecompose->Mask).isZero() && + Op0Ne->udiv(LHSDecompose->Mask) == + Op1Ne->udiv(RHSDecompose->Mask)) { + auto NewAnd = Builder.CreateAnd( + LHSDecompose->X, + ConstantInt::get(LHSDecompose->X->getType(), + (LHSDecompose->Mask + RHSDecompose->Mask))); + + return BinaryOperator::CreateMul( + NewAnd, ConstantInt::get(NewAnd->getType(), + Op0Ne->udiv(LHSDecompose->Mask))); + } + } + } } Value *X, *Y; diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll new file mode 100644 index 0000000000000..18acc42f2dc40 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll @@ -0,0 +1,328 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s --check-prefixes=CHECK,CONSTVEC +; RUN: opt < %s -passes=instcombine -S -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=CHECK,CONSTSPLAT + +define i32 @add_select_cmp_and1(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and1( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %bitop1 = and i32 %in, 2 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 0, i32 72 + %sel1 = select i1 %cmp1, i32 0, i32 144 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_and2(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and2( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 5 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %bitop1 = and i32 %in, 4 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 0, i32 72 + %sel1 = select i1 %cmp1, i32 0, i32 288 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_and3(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and3( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: [[BITOP2:%.*]] = and i32 [[IN]], 4 +; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[BITOP2]], 0 +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i32 0, i32 288 +; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[SEL2]] +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %bitop1 = and i32 %in, 2 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 0, i32 72 + %sel1 = select i1 %cmp1, i32 0, i32 144 + %temp = or disjoint i32 %sel0, %sel1 + %bitop2 = and i32 %in, 4 + %cmp2 = icmp eq i32 %bitop2, 0 + %sel2 = select i1 %cmp2, i32 0, i32 288 + %out = or disjoint i32 %temp, %sel2 + ret i32 %out +} + +define i32 @add_select_cmp_and4(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and4( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN]], 12 +; CHECK-NEXT: [[TEMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72 +; CHECK-NEXT: [[OUT1:%.*]] = or disjoint i32 [[OUT]], [[TEMP3]] +; CHECK-NEXT: ret i32 [[OUT1]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %bitop1 = and i32 %in, 2 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 0, i32 72 + %sel1 = select i1 %cmp1, i32 0, i32 144 + %temp = or disjoint i32 %sel0, %sel1 + %bitop2 = and i32 %in, 4 + %cmp2 = icmp eq i32 %bitop2, 0 + %bitop3 = and i32 %in, 8 + %cmp3 = icmp eq i32 %bitop3, 0 + %sel2 = select i1 %cmp2, i32 0, i32 288 + %sel3 = select i1 %cmp3, i32 0, i32 576 + %temp2 = or disjoint i32 %sel2, %sel3 + %out = or disjoint i32 %temp, %temp2 + ret i32 %out +} + +define i32 @add_select_cmp_and_pred1(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and_pred1( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp ne i32 %bitop0, 0 + %bitop1 = and i32 %in, 2 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 72, i32 0 + %sel1 = select i1 %cmp1, i32 0, i32 144 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_and_pred2(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and_pred2( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %bitop1 = and i32 %in, 2 + %cmp1 = icmp ne i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 0, i32 72 + %sel1 = select i1 %cmp1, i32 144, i32 0 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_and_pred3(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and_pred3( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp ne i32 %bitop0, 0 + %bitop1 = and i32 %in, 2 + %cmp1 = icmp ne i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 72, i32 0 + %sel1 = select i1 %cmp1, i32 144, i32 0 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_trunc(i32 %in) { +; CHECK-LABEL: @add_select_cmp_trunc( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %cmp0 = trunc i32 %in to i1 + %bitop1 = and i32 %in, 2 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 72, i32 0 + %sel1 = select i1 %cmp1, i32 0, i32 144 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_trunc1(i32 %in) { +; CHECK-LABEL: @add_select_cmp_trunc1( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %cmp0 = trunc i32 %in to i1 + %bitop1 = and i32 %in, 2 + %cmp1 = icmp ne i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 72, i32 0 + %sel1 = select i1 %cmp1, i32 144, i32 0 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + + +define i32 @add_select_cmp_and_const_mismatch(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and_const_mismatch( +; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1 +; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0 +; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN]], 2 +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[BITOP1]], 0 +; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72 +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 288 +; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]] +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %bitop1 = and i32 %in, 2 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 0, i32 72 + %sel1 = select i1 %cmp1, i32 0, i32 288 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_and_value_mismatch(i32 %in, i32 %in1) { +; CHECK-LABEL: @add_select_cmp_and_value_mismatch( +; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1 +; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0 +; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN1:%.*]], 2 +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[BITOP1]], 0 +; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72 +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 144 +; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]] +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %bitop1 = and i32 %in1, 2 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 0, i32 72 + %sel1 = select i1 %cmp1, i32 0, i32 144 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_and_negative(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and_negative( +; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1 +; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0 +; CHECK-NEXT: [[CMP1:%.*]] = icmp ult i32 [[IN]], 2 +; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72 +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 -144 +; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]] +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %bitop1 = and i32 %in, -2 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 0, i32 72 + %sel1 = select i1 %cmp1, i32 0, i32 -144 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_and_bitsel_overlap(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and_bitsel_overlap( +; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 2 +; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0 +; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 144 +; CHECK-NEXT: ret i32 [[SEL0]] +; + %bitop0 = and i32 %in, 2 + %cmp0 = icmp eq i32 %bitop0, 0 + %bitop1 = and i32 %in, 2 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 0, i32 144 + %sel1 = select i1 %cmp1, i32 0, i32 144 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +; We cannot combine into and-mul, as %bitop1 may not be exactly 6 + +define i32 @add_select_cmp_and_multbit_mask(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and_multbit_mask( +; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1 +; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0 +; CHECK-NEXT: [[BITOP1:%.*]] = and i32 [[IN]], 6 +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[BITOP1]], 0 +; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 72 +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i32 0, i32 432 +; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]] +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %bitop1 = and i32 %in, 6 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel0 = select i1 %cmp0, i32 0, i32 72 + %sel1 = select i1 %cmp1, i32 0, i32 432 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + + +define <2 x i32> @add_select_cmp_vec(<2 x i32> %in) { +; CHECK-LABEL: @add_select_cmp_vec( +; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[IN:%.*]], splat (i32 3) +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw <2 x i32> [[TMP1]], splat (i32 72) +; CHECK-NEXT: ret <2 x i32> [[OUT]] +; + %bitop0 = and <2 x i32> %in, + %cmp0 = icmp eq <2 x i32> %bitop0, + %bitop1 = and <2 x i32> %in, + %cmp1 = icmp eq <2 x i32> %bitop1, + %sel0 = select <2 x i1> %cmp0, <2 x i32> , <2 x i32> + %sel1 = select <2 x i1> %cmp1, <2 x i32> , <2 x i32> + %out = or disjoint <2 x i32> %sel0, %sel1 + ret <2 x i32> %out +} + +define <2 x i32> @add_select_cmp_vec_poison(<2 x i32> %in) { +; CHECK-LABEL: @add_select_cmp_vec_poison( +; CHECK-NEXT: [[BITOP0:%.*]] = and <2 x i32> [[IN:%.*]], splat (i32 1) +; CHECK-NEXT: [[CMP0:%.*]] = icmp eq <2 x i32> [[BITOP0]], zeroinitializer +; CHECK-NEXT: [[BITOP1:%.*]] = and <2 x i32> [[IN]], splat (i32 2) +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq <2 x i32> [[BITOP1]], zeroinitializer +; CHECK-NEXT: [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i32> zeroinitializer, <2 x i32> +; CHECK-NEXT: [[OUT:%.*]] = select <2 x i1> [[CMP0]], <2 x i32> [[SEL1]], <2 x i32> +; CHECK-NEXT: ret <2 x i32> [[OUT]] +; + %bitop0 = and <2 x i32> %in, + %cmp0 = icmp eq <2 x i32> %bitop0, + %bitop1 = and <2 x i32> %in, + %cmp1 = icmp eq <2 x i32> %bitop1, + %sel0 = select <2 x i1> %cmp0, <2 x i32> , <2 x i32> + %sel1 = select <2 x i1> %cmp1, <2 x i32> , <2 x i32> + %out = or disjoint <2 x i32> %sel0, %sel1 + ret <2 x i32> %out +} + +define <2 x i32> @add_select_cmp_vec_nonunique(<2 x i32> %in) { +; CHECK-LABEL: @add_select_cmp_vec_nonunique( +; CHECK-NEXT: [[BITOP0:%.*]] = and <2 x i32> [[IN:%.*]], +; CHECK-NEXT: [[CMP0:%.*]] = icmp eq <2 x i32> [[BITOP0]], zeroinitializer +; CHECK-NEXT: [[BITOP1:%.*]] = and <2 x i32> [[IN]], +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq <2 x i32> [[BITOP1]], zeroinitializer +; CHECK-NEXT: [[SEL0:%.*]] = select <2 x i1> [[CMP0]], <2 x i32> zeroinitializer, <2 x i32> +; CHECK-NEXT: [[SEL1:%.*]] = select <2 x i1> [[CMP1]], <2 x i32> zeroinitializer, <2 x i32> +; CHECK-NEXT: [[OUT:%.*]] = or disjoint <2 x i32> [[SEL0]], [[SEL1]] +; CHECK-NEXT: ret <2 x i32> [[OUT]] +; + %bitop0 = and <2 x i32> %in, + %cmp0 = icmp eq <2 x i32> %bitop0, + %bitop1 = and <2 x i32> %in, + %cmp1 = icmp eq <2 x i32> %bitop1, + %sel0 = select <2 x i1> %cmp0, <2 x i32> , <2 x i32> + %sel1 = select <2 x i1> %cmp1, <2 x i32> , <2 x i32> + %out = or disjoint <2 x i32> %sel0, %sel1 + ret <2 x i32> %out +} +;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: +; CONSTSPLAT: {{.*}} +; CONSTVEC: {{.*}} \ No newline at end of file From e2f011e4de86fa8ef84f6f304f70cfa85de90f26 Mon Sep 17 00:00:00 2001 From: Jeffrey Byrnes Date: Thu, 17 Apr 2025 10:11:18 -0700 Subject: [PATCH 2/2] [InstCombine] Extend bitmask->select combine to match and->mul Change-Id: I1cc2acd3804dde50636518f3ef2c9581848ae9f6 --- .../InstCombine/InstCombineAndOrXor.cpp | 124 +++++++++++------- .../test/Transforms/InstCombine/or-bitmask.ll | 96 ++++++++++++-- 2 files changed, 164 insertions(+), 56 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 18d3f551adef4..b3ec2bb2d5f3f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3560,6 +3560,72 @@ static Value *foldOrOfInversions(BinaryOperator &I, return nullptr; } +struct DecomposedBitMaskMul { + Value *X; + APInt Factor; + APInt Mask; +}; + +static std::optional matchBitmaskMul(Value *V) { + Instruction *Op = dyn_cast(V); + if (!Op) + return std::nullopt; + + Value *MulOp = nullptr; + const APInt *MulConst = nullptr; + if (match(Op, m_Mul(m_Value(MulOp), m_APInt(MulConst)))) { + Value *Original = nullptr; + const APInt *Mask = nullptr; + if (!MulConst->isStrictlyPositive()) + return std::nullopt; + + if (match(MulOp, m_And(m_Value(Original), m_APInt(Mask)))) { + if (!Mask->isStrictlyPositive()) + return std::nullopt; + DecomposedBitMaskMul Ret; + Ret.X = Original; + Ret.Mask = *Mask; + Ret.Factor = *MulConst; + return Ret; + } + return std::nullopt; + } + + Value *Cond = nullptr; + const APInt *EqZero = nullptr, *NeZero = nullptr; + + // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C + if (match(Op, m_Select(m_Value(Cond), m_APInt(EqZero), m_APInt(NeZero)))) { + auto ICmpDecompose = + decomposeBitTest(Cond, /*LookThruTrunc=*/true, + /*AllowNonZeroC=*/false, /*DecomposeBitMask=*/true); + if (!ICmpDecompose.has_value()) + return std::nullopt; + + if (ICmpDecompose->Pred == ICmpInst::ICMP_NE) + std::swap(EqZero, NeZero); + + if (!EqZero->isZero() || !NeZero->isStrictlyPositive()) + return std::nullopt; + + if (!ICmpInst::isEquality(ICmpDecompose->Pred) || + !ICmpDecompose->C.isZero() || !ICmpDecompose->Mask.isPowerOf2() || + ICmpDecompose->Mask.isNegative()) + return std::nullopt; + + if (!NeZero->urem(ICmpDecompose->Mask).isZero()) + return std::nullopt; + + DecomposedBitMaskMul Ret; + Ret.X = ICmpDecompose->X; + Ret.Mask = ICmpDecompose->Mask; + Ret.Factor = NeZero->udiv(ICmpDecompose->Mask); + return Ret; + } + + return std::nullopt; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -3642,51 +3708,19 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { /*NSW=*/true, /*NUW=*/true)) return R; - Value *Cond0 = nullptr, *Cond1 = nullptr; - const APInt *Op0Eq = nullptr, *Op0Ne = nullptr; - const APInt *Op1Eq = nullptr, *Op1Ne = nullptr; - - // (!(A & N) ? 0 : N * C) + (!(A & M) ? 0 : M * C) -> A & (N + M) * C - if (match(I.getOperand(0), - m_Select(m_Value(Cond0), m_APInt(Op0Eq), m_APInt(Op0Ne))) && - match(I.getOperand(1), - m_Select(m_Value(Cond1), m_APInt(Op1Eq), m_APInt(Op1Ne)))) { - CmpPredicate Pred0, Pred1; - - auto LHSDecompose = - decomposeBitTest(Cond0, /*LookThruTrunc=*/true, - /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true); - auto RHSDecompose = - decomposeBitTest(Cond1, /*LookThruTrunc=*/true, - /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true); - - if (LHSDecompose && RHSDecompose && LHSDecompose->X == RHSDecompose->X && - (ICmpInst::isEquality(LHSDecompose->Pred)) && - !RHSDecompose->Mask.isNegative() && - !LHSDecompose->Mask.isNegative() && RHSDecompose->Mask.isPowerOf2() && - LHSDecompose->Mask.isPowerOf2() && - LHSDecompose->Mask != RHSDecompose->Mask && - LHSDecompose->C.isZero() && RHSDecompose->C.isZero()) { - if (LHSDecompose->Pred == ICmpInst::ICMP_NE) - std::swap(Op0Eq, Op0Ne); - if (RHSDecompose->Pred == ICmpInst::ICMP_NE) - std::swap(Op1Eq, Op1Ne); - - if (Op0Ne->isStrictlyPositive() && Op1Ne->isStrictlyPositive() && - Op0Eq->isZero() && Op1Eq->isZero() && - Op0Ne->urem(LHSDecompose->Mask).isZero() && - Op1Ne->urem(RHSDecompose->Mask).isZero() && - Op0Ne->udiv(LHSDecompose->Mask) == - Op1Ne->udiv(RHSDecompose->Mask)) { - auto NewAnd = Builder.CreateAnd( - LHSDecompose->X, - ConstantInt::get(LHSDecompose->X->getType(), - (LHSDecompose->Mask + RHSDecompose->Mask))); - - return BinaryOperator::CreateMul( - NewAnd, ConstantInt::get(NewAnd->getType(), - Op0Ne->udiv(LHSDecompose->Mask))); - } + auto Decomp0 = matchBitmaskMul(I.getOperand(0)); + auto Decomp1 = matchBitmaskMul(I.getOperand(1)); + + if (Decomp0 && Decomp1) { + if (Decomp0->X == Decomp1->X && + (Decomp0->Mask & Decomp1->Mask).isZero() && + Decomp0->Factor == Decomp1->Factor) { + auto NewAnd = Builder.CreateAnd( + Decomp0->X, ConstantInt::get(Decomp0->X->getType(), + (Decomp0->Mask + Decomp1->Mask))); + + return BinaryOperator::CreateMul( + NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor)); } } } diff --git a/llvm/test/Transforms/InstCombine/or-bitmask.ll b/llvm/test/Transforms/InstCombine/or-bitmask.ll index 18acc42f2dc40..4b989733afbb9 100644 --- a/llvm/test/Transforms/InstCombine/or-bitmask.ll +++ b/llvm/test/Transforms/InstCombine/or-bitmask.ll @@ -36,13 +36,9 @@ define i32 @add_select_cmp_and2(i32 %in) { define i32 @add_select_cmp_and3(i32 %in) { ; CHECK-LABEL: @add_select_cmp_and3( -; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 7 ; CHECK-NEXT: [[TEMP:%.*]] = mul nuw nsw i32 [[TMP1]], 72 -; CHECK-NEXT: [[BITOP2:%.*]] = and i32 [[IN]], 4 -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[BITOP2]], 0 -; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i32 0, i32 288 -; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[TEMP]], [[SEL2]] -; CHECK-NEXT: ret i32 [[OUT]] +; CHECK-NEXT: ret i32 [[TEMP]] ; %bitop0 = and i32 %in, 1 %cmp0 = icmp eq i32 %bitop0, 0 @@ -60,12 +56,9 @@ define i32 @add_select_cmp_and3(i32 %in) { define i32 @add_select_cmp_and4(i32 %in) { ; CHECK-LABEL: @add_select_cmp_and4( -; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 -; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 -; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN]], 12 +; CHECK-NEXT: [[TMP2:%.*]] = and i32 [[IN:%.*]], 15 ; CHECK-NEXT: [[TEMP3:%.*]] = mul nuw nsw i32 [[TMP2]], 72 -; CHECK-NEXT: [[OUT1:%.*]] = or disjoint i32 [[OUT]], [[TEMP3]] -; CHECK-NEXT: ret i32 [[OUT1]] +; CHECK-NEXT: ret i32 [[TEMP3]] ; %bitop0 = and i32 %in, 1 %cmp0 = icmp eq i32 %bitop0, 0 @@ -323,6 +316,87 @@ define <2 x i32> @add_select_cmp_vec_nonunique(<2 x i32> %in) { %out = or disjoint <2 x i32> %sel0, %sel1 ret <2 x i32> %out } + +define i32 @add_select_cmp_mixed1(i32 %in) { +; CHECK-LABEL: @add_select_cmp_mixed1( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %mask = and i32 %in, 1 + %sel0 = mul i32 %mask, 72 + %bitop1 = and i32 %in, 2 + %cmp1 = icmp eq i32 %bitop1, 0 + %sel1 = select i1 %cmp1, i32 0, i32 144 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_mixed2(i32 %in) { +; CHECK-LABEL: @add_select_cmp_mixed2( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %mask = and i32 %in, 2 + %sel0 = select i1 %cmp0, i32 0, i32 72 + %sel1 = mul i32 %mask, 72 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_and_mul(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and_mul( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[IN:%.*]], 3 +; CHECK-NEXT: [[OUT:%.*]] = mul nuw nsw i32 [[TMP1]], 72 +; CHECK-NEXT: ret i32 [[OUT]] +; + %mask0 = and i32 %in, 1 + %sel0 = mul i32 %mask0, 72 + %mask1 = and i32 %in, 2 + %sel1 = mul i32 %mask1, 72 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_mixed2_mismatch(i32 %in) { +; CHECK-LABEL: @add_select_cmp_mixed2_mismatch( +; CHECK-NEXT: [[BITOP0:%.*]] = and i32 [[IN:%.*]], 1 +; CHECK-NEXT: [[CMP0:%.*]] = icmp eq i32 [[BITOP0]], 0 +; CHECK-NEXT: [[MASK:%.*]] = and i32 [[IN]], 2 +; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 0, i32 73 +; CHECK-NEXT: [[SEL1:%.*]] = mul nuw nsw i32 [[MASK]], 72 +; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]] +; CHECK-NEXT: ret i32 [[OUT]] +; + %bitop0 = and i32 %in, 1 + %cmp0 = icmp eq i32 %bitop0, 0 + %mask = and i32 %in, 2 + %sel0 = select i1 %cmp0, i32 0, i32 73 + %sel1 = mul i32 %mask, 72 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + +define i32 @add_select_cmp_and_mul_mismatch(i32 %in) { +; CHECK-LABEL: @add_select_cmp_and_mul_mismatch( +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[IN:%.*]] to i1 +; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[TMP1]], i32 73, i32 0 +; CHECK-NEXT: [[MASK1:%.*]] = and i32 [[IN]], 2 +; CHECK-NEXT: [[SEL1:%.*]] = mul nuw nsw i32 [[MASK1]], 72 +; CHECK-NEXT: [[OUT:%.*]] = or disjoint i32 [[SEL0]], [[SEL1]] +; CHECK-NEXT: ret i32 [[OUT]] +; + %mask0 = and i32 %in, 1 + %sel0 = mul i32 %mask0, 73 + %mask1 = and i32 %in, 2 + %sel1 = mul i32 %mask1, 72 + %out = or disjoint i32 %sel0, %sel1 + ret i32 %out +} + ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line: ; CONSTSPLAT: {{.*}} ; CONSTVEC: {{.*}} \ No newline at end of file