Skip to content

Commit 8ae1cb2

Browse files
ImanHosseiniImanHosseinikuhar
authored
add power function to APInt (#122788)
I am trying to calculate power function for APFloat, APInt to constant fold vector reductions: #122450 I need this utility to fold N `mul`s into power. --------- Co-authored-by: ImanHosseini <imanhosseini.17@gmail.com> Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
1 parent a18f4bd commit 8ae1cb2

File tree

3 files changed

+89
-0
lines changed

3 files changed

+89
-0
lines changed

llvm/include/llvm/ADT/APInt.h

+4
Original file line numberDiff line numberDiff line change
@@ -2263,6 +2263,10 @@ APInt mulhs(const APInt &C1, const APInt &C2);
22632263
/// Returns the high N bits of the multiplication result.
22642264
APInt mulhu(const APInt &C1, const APInt &C2);
22652265

2266+
/// Compute X^N for N>=0.
2267+
/// 0^0 is supported and returns 1.
2268+
APInt pow(const APInt &X, int64_t N);
2269+
22662270
/// Compute GCD of two unsigned APInt values.
22672271
///
22682272
/// This function returns the greatest common divisor of the two APInt values

llvm/lib/Support/APInt.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -3108,3 +3108,21 @@ APInt APIntOps::mulhu(const APInt &C1, const APInt &C2) {
31083108
APInt C2Ext = C2.zext(FullWidth);
31093109
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
31103110
}
3111+
3112+
APInt APIntOps::pow(const APInt &X, int64_t N) {
3113+
assert(N >= 0 && "negative exponents not supported.");
3114+
APInt Acc = APInt(X.getBitWidth(), 1);
3115+
if (N == 0)
3116+
return Acc;
3117+
APInt Base = X;
3118+
int64_t RemainingExponent = N;
3119+
while (RemainingExponent > 0) {
3120+
while (RemainingExponent % 2 == 0) {
3121+
Base *= Base;
3122+
RemainingExponent /= 2;
3123+
}
3124+
--RemainingExponent;
3125+
Acc *= Base;
3126+
}
3127+
return Acc;
3128+
};

llvm/unittests/ADT/APIntTest.cpp

+67
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,73 @@ TEST(APIntTest, ValueInit) {
2929
EXPECT_TRUE(!Zero.sext(64));
3030
}
3131

32+
// Test that 0^5 == 0
33+
TEST(APIntTest, PowZeroTo5) {
34+
APInt Zero = APInt::getZero(32);
35+
EXPECT_TRUE(!Zero);
36+
APInt ZeroTo5 = APIntOps::pow(Zero, 5);
37+
EXPECT_TRUE(!ZeroTo5);
38+
}
39+
40+
// Test that 1^16 == 1
41+
TEST(APIntTest, PowOneTo16) {
42+
APInt One(32, 1);
43+
APInt OneTo16 = APIntOps::pow(One, 16);
44+
EXPECT_EQ(One, OneTo16);
45+
}
46+
47+
// Test that 2^10 == 1024
48+
TEST(APIntTest, PowerTwoTo10) {
49+
APInt Two(32, 2);
50+
APInt TwoTo20 = APIntOps::pow(Two, 10);
51+
APInt V_1024(32, 1024);
52+
EXPECT_EQ(TwoTo20, V_1024);
53+
}
54+
55+
// Test that 3^3 == 27
56+
TEST(APIntTest, PowerThreeTo3) {
57+
APInt Three(32, 3);
58+
APInt ThreeTo3 = APIntOps::pow(Three, 3);
59+
APInt V_27(32, 27);
60+
EXPECT_EQ(ThreeTo3, V_27);
61+
}
62+
63+
// Test that SignedMaxValue^3 == SignedMaxValue
64+
TEST(APIntTest, PowerSignedMaxValue) {
65+
APInt SignedMaxValue = APInt::getSignedMaxValue(32);
66+
APInt MaxTo3 = APIntOps::pow(SignedMaxValue, 3);
67+
EXPECT_EQ(MaxTo3, SignedMaxValue);
68+
}
69+
70+
// Test that MaxValue^3 == MaxValue
71+
TEST(APIntTest, PowerMaxValue) {
72+
APInt MaxValue = APInt::getMaxValue(32);
73+
APInt MaxTo3 = APIntOps::pow(MaxValue, 3);
74+
EXPECT_EQ(MaxValue, MaxTo3);
75+
}
76+
77+
// Test that SignedMinValue^3 == 0
78+
TEST(APIntTest, PowerSignedMinValueTo3) {
79+
APInt SignedMinValue = APInt::getSignedMinValue(32);
80+
APInt MinTo3 = APIntOps::pow(SignedMinValue, 3);
81+
EXPECT_TRUE(MinTo3.isZero());
82+
}
83+
84+
// Test that SignedMinValue^1 == SignedMinValue
85+
TEST(APIntTest, PowerSignedMinValueTo1) {
86+
APInt SignedMinValue = APInt::getSignedMinValue(32);
87+
APInt MinTo1 = APIntOps::pow(SignedMinValue, 1);
88+
EXPECT_EQ(SignedMinValue, MinTo1);
89+
}
90+
91+
// Test that MaxValue^3 == MaxValue
92+
TEST(APIntTest, ZeroToZero) {
93+
APInt Zero = APInt::getZero(32);
94+
APInt One(32, 1);
95+
APInt ZeroToZero = APIntOps::pow(Zero, 0);
96+
EXPECT_EQ(ZeroToZero, One);
97+
}
98+
3299
// Test that APInt shift left works when bitwidth > 64 and shiftamt == 0
33100
TEST(APIntTest, ShiftLeftByZero) {
34101
APInt One = APInt::getZero(65) + 1;

0 commit comments

Comments
 (0)