Skip to content

Commit

Permalink
sat: python remove some type assert to improve model building perform…
Browse files Browse the repository at this point in the history
…ance
  • Loading branch information
Mizux committed Sep 18, 2024
1 parent ed49b32 commit 1406518
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 80 deletions.
77 changes: 29 additions & 48 deletions ortools/sat/python/cp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,6 @@ def __sub__(self, arg):
if cmh.is_zero(arg):
return self
if isinstance(arg, NumberTypes):
arg = cmh.assert_is_a_number(arg)
return _Sum(self, -arg)
else:
return _Sum(self, -arg)
Expand Down Expand Up @@ -566,7 +565,6 @@ def __eq__(self, arg: LinearExprT) -> BoundedLinearExprT: # type: ignore[overri
if arg is None:
return False
if isinstance(arg, IntegralTypes):
arg = cmh.assert_is_int64(arg)
return BoundedLinearExpression(self, [arg, arg])
elif isinstance(arg, LinearExpr):
return BoundedLinearExpression(self - arg, [0, 0])
Expand All @@ -575,31 +573,31 @@ def __eq__(self, arg: LinearExprT) -> BoundedLinearExprT: # type: ignore[overri

def __ge__(self, arg: LinearExprT) -> "BoundedLinearExpression":
if isinstance(arg, IntegralTypes):
arg = cmh.assert_is_int64(arg)
if arg >= INT_MAX:
raise ArithmeticError(">= INT_MAX is not supported")
return BoundedLinearExpression(self, [arg, INT_MAX])
else:
return BoundedLinearExpression(self - arg, [0, INT_MAX])

def __le__(self, arg: LinearExprT) -> "BoundedLinearExpression":
if isinstance(arg, IntegralTypes):
arg = cmh.assert_is_int64(arg)
if arg <= INT_MIN:
raise ArithmeticError("<= INT_MIN is not supported")
return BoundedLinearExpression(self, [INT_MIN, arg])
else:
return BoundedLinearExpression(self - arg, [INT_MIN, 0])

def __lt__(self, arg: LinearExprT) -> "BoundedLinearExpression":
if isinstance(arg, IntegralTypes):
arg = cmh.assert_is_int64(arg)
if arg == INT_MIN:
if arg <= INT_MIN:
raise ArithmeticError("< INT_MIN is not supported")
return BoundedLinearExpression(self, [INT_MIN, arg - 1])
else:
return BoundedLinearExpression(self - arg, [INT_MIN, -1])

