Skip to content

Commit 4339407

Browse files
authored
Merge pull request #2176 from Shaikh-Ubaid/unsigned_ints_no_wrap
Unsigned ints no wrap
2 parents ff0985b + 4195caa commit 4339407

File tree

6 files changed

+86
-205
lines changed

6 files changed

+86
-205
lines changed

integration_tests/CMakeLists.txt

+4-2
Original file line numberDiff line numberDiff line change
@@ -570,9 +570,11 @@ RUN(NAME test_unary_op_01 LABELS cpython llvm c) # unary minus
570570
RUN(NAME test_unary_op_02 LABELS cpython llvm c) # unary plus
571571
RUN(NAME test_unary_op_03 LABELS cpython llvm c wasm) # unary bitinvert
572572
RUN(NAME test_unary_op_04 LABELS cpython llvm c) # unary bitinvert
573-
RUN(NAME test_unary_op_05 LABELS cpython llvm c) # unsigned unary minus, plus
573+
# Unsigned unary minus is not supported in CPython
574+
# RUN(NAME test_unary_op_05 LABELS cpython llvm c) # unsigned unary minus, plus
574575
RUN(NAME test_unary_op_06 LABELS cpython llvm c) # unsigned unary bitnot
575-
RUN(NAME test_unsigned_01 LABELS cpython llvm c) # unsigned bitshift left, right
576+
# The value after shift overflows in CPython
577+
# RUN(NAME test_unsigned_01 LABELS cpython llvm c) # unsigned bitshift left, right
576578
RUN(NAME test_unsigned_02 LABELS cpython llvm c)
577579
RUN(NAME test_unsigned_03 LABELS cpython llvm c)
578580
RUN(NAME test_bool_binop LABELS cpython llvm c)

integration_tests/cast_02.py

+33-29
Original file line numberDiff line numberDiff line change
@@ -34,46 +34,50 @@ def test_02():
3434
print(w)
3535
assert w == u32(11)
3636

37-
def test_03():
38-
x : u32 = u32(-10)
39-
print(x)
40-
assert x == u32(4294967286)
37+
# Disable following tests
38+
# Negative numbers in unsigned should throw errors
39+
# TODO: Add these tests as error reference tests
4140

42-
y: u16 = u16(x)
43-
print(y)
44-
assert y == u16(65526)
41+
# def test_03():
42+
# x : u32 = u32(-10)
43+
# print(x)
44+
# assert x == u32(4294967286)
4545

46-
z: u64 = u64(y)
47-
print(z)
48-
assert z == u64(65526)
46+
# y: u16 = u16(x)
47+
# print(y)
48+
# assert y == u16(65526)
4949

50-
w: u8 = u8(z)
51-
print(w)
52-
assert w == u8(246)
50+
# z: u64 = u64(y)
51+
# print(z)
52+
# assert z == u64(65526)
5353

54-
def test_04():
55-
x : u64 = u64(-11)
56-
print(x)
57-
# TODO: We are unable to store the following u64 in AST/R
58-
# assert x == u64(18446744073709551605)
54+
# w: u8 = u8(z)
55+
# print(w)
56+
# assert w == u8(246)
5957

60-
y: u8 = u8(x)
61-
print(y)
62-
assert y == u8(245)
58+
# def test_04():
59+
# x : u64 = u64(-11)
60+
# print(x)
61+
# # TODO: We are unable to store the following u64 in AST/R
62+
# # assert x == u64(18446744073709551605)
6363

64-
z: u16 = u16(y)
65-
print(z)
66-
assert z == u16(245)
64+
# y: u8 = u8(x)
65+
# print(y)
66+
# assert y == u8(245)
6767

68-
w: u32 = u32(z)
69-
print(w)
70-
assert w == u32(245)
68+
# z: u16 = u16(y)
69+
# print(z)
70+
# assert z == u16(245)
71+
72+
# w: u32 = u32(z)
73+
# print(w)
74+
# assert w == u32(245)
7175

7276

