Skip to content

Commit

Permalink
Merge pull request #2176 from Shaikh-Ubaid/unsigned_ints_no_wrap
Browse files Browse the repository at this point in the history
Unsigned ints no wrap
  • Loading branch information
certik authored Jul 19, 2023
2 parents ff0985b + 4195caa commit 4339407
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 205 deletions.
6 changes: 4 additions & 2 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -570,9 +570,11 @@ RUN(NAME test_unary_op_01 LABELS cpython llvm c) # unary minus
RUN(NAME test_unary_op_02 LABELS cpython llvm c) # unary plus
RUN(NAME test_unary_op_03 LABELS cpython llvm c wasm) # unary bitinvert
RUN(NAME test_unary_op_04 LABELS cpython llvm c) # unary bitinvert
RUN(NAME test_unary_op_05 LABELS cpython llvm c) # unsigned unary minus, plus
# Unsigned unary minus is not supported in CPython
# RUN(NAME test_unary_op_05 LABELS cpython llvm c) # unsigned unary minus, plus
RUN(NAME test_unary_op_06 LABELS cpython llvm c) # unsigned unary bitnot
RUN(NAME test_unsigned_01 LABELS cpython llvm c) # unsigned bitshift left, right
# The value after shift overflows in CPython
# RUN(NAME test_unsigned_01 LABELS cpython llvm c) # unsigned bitshift left, right
RUN(NAME test_unsigned_02 LABELS cpython llvm c)
RUN(NAME test_unsigned_03 LABELS cpython llvm c)
RUN(NAME test_bool_binop LABELS cpython llvm c)
Expand Down
62 changes: 33 additions & 29 deletions integration_tests/cast_02.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,46 +34,50 @@ def test_02():
print(w)
assert w == u32(11)

def test_03():
x : u32 = u32(-10)
print(x)
assert x == u32(4294967286)
# Disable following tests
# Negative numbers in unsigned should throw errors
# TODO: Add these tests as error reference tests

y: u16 = u16(x)
print(y)
assert y == u16(65526)
# def test_03():
# x : u32 = u32(-10)
# print(x)
# assert x == u32(4294967286)

z: u64 = u64(y)
print(z)
assert z == u64(65526)
# y: u16 = u16(x)
# print(y)
# assert y == u16(65526)

w: u8 = u8(z)
print(w)
assert w == u8(246)
# z: u64 = u64(y)
# print(z)
# assert z == u64(65526)

def test_04():
x : u64 = u64(-11)
print(x)
# TODO: We are unable to store the following u64 in AST/R
# assert x == u64(18446744073709551605)
# w: u8 = u8(z)
# print(w)
# assert w == u8(246)

y: u8 = u8(x)
print(y)
assert y == u8(245)
# def test_04():
# x : u64 = u64(-11)
# print(x)
# # TODO: We are unable to store the following u64 in AST/R
# # assert x == u64(18446744073709551605)

z: u16 = u16(y)
print(z)
assert z == u16(245)
# y: u8 = u8(x)
# print(y)
# assert y == u8(245)

w: u32 = u32(z)
print(w)
assert w == u32(245)
# z: u16 = u16(y)
# print(z)
# assert z == u16(245)

# w: u32 = u32(z)
# print(w)
# assert w == u32(245)


def main0():
test_01()
test_02()
test_03()
test_04()
# test_03()
# test_04()

main0()
13 changes: 5 additions & 8 deletions integration_tests/test_unary_op_04.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from lpython import u16
from lpython import u16, bitnot_u16

def foo(grp: u16) -> u16:
i: u16 = ~(u16(grp))

i: u16 = bitnot_u16(grp)
return i


def foo2() -> u16:
i: u16 = ~(u16(0xffff))

i: u16 = bitnot_u16(u16(0xffff))
return i

def foo3() -> u16:
i: u16 = ~(u16(0xffff))

return ~i
i: u16 = bitnot_u16(u16(0xffff))
return bitnot_u16(i)