def __gt__(self, arg: LinearExprT) -> "BoundedLinearExpression":
if isinstance(arg, IntegralTypes):
arg = cmh.assert_is_int64(arg)
if arg == INT_MAX:
if arg >= INT_MAX:
raise ArithmeticError("> INT_MAX is not supported")
return BoundedLinearExpression(self, [arg + 1, INT_MAX])
else:
Expand All @@ -609,10 +607,9 @@ def __ne__(self, arg: LinearExprT) -> BoundedLinearExprT: # type: ignore[overri
if arg is None:
return True
if isinstance(arg, IntegralTypes):
arg = cmh.assert_is_int64(arg)
if arg == INT_MAX:
if arg >= INT_MAX:
return BoundedLinearExpression(self, [INT_MIN, INT_MAX - 1])
elif arg == INT_MIN:
elif arg <= INT_MIN:
return BoundedLinearExpression(self, [INT_MIN + 1, INT_MAX])
else:
return BoundedLinearExpression(
Expand Down Expand Up @@ -702,7 +699,6 @@ class _ProductCst(LinearExpr):
"""Represents the product of a LinearExpr by a constant."""

def __init__(self, expr, coeff) -> None:
coeff = cmh.assert_is_a_number(coeff)
if isinstance(expr, _ProductCst):
self.__expr = expr.expression()
self.__coef = expr.coefficient() * coeff
Expand Down Expand Up @@ -736,7 +732,6 @@ def __init__(self, expressions, constant=0) -> None:
if isinstance(x, NumberTypes):
if cmh.is_zero(x):
continue
x = cmh.assert_is_a_number(x)
self.__constant += x
elif isinstance(x, LinearExpr):
self.__expressions.append(x)
Expand Down Expand Up @@ -776,11 +771,9 @@ def __init__(self, expressions, coefficients, constant=0) -> None:
" coefficient array must have the same length."
)
for e, c in zip(expressions, coefficients):
c = cmh.assert_is_a_number(c)
if cmh.is_zero(c):
continue
if isinstance(e, NumberTypes):
e = cmh.assert_is_a_number(e)
self.__constant += e * c
elif isinstance(e, LinearExpr):
self.__expressions.append(e)
Expand Down Expand Up @@ -1509,9 +1502,8 @@ def add_linear_expression_in_domain(
for t in coeffs_map.items():
if not isinstance(t[0], IntVar):
raise TypeError("Wrong argument" + str(t))
c = cmh.assert_is_int64(t[1])
model_ct.linear.vars.append(t[0].index)
model_ct.linear.coeffs.append(c)
model_ct.linear.coeffs.append(t[1])
model_ct.linear.domain.extend(
[
cmh.capped_subtraction(x, constant)
Expand Down Expand Up @@ -1640,12 +1632,9 @@ def add_circuit(self, arcs: Sequence[ArcT]) -> Constraint:
ct = Constraint(self)
model_ct = self.__model.constraints[ct.index]
for arc in arcs:
tail = cmh.assert_is_int32(arc[0])
head = cmh.assert_is_int32(arc[1])
lit = self.get_or_make_boolean_index(arc[2])
model_ct.circuit.tails.append(tail)
model_ct.circuit.heads.append(head)
model_ct.circuit.literals.append(lit)
model_ct.circuit.tails.append(arc[0])
model_ct.circuit.heads.append(arc[1])
model_ct.circuit.literals.append(self.get_or_make_boolean_index(arc[2]))
return ct

def add_multiple_circuit(self, arcs: Sequence[ArcT]) -> Constraint:
Expand Down Expand Up @@ -1677,12 +1666,9 @@ def add_multiple_circuit(self, arcs: Sequence[ArcT]) -> Constraint:
ct = Constraint(self)
model_ct = self.__model.constraints[ct.index]
for arc in arcs:
tail = cmh.assert_is_int32(arc[0])
head = cmh.assert_is_int32(arc[1])
lit = self.get_or_make_boolean_index(arc[2])
model_ct.routes.tails.append(tail)
model_ct.routes.heads.append(head)
model_ct.routes.literals.append(lit)
model_ct.routes.tails.append(arc[0])
model_ct.routes.heads.append(arc[1])
model_ct.routes.literals.append(self.get_or_make_boolean_index(arc[2]))
return ct

def add_allowed_assignments(
Expand Down Expand Up @@ -1720,15 +1706,19 @@ def add_allowed_assignments(
model_ct = self.__model.constraints[ct.index]
model_ct.table.vars.extend([self.get_or_make_index(x) for x in variables])
arity: int = len(variables)
for t in tuples_list:
if len(t) != arity:
raise TypeError("Tuple " + str(t) + " has the wrong arity")
for one_tuple in tuples_list:
if len(one_tuple) != arity:
raise TypeError("Tuple " + str(one_tuple) + " has the wrong arity")

# duck-typing (no explicit type checks here)
try:
model_ct.table.values.extend(a for b in tuples_list for a in b)
for one_tuple in tuples_list:
model_ct.table.values.extend(one_tuple)
except ValueError as ex:
raise TypeError(f"add_xxx_assignment: Not an integer or does not fit in an int64_t: {ex.args}") from ex
raise TypeError(
"add_xxx_assignment: Not an integer or does not fit in an int64_t:"
f" {ex.args}"
) from ex

return ct

Expand Down Expand Up @@ -1762,7 +1752,7 @@ def add_forbidden_assignments(
"add_forbidden_assignments expects a non-empty variables array"
)

index = len(self.__model.constraints)
index: int = len(self.__model.constraints)
ct: Constraint = self.add_allowed_assignments(variables, tuples_list)
self.__model.constraints[index].table.negated = True
return ct
Expand Down Expand Up @@ -1829,20 +1819,15 @@ def add_automaton(
model_ct.automaton.vars.extend(
[self.get_or_make_index(x) for x in transition_variables]
)
starting_state = cmh.assert_is_int64(starting_state)
model_ct.automaton.starting_state = starting_state
for v in final_states:
v = cmh.assert_is_int64(v)
model_ct.automaton.final_states.append(v)
for t in transition_triples:
if len(t) != 3:
raise TypeError("Tuple " + str(t) + " has the wrong arity (!= 3)")
tail = cmh.assert_is_int64(t[0])
label = cmh.assert_is_int64(t[1])
head = cmh.assert_is_int64(t[2])
model_ct.automaton.transition_tail.append(tail)
model_ct.automaton.transition_label.append(label)
model_ct.automaton.transition_head.append(head)
model_ct.automaton.transition_tail.append(t[0])
model_ct.automaton.transition_label.append(t[1])
model_ct.automaton.transition_head.append(t[2])
return ct

def add_inverse(
Expand Down Expand Up @@ -2358,7 +2343,6 @@ def new_fixed_size_interval_var(
Returns:
An `IntervalVar` object.
"""
size = cmh.assert_is_int64(size)
start_expr = self.parse_linear_expression(start)
size_expr = self.parse_linear_expression(size)
end_expr = self.parse_linear_expression(start + size)
Expand Down Expand Up @@ -2545,7 +2529,6 @@ def new_optional_fixed_size_interval_var(
Returns:
An `IntervalVar` object.
"""
size = cmh.assert_is_int64(size)
start_expr = self.parse_linear_expression(start)
size_expr = self.parse_linear_expression(size)
end_expr = self.parse_linear_expression(start + size)
Expand Down Expand Up @@ -2776,7 +2759,6 @@ def get_or_make_index(self, arg: VariableT) -> int:
):
return -arg.expression().index - 1
if isinstance(arg, IntegralTypes):
arg = cmh.assert_is_int64(arg)
return self.get_or_make_index_from_constant(arg)
raise TypeError("NotSupported: model.get_or_make_index(" + str(arg) + ")")

Expand Down Expand Up @@ -2842,9 +2824,8 @@ def parse_linear_expression(
for t in coeffs_map.items():
if not isinstance(t[0], IntVar):
raise TypeError("Wrong argument" + str(t))
c = cmh.assert_is_int64(t[1])
result.vars.append(t[0].index)
result.coeffs.append(c * mult)
result.coeffs.append(t[1] * mult)
return result

def _set_objective(self, obj: ObjLinearExprT, minimize: bool):
Expand Down
22 changes: 0 additions & 22 deletions ortools/sat/python/cp_model_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,6 @@ def is_minus_one(x: Any) -> bool:
return False


def assert_is_int64(x: Any) -> int:
"""Asserts that x is integer and x is in [min_int_64, max_int_64] and returns it casted to an int."""
if not isinstance(x, numbers.Integral):
raise TypeError(f"Not an integer: {x} of type {type(x)}")
x_as_int = int(x)
if x_as_int < INT_MIN or x_as_int > INT_MAX:
raise OverflowError(f"Does not fit in an int64_t: {x}")
return x_as_int


def assert_is_int32(x: Any) -> int:
"""Asserts that x is integer and x is in [min_int_32, max_int_32] and returns it casted to an int."""
if not isinstance(x, numbers.Integral):
raise TypeError(f"Not an integer: {x} of type {type(x)}")
x_as_int = int(x)
if x_as_int < INT32_MIN or x_as_int > INT32_MAX:
raise OverflowError(f"Does not fit in an int32_t: {x}")
return x_as_int


def assert_is_zero_or_one(x: Any) -> int:
"""Asserts that x is 0 or 1 and returns it as an int."""
if not isinstance(x, numbers.Integral):
Expand Down Expand Up @@ -110,8 +90,6 @@ def to_capped_int64(v: int) -> int:

def capped_subtraction(x: int, y: int) -> int:
"""Saturated arithmetics. Returns x - y truncated to the int64_t range."""
assert_is_int64(x)
assert_is_int64(y)
if y == 0:
return x
if x == y:
Expand Down
10 changes: 0 additions & 10 deletions ortools/sat/python/cp_model_helper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,6 @@ def test_is_boolean(self):
self.assertTrue(cp_model_helper.is_boolean(np.bool_(1)))
self.assertTrue(cp_model_helper.is_boolean(np.bool_(0)))

def testassert_is_int64(self):
print("testassert_is_int64")
self.assertRaises(TypeError, cp_model_helper.assert_is_int64, "Hello")
self.assertRaises(TypeError, cp_model_helper.assert_is_int64, 1.2)
self.assertRaises(OverflowError, cp_model_helper.assert_is_int64, 2**63)
self.assertRaises(OverflowError, cp_model_helper.assert_is_int64, -(2**63) - 1)
cp_model_helper.assert_is_int64(123)
cp_model_helper.assert_is_int64(2**63 - 1)
cp_model_helper.assert_is_int64(-(2**63))

def testto_capped_int64(self):
print("testto_capped_int64")
self.assertEqual(
Expand Down

0 comments on commit 1406518

Please sign in to comment.