7377
def main0():
7478
test_01()
7579
test_02()
76-
test_03()
77-
test_04()
80+
# test_03()
81+
# test_04()
7882

7983
main0()

integration_tests/test_unary_op_04.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
1-
from lpython import u16
1+
from lpython import u16, bitnot_u16
22

33
def foo(grp: u16) -> u16:
4-
i: u16 = ~(u16(grp))
5-
4+
i: u16 = bitnot_u16(grp)
65
return i
76

87

98
def foo2() -> u16:
10-
i: u16 = ~(u16(0xffff))
11-
9+
i: u16 = bitnot_u16(u16(0xffff))
1210
return i
1311

1412
def foo3() -> u16:
15-
i: u16 = ~(u16(0xffff))
16-
17-
return ~i
13+
i: u16 = bitnot_u16(u16(0xffff))
14+
return bitnot_u16(i)
1815

1916
assert foo(u16(0)) == u16(0xffff)
2017
assert foo(u16(0xffff)) == u16(0)

src/libasr/codegen/asr_to_llvm.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -6314,7 +6314,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
63146314
arg_kind != dest_kind )
63156315
{
63166316
if (dest_kind > arg_kind) {
6317-
tmp = builder->CreateSExt(tmp, llvm_utils->getIntType(dest_kind));
6317+
tmp = builder->CreateZExt(tmp, llvm_utils->getIntType(dest_kind));
63186318
} else {
63196319
tmp = builder->CreateTrunc(tmp, llvm_utils->getIntType(dest_kind));
63206320
}

src/lpython/semantics/python_ast_to_asr.cpp

+31-9
Original file line numberDiff line numberDiff line change
@@ -3482,15 +3482,17 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
34823482
tmp = ASR::make_IntegerBitNot_t(al, x.base.base.loc, operand, dest_type, value);
34833483
return;
34843484
} else if (ASRUtils::is_unsigned_integer(*operand_type)) {
3485-
if (ASRUtils::expr_value(operand) != nullptr) {
3486-
int64_t op_value = ASR::down_cast<ASR::UnsignedIntegerConstant_t>(
3487-
ASRUtils::expr_value(operand))->m_n;
3488-
uint64_t val = ~uint64_t(op_value);
3489-
value = ASR::down_cast<ASR::expr_t>(ASR::make_UnsignedIntegerConstant_t(
3490-
al, x.base.base.loc, val, operand_type));
3491-
}
3492-
tmp = ASR::make_UnsignedIntegerBitNot_t(al, x.base.base.loc, operand, dest_type, value);
3493-
return;
3485+
int kind = ASRUtils::extract_kind_from_ttype_t(operand_type);
3486+
int signed_promote_kind = (kind < 8) ? kind * 2 : kind;
3487+
diag.add(diag::Diagnostic(
3488+
"The result of the bitnot ~ operation is negative, thus out of range for u" + std::to_string(kind * 8),
3489+
diag::Level::Error, diag::Stage::Semantic, {
3490+
diag::Label("use ~i" + std::to_string(signed_promote_kind * 8)
3491+
+ "(u) for signed result or bitnot_u" + std::to_string(kind * 8) + "(u) for unsigned result",
3492+
{x.base.base.loc})
3493+
})
3494+
);
3495+
throw SemanticAbort();
34943496
} else if (ASRUtils::is_real(*operand_type)) {
34953497
throw SemanticError("Unary operator '~' not supported for floats",
34963498
x.base.base.loc);
@@ -7471,6 +7473,26 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
74717473
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Pointer_t(al, x.base.base.loc, type_));
74727474
tmp = ASR::make_GetPointer_t(al, x.base.base.loc, args[0].m_value, type, nullptr);
74737475
return ;
7476+
} else if( call_name.substr(0, 6) == "bitnot" ) {
7477+
parse_args(x, args);
7478+
if (args.size() != 1) {
7479+
throw SemanticError(call_name + "() expects one argument, provided " + std::to_string(args.size()), x.base.base.loc);
7480+
}
7481+
ASR::expr_t* operand = args[0].m_value;
7482+
ASR::ttype_t *operand_type = ASRUtils::expr_type(operand);
7483+
ASR::expr_t* value = nullptr;
7484+
if (!ASR::is_a<ASR::UnsignedInteger_t>(*operand_type)) {
7485+
throw SemanticError(call_name + "() expects unsigned integer, provided" + ASRUtils::type_to_str_python(operand_type), x.base.base.loc);
7486+
}
7487+
if (ASRUtils::expr_value(operand) != nullptr) {
7488+
int64_t op_value = ASR::down_cast<ASR::UnsignedIntegerConstant_t>(
7489+
ASRUtils::expr_value(operand))->m_n;
7490+
uint64_t val = ~uint64_t(op_value);
7491+
value = ASR::down_cast<ASR::expr_t>(ASR::make_UnsignedIntegerConstant_t(
7492+
al, x.base.base.loc, val, operand_type));
7493+
}
7494+
tmp = ASR::make_UnsignedIntegerBitNot_t(al, x.base.base.loc, operand, operand_type, value);
7495+
return;
74747496
} else if( call_name == "array" ) {
74757497
parse_args(x, args);
74767498
if( args.size() != 1 ) {

src/runtime/lpython/lpython.py

+12-156
Original file line numberDiff line numberDiff line change
@@ -14,168 +14,16 @@
1414

1515
# data-types
1616

17-
class UnsignedInteger:
18-
def __init__(self, bit_width, value):
19-
if isinstance(value, UnsignedInteger):
20-
value = value.value
21-
self.bit_width = bit_width
22-
self.value = value % (2**bit_width)
23-
24-
def __bool__(self):
25-
return self.value != 0
26-
27-
def __add__(self, other):
28-
if isinstance(other, self.__class__):
29-
return UnsignedInteger(self.bit_width, (self.value + other.value) % (2**self.bit_width))
30-
else:
31-
raise TypeError("Unsupported operand type")
32-
33-
def __sub__(self, other):
34-
if isinstance(other, self.__class__):
35-
# if self.value < other.value:
36-
# raise ValueError("Result of subtraction cannot be negative")
37-
return UnsignedInteger(self.bit_width, (self.value - other.value) % (2**self.bit_width))
38-
else:
39-
raise TypeError("Unsupported operand type")
40-
41-
def __mul__(self, other):
42-
if isinstance(other, self.__class__):
43-
return UnsignedInteger(self.bit_width, (self.value * other.value) % (2**self.bit_width))
44-
else:
45-
raise TypeError("Unsupported operand type")
46-
47-
def __div__(self, other):
48-
if isinstance(other, self.__class__):
49-
if other.value == 0:
50-
raise ValueError("Division by zero")
51-
return UnsignedInteger(self.bit_width, self.value / other.value)
52-
else:
53-
raise TypeError("Unsupported operand type")
54-
55-
def __floordiv__(self, other):
56-
if isinstance(other, self.__class__):
57-
if other.value == 0:
58-
raise ValueError("Division by zero")
59-
return UnsignedInteger(self.bit_width, self.value // other.value)
60-
else:
61-
raise TypeError("Unsupported operand type")
62-
63-
def __mod__(self, other):
64-
if isinstance(other, self.__class__):
65-
if other.value == 0:
66-
raise ValueError("Modulo by zero")
67-
return UnsignedInteger(self.bit_width, self.value % other.value)
68-
else:
69-
raise TypeError("Unsupported operand type")
70-
71-
def __pow__(self, other):
72-
if isinstance(other, self.__class__):
73-
return UnsignedInteger(self.bit_width, (self.value ** other.value) % (2**self.bit_width))
74-
else:
75-
raise TypeError("Unsupported operand type")
76-
77-
def __and__(self, other):
78-
if isinstance(other, self.__class__):
79-
return UnsignedInteger(self.bit_width, self.value & other.value)
80-
else:
81-
raise TypeError("Unsupported operand type")
82-
83-
def __or__(self, other):
84-
if isinstance(other, self.__class__):
85-
return UnsignedInteger(self.bit_width, self.value | other.value)
86-
else:
87-
raise TypeError("Unsupported operand type")
88-
89-
# unary operators
90-
def __neg__(self):
91-
return UnsignedInteger(self.bit_width, -self.value % (2**self.bit_width))
92-
93-
def __pos__(self):
94-
return UnsignedInteger(self.bit_width, self.value)
95-
96-
def __abs__(self):
97-
return UnsignedInteger(self.bit_width, abs(self.value))
98-
99-
def __invert__(self):
100-
return UnsignedInteger(self.bit_width, ~self.value % (2**self.bit_width))
101-
102-
# comparator operators
103-
def __eq__(self, other):
104-
if isinstance(other, self.__class__):
105-
return self.value == other.value
106-
else:
107-
try:
108-
return self.value == other
109-
except:
110-
raise TypeError("Unsupported operand type")
111-
112-
def __ne__(self, other):
113-
if isinstance(other, self.__class__):
114-
return self.value != other.value
115-
else:
116-
raise TypeError("Unsupported operand type")
117-
118-
def __lt__(self, other):
119-
if isinstance(other, self.__class__):
120-
return self.value < other.value
121-
else:
122-
raise TypeError("Unsupported operand type")
123-
124-
def __le__(self, other):
125-
if isinstance(other, self.__class__):
126-
return self.value <= other.value
127-
else:
128-
raise TypeError("Unsupported operand type")
129-
130-
def __gt__(self, other):
131-
if isinstance(other, self.__class__):
132-
return self.value > other.value
133-
else:
134-
raise TypeError("Unsupported operand type")
135-
136-
def __ge__(self, other):
137-
if isinstance(other, self.__class__):
138-
return self.value >= other.value
139-
else:
140-
raise TypeError("Unsupported operand type")
141-
142-
def __lshift__(self, other):
143-
if isinstance(other, self.__class__):
144-
return UnsignedInteger(self.bit_width, self.value << other.value)
145-
else:
146-
raise TypeError("Unsupported operand type")
147-
148-
def __rshift__(self, other):
149-
if isinstance(other, self.__class__):
150-
return UnsignedInteger(self.bit_width, self.value >> other.value)
151-
else:
152-
raise TypeError("Unsupported operand type")
153-
154-
# conversion to integer
155-
def __int__(self):
156-
return self.value
157-
158-
def __str__(self):
159-
return str(self.value)
160-
161-
def __repr__(self):
162-
return f'UnsignedInteger({self.bit_width}, {str(self)})'
163-
164-
def __index__(self):
165-
return self.value
166-
167-
168-
16917
type_to_convert_func = {
17018
"i1": bool,
17119
"i8": int,
17220
"i16": int,
17321
"i32": int,
17422
"i64": int,
175-
"u8": lambda x: UnsignedInteger(8, x),
176-
"u16": lambda x: UnsignedInteger(16, x),
177-
"u32": lambda x: UnsignedInteger(32, x),
178-
"u64": lambda x: UnsignedInteger(64, x),
23+
"u8": int,
24+
"u16": int,
25+
"u32": int,
26+
"u64": int,
17927
"f32": float,
18028
"f64": float,
18129
"c32": complex,
@@ -859,3 +707,11 @@ def __call__(self, *args, **kwargs):
859707
function = getattr(__import__("lpython_module_" + self.fn_name),
860708
self.fn_name)
861709
return function(*args, **kwargs)
710+
711+
def bitnot(x, bitsize):
712+
return (~x) % (2 ** bitsize)
713+
714+
bitnot_u8 = lambda x: bitnot(x, 8)
715+
bitnot_u16 = lambda x: bitnot(x, 16)
716+
bitnot_u32 = lambda x: bitnot(x, 32)
717+
bitnot_u64 = lambda x: bitnot(x, 64)

0 commit comments

Comments
 (0)