assert foo(u16(0)) == u16(0xffff)
assert foo(u16(0xffff)) == u16(0)
Expand Down
2 changes: 1 addition & 1 deletion src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6314,7 +6314,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
arg_kind != dest_kind )
{
if (dest_kind > arg_kind) {
tmp = builder->CreateSExt(tmp, llvm_utils->getIntType(dest_kind));
tmp = builder->CreateZExt(tmp, llvm_utils->getIntType(dest_kind));
} else {
tmp = builder->CreateTrunc(tmp, llvm_utils->getIntType(dest_kind));
}
Expand Down
40 changes: 31 additions & 9 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3482,15 +3482,17 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
tmp = ASR::make_IntegerBitNot_t(al, x.base.base.loc, operand, dest_type, value);
return;
} else if (ASRUtils::is_unsigned_integer(*operand_type)) {
if (ASRUtils::expr_value(operand) != nullptr) {
int64_t op_value = ASR::down_cast<ASR::UnsignedIntegerConstant_t>(
ASRUtils::expr_value(operand))->m_n;
uint64_t val = ~uint64_t(op_value);
value = ASR::down_cast<ASR::expr_t>(ASR::make_UnsignedIntegerConstant_t(
al, x.base.base.loc, val, operand_type));
}
tmp = ASR::make_UnsignedIntegerBitNot_t(al, x.base.base.loc, operand, dest_type, value);
return;
int kind = ASRUtils::extract_kind_from_ttype_t(operand_type);
int signed_promote_kind = (kind < 8) ? kind * 2 : kind;
diag.add(diag::Diagnostic(
"The result of the bitnot ~ operation is negative, thus out of range for u" + std::to_string(kind * 8),
diag::Level::Error, diag::Stage::Semantic, {
diag::Label("use ~i" + std::to_string(signed_promote_kind * 8)
+ "(u) for signed result or bitnot_u" + std::to_string(kind * 8) + "(u) for unsigned result",
{x.base.base.loc})
})
);
throw SemanticAbort();
} else if (ASRUtils::is_real(*operand_type)) {
throw SemanticError("Unary operator '~' not supported for floats",
x.base.base.loc);
Expand Down Expand Up @@ -7471,6 +7473,26 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Pointer_t(al, x.base.base.loc, type_));
tmp = ASR::make_GetPointer_t(al, x.base.base.loc, args[0].m_value, type, nullptr);
return ;
} else if( call_name.substr(0, 6) == "bitnot" ) {
parse_args(x, args);
if (args.size() != 1) {
throw SemanticError(call_name + "() expects one argument, provided " + std::to_string(args.size()), x.base.base.loc);
}
ASR::expr_t* operand = args[0].m_value;
ASR::ttype_t *operand_type = ASRUtils::expr_type(operand);
ASR::expr_t* value = nullptr;
if (!ASR::is_a<ASR::UnsignedInteger_t>(*operand_type)) {
throw SemanticError(call_name + "() expects unsigned integer, provided" + ASRUtils::type_to_str_python(operand_type), x.base.base.loc);
}
if (ASRUtils::expr_value(operand) != nullptr) {
int64_t op_value = ASR::down_cast<ASR::UnsignedIntegerConstant_t>(
ASRUtils::expr_value(operand))->m_n;
uint64_t val = ~uint64_t(op_value);
value = ASR::down_cast<ASR::expr_t>(ASR::make_UnsignedIntegerConstant_t(
al, x.base.base.loc, val, operand_type));
}
tmp = ASR::make_UnsignedIntegerBitNot_t(al, x.base.base.loc, operand, operand_type, value);
return;
} else if( call_name == "array" ) {
parse_args(x, args);
if( args.size() != 1 ) {
Expand Down
168 changes: 12 additions & 156 deletions src/runtime/lpython/lpython.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,168 +14,16 @@

# data-types

class UnsignedInteger:
def __init__(self, bit_width, value):
if isinstance(value, UnsignedInteger):
value = value.value
self.bit_width = bit_width
self.value = value % (2**bit_width)

def __bool__(self):
return self.value != 0

def __add__(self, other):
if isinstance(other, self.__class__):
return UnsignedInteger(self.bit_width, (self.value + other.value) % (2**self.bit_width))
else:
raise TypeError("Unsupported operand type")

def __sub__(self, other):
if isinstance(other, self.__class__):
# if self.value < other.value:
# raise ValueError("Result of subtraction cannot be negative")
return UnsignedInteger(self.bit_width, (self.value - other.value) % (2**self.bit_width))
else:
raise TypeError("Unsupported operand type")

def __mul__(self, other):
if isinstance(other, self.__class__):
return UnsignedInteger(self.bit_width, (self.value * other.value) % (2**self.bit_width))
else:
raise TypeError("Unsupported operand type")

def __div__(self, other):
if isinstance(other, self.__class__):
if other.value == 0:
raise ValueError("Division by zero")
return UnsignedInteger(self.bit_width, self.value / other.value)
else:
raise TypeError("Unsupported operand type")

def __floordiv__(self, other):
if isinstance(other, self.__class__):
if other.value == 0:
raise ValueError("Division by zero")
return UnsignedInteger(self.bit_width, self.value // other.value)
else:
raise TypeError("Unsupported operand type")

def __mod__(self, other):
if isinstance(other, self.__class__):
if other.value == 0:
raise ValueError("Modulo by zero")
return UnsignedInteger(self.bit_width, self.value % other.value)
else:
raise TypeError("Unsupported operand type")

def __pow__(self, other):
if isinstance(other, self.__class__):
return UnsignedInteger(self.bit_width, (self.value ** other.value) % (2**self.bit_width))
else:
raise TypeError("Unsupported operand type")

def __and__(self, other):
if isinstance(other, self.__class__):
return UnsignedInteger(self.bit_width, self.value & other.value)
else:
raise TypeError("Unsupported operand type")

def __or__(self, other):
if isinstance(other, self.__class__):
return UnsignedInteger(self.bit_width, self.value | other.value)
else:
raise TypeError("Unsupported operand type")

# unary operators
def __neg__(self):
return UnsignedInteger(self.bit_width, -self.value % (2**self.bit_width))

def __pos__(self):
return UnsignedInteger(self.bit_width, self.value)

def __abs__(self):
return UnsignedInteger(self.bit_width, abs(self.value))

def __invert__(self):
return UnsignedInteger(self.bit_width, ~self.value % (2**self.bit_width))

# comparator operators
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.value == other.value
else:
try:
return self.value == other
except:
raise TypeError("Unsupported operand type")

def __ne__(self, other):
if isinstance(other, self.__class__):
return self.value != other.value
else:
raise TypeError("Unsupported operand type")

def __lt__(self, other):
if isinstance(other, self.__class__):
return self.value < other.value
else:
raise TypeError("Unsupported operand type")

def __le__(self, other):
if isinstance(other, self.__class__):
return self.value <= other.value
else:
raise TypeError("Unsupported operand type")

def __gt__(self, other):
if isinstance(other, self.__class__):
return self.value > other.value
else:
raise TypeError("Unsupported operand type")

def __ge__(self, other):
if isinstance(other, self.__class__):
return self.value >= other.value
else:
raise TypeError("Unsupported operand type")

def __lshift__(self, other):
if isinstance(other, self.__class__):
return UnsignedInteger(self.bit_width, self.value << other.value)
else:
raise TypeError("Unsupported operand type")

def __rshift__(self, other):
if isinstance(other, self.__class__):
return UnsignedInteger(self.bit_width, self.value >> other.value)
else:
raise TypeError("Unsupported operand type")

# conversion to integer
def __int__(self):
return self.value

def __str__(self):
return str(self.value)

def __repr__(self):
return f'UnsignedInteger({self.bit_width}, {str(self)})'

def __index__(self):
return self.value



type_to_convert_func = {
"i1": bool,
"i8": int,
"i16": int,
"i32": int,
"i64": int,
"u8": lambda x: UnsignedInteger(8, x),
"u16": lambda x: UnsignedInteger(16, x),
"u32": lambda x: UnsignedInteger(32, x),
"u64": lambda x: UnsignedInteger(64, x),
"u8": int,
"u16": int,
"u32": int,
"u64": int,
"f32": float,
"f64": float,
"c32": complex,
Expand Down Expand Up @@ -859,3 +707,11 @@ def __call__(self, *args, **kwargs):
function = getattr(__import__("lpython_module_" + self.fn_name),
self.fn_name)
return function(*args, **kwargs)

def bitnot(x, bitsize):
return (~x) % (2 ** bitsize)

bitnot_u8 = lambda x: bitnot(x, 8)
bitnot_u16 = lambda x: bitnot(x, 16)
bitnot_u32 = lambda x: bitnot(x, 32)
bitnot_u64 = lambda x: bitnot(x, 64)

0 comments on commit 4339407

Please sign in to comment.