From f053be97865affaa1cfaa18dfa1c23bb265afac7 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Mon, 23 Sep 2024 15:28:18 +0200 Subject: [PATCH] [CP-SAT] fix presolve bug; fix callback bug --- ortools/sat/BUILD.bazel | 1 + ortools/sat/cp_model_mapping.h | 7 + ortools/sat/cp_model_presolve.cc | 168 +++++++++++++------- ortools/sat/feasibility_jump_test.cc | 1 + ortools/sat/feasibility_pump.cc | 7 +- ortools/sat/presolve_context.cc | 27 +++- ortools/sat/presolve_context.h | 3 +- ortools/sat/python/cp_model_test.py | 225 ++++++++++++++++++++++++++- ortools/sat/swig_helper.cc | 24 ++- ortools/sat/swig_helper.h | 13 +- 10 files changed, 385 insertions(+), 91 deletions(-) diff --git a/ortools/sat/BUILD.bazel b/ortools/sat/BUILD.bazel index 87ac5cd71e2..9991d8998d9 100644 --- a/ortools/sat/BUILD.bazel +++ b/ortools/sat/BUILD.bazel @@ -3004,6 +3004,7 @@ cc_library( ":cp_model_utils", ":model", ":sat_parameters_cc_proto", + ":util", "//ortools/util:logging", "//ortools/util:sorted_interval_list", "//ortools/util:time_limit", diff --git a/ortools/sat/cp_model_mapping.h b/ortools/sat/cp_model_mapping.h index 58a6849f74e..530a0a21b55 100644 --- a/ortools/sat/cp_model_mapping.h +++ b/ortools/sat/cp_model_mapping.h @@ -172,6 +172,13 @@ class CpModelMapping { return reverse_integer_map_[var]; } + // This one should only be used when we have a mapping. + int GetProtoLiteralFromLiteral(sat::Literal lit) const { + const int proto_var = GetProtoVariableFromBooleanVariable(lit.Variable()); + DCHECK_NE(proto_var, -1); + return lit.IsPositive() ? proto_var : NegatedRef(proto_var); + } + const std::vector& GetVariableMapping() const { return integers_; } diff --git a/ortools/sat/cp_model_presolve.cc b/ortools/sat/cp_model_presolve.cc index f7e2c33576d..4b6720e1064 100644 --- a/ortools/sat/cp_model_presolve.cc +++ b/ortools/sat/cp_model_presolve.cc @@ -2508,7 +2508,7 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { context_->UpdateRuleStats("linear1: infeasible"); return MarkConstraintAsFalse(ct); } - if (rhs == context_->DomainOf(var)) { + if (rhs == var_domain) { context_->UpdateRuleStats("linear1: always true"); return RemoveConstraint(ct); } @@ -2544,16 +2544,28 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { } // Detect encoding. + bool changed = false; if (ct->enforcement_literal().size() == 1) { // If we already have an encoding literal, this constraint is really // an implication. - const int lit = ct->enforcement_literal(0); + int lit = ct->enforcement_literal(0); + + // For correctness below, it is important lit is the canonical literal, + // otherwise we might remove the constraint even though it is the one + // defining an encoding literal. + const int representative = context_->GetLiteralRepresentative(lit); + if (lit != representative) { + lit = representative; + ct->set_enforcement_literal(0, lit); + context_->UpdateRuleStats("linear1: remapped enforcement literal"); + changed = true; + } if (rhs.IsFixed()) { const int64_t value = rhs.FixedValue(); int encoding_lit; if (context_->HasVarValueEncoding(var, value, &encoding_lit)) { - if (lit == encoding_lit) return false; + if (lit == encoding_lit) return changed; context_->AddImplication(lit, encoding_lit); context_->UpdateNewConstraintsVariableUsage(); ct->Clear(); @@ -2567,7 +2579,7 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { } context_->UpdateNewConstraintsVariableUsage(); } - return false; + return changed; } const Domain complement = rhs.Complement().IntersectionWith(var_domain); @@ -2575,7 +2587,7 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { const int64_t value = complement.FixedValue(); int encoding_lit; if (context_->HasVarValueEncoding(var, value, &encoding_lit)) { - if (NegatedRef(lit) == encoding_lit) return false; + if (NegatedRef(lit) == encoding_lit) return changed; context_->AddImplication(lit, NegatedRef(encoding_lit)); context_->UpdateNewConstraintsVariableUsage(); ct->Clear(); @@ -2589,11 +2601,11 @@ bool CpModelPresolver::PresolveLinearOfSizeOne(ConstraintProto* ct) { } context_->UpdateNewConstraintsVariableUsage(); } - return false; + return changed; } } - return false; + return changed; } bool CpModelPresolver::PresolveLinearOfSizeTwo(ConstraintProto* ct) { @@ -7110,9 +7122,6 @@ void CpModelPresolver::Probe() { } probing_timer->AddCounter("fixed_bools", num_fixed); - DetectDuplicateConstraintsWithDifferentEnforcements( - mapping, implication_graph, model.GetOrCreate()); - int num_equiv = 0; int num_changed_bounds = 0; const int num_variables = context_->working_model->variables().size(); @@ -7148,6 +7157,12 @@ void CpModelPresolver::Probe() { probing_timer->AddCounter("new_binary_clauses", prober->num_new_binary_clauses()); + // Note that we prefer to run this after we exported all equivalence to the + // context, so that our enforcement list can be presolved to the best of our + // knowledge. + DetectDuplicateConstraintsWithDifferentEnforcements( + mapping, implication_graph, model.GetOrCreate()); + // Stop probing timer now and display info. probing_timer.reset(); @@ -8888,37 +8903,20 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( for (const auto& [dup, rep] : duplicates_without_enforcement) { auto* dup_ct = context_->working_model->mutable_constraints(dup); auto* rep_ct = context_->working_model->mutable_constraints(rep); - if (rep_ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) { - continue; + + // Make sure our enforcement list are up to date: nothing fixed and that + // its uses the literal representatives. + if (PresolveEnforcementLiteral(dup_ct)) { + context_->UpdateConstraintVariableUsage(dup); + } + if (PresolveEnforcementLiteral(rep_ct)) { + context_->UpdateConstraintVariableUsage(rep); } - // If we have a trail, we can check if any variable of the enforcement is - // fixed to false. This is useful for what follows since calling - // implication_graph->DirectImplications() is invalid for fixed variables. - if (trail != nullptr) { - bool found_false_enforcement = false; - for (const int c : {dup, rep}) { - for (const int l : - context_->working_model->constraints(c).enforcement_literal()) { - if (trail->Assignment().LiteralIsFalse(mapping->Literal(l))) { - found_false_enforcement = true; - break; - } - } - if (found_false_enforcement) { - context_->UpdateRuleStats("enforcement: false literal"); - if (c == rep) { - rep_ct->Swap(dup_ct); - context_->UpdateConstraintVariableUsage(rep); - } - dup_ct->Clear(); - context_->UpdateConstraintVariableUsage(dup); - break; - } - } - if (found_false_enforcement) { - continue; - } + // Skip this pair if one of the constraint was simplified + if (rep_ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET || + dup_ct->constraint_case() == ConstraintProto::CONSTRAINT_NOT_SET) { + continue; } // If one of them has no enforcement, then the other can be ignored. @@ -8936,10 +8934,7 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( // Special case. This looks specific but users might reify with a cost // a duplicate constraint. In this case, no need to have two variables, // we can make them equal by duality argument. - const int a = rep_ct->enforcement_literal(0); - const int b = dup_ct->enforcement_literal(0); - if (context_->IsFixed(a) || context_->IsFixed(b)) continue; - + // // TODO(user): Deal with more general situation? Note that we already // do something similar in dual_bound_strengthening.Strengthen() were we // are more general as we just require an unique blocking constraint rather @@ -8949,6 +8944,8 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( // we can also add the equality. Alternatively, we can just introduce a new // variable and merge all duplicate constraint into 1 + bunch of boolean // constraints liking enforcements. + const int a = rep_ct->enforcement_literal(0); + const int b = dup_ct->enforcement_literal(0); if (context_->VariableWithCostIsUniqueAndRemovable(a) && context_->VariableWithCostIsUniqueAndRemovable(b)) { // Both these case should be presolved before, but it is easy to deal with @@ -9007,19 +9004,19 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( // B, then constraint A is redundant and we can remove it. const int c_a = i == 0 ? dup : rep; const int c_b = i == 0 ? rep : dup; + const auto& ct_a = context_->working_model->constraints(c_a); + const auto& ct_b = context_->working_model->constraints(c_b); enforcement_vars.clear(); implications_used.clear(); - for (const int proto_lit : - context_->working_model->constraints(c_b).enforcement_literal()) { + for (const int proto_lit : ct_b.enforcement_literal()) { const Literal lit = mapping->Literal(proto_lit); - if (trail->Assignment().LiteralIsTrue(lit)) continue; + DCHECK(!trail->Assignment().LiteralIsAssigned(lit)); enforcement_vars.insert(lit); } - for (const int proto_lit : - context_->working_model->constraints(c_a).enforcement_literal()) { + for (const int proto_lit : ct_a.enforcement_literal()) { const Literal lit = mapping->Literal(proto_lit); - if (trail->Assignment().LiteralIsTrue(lit)) continue; + DCHECK(!trail->Assignment().LiteralIsAssigned(lit)); for (const Literal implication_lit : implication_graph->DirectImplications(lit)) { auto extracted = enforcement_vars.extract(implication_lit); @@ -9029,6 +9026,71 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( } } if (enforcement_vars.empty()) { + // Tricky: Because we keep track of literal <=> var == value, we + // cannot easily simplify linear1 here. This is because a scenario + // like this can happen: + // + // We have registered the fact that a <=> X=1 because we saw two + // constraints a => X=1 and not(a) => X!= 1 + // + // Now, we are here and we have: + // a => X=1, b => X=1, a => b + // So we rewrite this as + // a => b, b => X=1 + // + // But later, the PresolveLinearOfSizeOne() see + // b => X=1 and just rewrite this as b => a since (a <=> X=1). + // This is wrong because the constraint "b => X=1" is needed for the + // equivalence (a <=> X=1), but we lost that fact. + // + // Note(user): In the scenario above we can see that a <=> b, and if + // we know that fact, then the transformation is correctly handled. + // The bug was triggered when the Probing finished early due to time + // limit and we never detected that equivalence. + // + // TODO(user): Try to find a cleaner way to handle this. We could + // query our HasVarValueEncoding() directly here and directly detect a + // <=> b. However we also need to figure the case of + // half-implications. + { + if (ct_a.constraint_case() == ConstraintProto::kLinear && + ct_a.linear().vars().size() == 1 && + ct_a.enforcement_literal().size() == 1) { + const int var = ct_a.linear().vars(0); + const Domain var_domain = context_->DomainOf(var); + const Domain rhs = + ReadDomainFromProto(ct_a.linear()) + .InverseMultiplicationBy(ct_a.linear().coeffs(0)) + .IntersectionWith(var_domain); + + // IsFixed() do not work on empty domain. + if (rhs.IsEmpty()) { + context_->UpdateRuleStats("duplicate: linear1 infeasible"); + if (!MarkConstraintAsFalse(rep_ct)) return; + if (!MarkConstraintAsFalse(dup_ct)) return; + context_->UpdateConstraintVariableUsage(rep); + context_->UpdateConstraintVariableUsage(dup); + continue; + } + if (rhs == var_domain) { + context_->UpdateRuleStats("duplicate: linear1 always true"); + rep_ct->Clear(); + dup_ct->Clear(); + context_->UpdateConstraintVariableUsage(rep); + context_->UpdateConstraintVariableUsage(dup); + continue; + } + + // We skip if it is a var == value or var != value constraint. + if (rhs.IsFixed() || + rhs.Complement().IntersectionWith(var_domain).IsFixed()) { + context_->UpdateRuleStats( + "TODO duplicate: skipped identical encoding constraints"); + continue; + } + } + } + context_->UpdateRuleStats( "duplicate: identical constraint with implied enforcements"); if (c_a == rep) { @@ -9043,12 +9105,8 @@ void CpModelPresolver::DetectDuplicateConstraintsWithDifferentEnforcements( // graph. This is because in some case the implications are only true // in the presence of the "duplicated" constraints. for (const auto& [a, b] : implications_used) { - const int var_a = - mapping->GetProtoVariableFromBooleanVariable(a.Variable()); - const int proto_lit_a = a.IsPositive() ? var_a : NegatedRef(var_a); - const int var_b = - mapping->GetProtoVariableFromBooleanVariable(b.Variable()); - const int proto_lit_b = b.IsPositive() ? var_b : NegatedRef(var_b); + const int proto_lit_a = mapping->GetProtoLiteralFromLiteral(a); + const int proto_lit_b = mapping->GetProtoLiteralFromLiteral(b); context_->AddImplication(proto_lit_a, proto_lit_b); } context_->UpdateNewConstraintsVariableUsage(); diff --git a/ortools/sat/feasibility_jump_test.cc b/ortools/sat/feasibility_jump_test.cc index c7934af218b..0e848d03ff7 100644 --- a/ortools/sat/feasibility_jump_test.cc +++ b/ortools/sat/feasibility_jump_test.cc @@ -13,6 +13,7 @@ #include "ortools/sat/feasibility_jump.h" +#include #include #include "gtest/gtest.h" diff --git a/ortools/sat/feasibility_pump.cc b/ortools/sat/feasibility_pump.cc index e1577030490..5ff2c7dd856 100644 --- a/ortools/sat/feasibility_pump.cc +++ b/ortools/sat/feasibility_pump.cc @@ -40,6 +40,7 @@ #include "ortools/sat/sat_parameters.pb.h" #include "ortools/sat/sat_solver.h" #include "ortools/sat/synchronization.h" +#include "ortools/sat/util.h" #include "ortools/util/saturated_arithmetic.h" #include "ortools/util/sorted_interval_list.h" #include "ortools/util/strong_integers.h" @@ -610,11 +611,11 @@ bool FeasibilityPump::PropagationRounding() { } const int64_t rounded_value = - static_cast(std::round(lp_solution_[var_index])); + SafeDoubleToInt64(std::round(lp_solution_[var_index])); const int64_t floor_value = - static_cast(std::floor(lp_solution_[var_index])); + SafeDoubleToInt64(std::floor(lp_solution_[var_index])); const int64_t ceil_value = - static_cast(std::ceil(lp_solution_[var_index])); + SafeDoubleToInt64(std::ceil(lp_solution_[var_index])); const bool floor_is_in_domain = (domain.Contains(floor_value) && lb.value() <= floor_value); diff --git a/ortools/sat/presolve_context.cc b/ortools/sat/presolve_context.cc index 32de1f0cf72..a5dab8dd6b1 100644 --- a/ortools/sat/presolve_context.cc +++ b/ortools/sat/presolve_context.cc @@ -1371,8 +1371,9 @@ void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) { max_literal = max_it->second.Get(this); if (min_literal != NegatedRef(max_literal)) { UpdateRuleStats("variables with 2 values: merge encoding literals"); - StoreBooleanEqualityRelation(min_literal, NegatedRef(max_literal)); - if (is_unsat_) return; + if (!StoreBooleanEqualityRelation(min_literal, NegatedRef(max_literal))) { + return; + } } min_literal = GetLiteralRepresentative(min_literal); max_literal = GetLiteralRepresentative(max_literal); @@ -1419,7 +1420,7 @@ void PresolveContext::CanonicalizeDomainOfSizeTwo(int var) { } } -void PresolveContext::InsertVarValueEncodingInternal(int literal, int var, +bool PresolveContext::InsertVarValueEncodingInternal(int literal, int var, int64_t value, bool add_constraints) { DCHECK(RefIsPositive(var)); @@ -1446,10 +1447,12 @@ void PresolveContext::InsertVarValueEncodingInternal(int literal, int var, if (literal != previous_literal) { UpdateRuleStats( "variables: merge equivalent var value encoding literals"); - StoreBooleanEqualityRelation(literal, previous_literal); + if (!StoreBooleanEqualityRelation(literal, previous_literal)) { + return false; + } } } - return; + return true; } if (DomainOf(var).Size() == 2) { @@ -1461,6 +1464,9 @@ void PresolveContext::InsertVarValueEncodingInternal(int literal, int var, AddImplyInDomain(literal, var, Domain(value)); AddImplyInDomain(NegatedRef(literal), var, Domain(value).Complement()); } + + // The canonicalization might have proven UNSAT. + return !ModelIsUnsat(); } bool PresolveContext::InsertHalfVarValueEncoding(int literal, int var, @@ -1484,8 +1490,10 @@ bool PresolveContext::InsertHalfVarValueEncoding(int literal, int var, if (other_set.contains({NegatedRef(literal), var, value})) { UpdateRuleStats("variables: detect fully reified value encoding"); const int imply_eq_literal = imply_eq ? literal : NegatedRef(literal); - InsertVarValueEncodingInternal(imply_eq_literal, var, value, - /*add_constraints=*/false); + if (!InsertVarValueEncodingInternal(imply_eq_literal, var, value, + /*add_constraints=*/false)) { + return false; + } } return true; @@ -1505,7 +1513,10 @@ bool PresolveContext::InsertVarValueEncoding(int literal, int var, return SetLiteralToFalse(literal); } literal = GetLiteralRepresentative(literal); - InsertVarValueEncodingInternal(literal, var, value, /*add_constraints=*/true); + if (!InsertVarValueEncodingInternal(literal, var, value, + /*add_constraints=*/true)) { + return false; + } eq_half_encoding_.insert({literal, var, value}); neq_half_encoding_.insert({NegatedRef(literal), var, value}); diff --git a/ortools/sat/presolve_context.h b/ortools/sat/presolve_context.h index 47cc332198d..faa7a398007 100644 --- a/ortools/sat/presolve_context.h +++ b/ortools/sat/presolve_context.h @@ -664,7 +664,8 @@ class PresolveContext { bool imply_eq); // Insert fully reified var-value encoding. - void InsertVarValueEncodingInternal(int literal, int var, int64_t value, + // Returns false if this make the problem infeasible. + bool InsertVarValueEncodingInternal(int literal, int var, int64_t value, bool add_constraints); SolverLogger* logger_; diff --git a/ortools/sat/python/cp_model_test.py b/ortools/sat/python/cp_model_test.py index d3fb0878a7a..c4f864cf174 100644 --- a/ortools/sat/python/cp_model_test.py +++ b/ortools/sat/python/cp_model_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for ortools.sat.python.cp_model.""" +import itertools from absl.testing import absltest import pandas as pd @@ -95,6 +95,20 @@ def bool_var_values(self): return self.__bool_var_values +class TimeRecorder(cp_model.CpSolverSolutionCallback): + + def __init__(self, default_time: float) -> None: + super().__init__() + self.__last_time = default_time + + def on_solution_callback(self) -> None: + self.__last_time = self.wall_time + + @property + def last_time(self): + return self.__last_time + + class LogToString: """Record log in a string.""" @@ -1649,6 +1663,215 @@ def testIntervalVarSeries(self): ) self.assertLen(model.proto.constraints, 13) + def testIssue4376SatModel(self): + print("testIssue4376SatModel") + letters: str = "BCFLMRT" + + def symbols_from_string(text: str) -> list[int]: + return [letters.index(char) for char in text] + + def rotate_symbols(symbols: list[int], turns: int) -> list[int]: + return symbols[turns:] + symbols[:turns] + + data = """FMRC +FTLB +MCBR +FRTM +FBTM +BRFM +BTRM +BCRM +RTCF +TFRC +CTRM +CBTM +TFBM +TCBM +CFTM +BLTR +RLFM +CFLM +CRML +FCLR +FBTR +TBRF +RBCF +RBCT +BCTF +TFCR +CBRT +FCBT +FRTB +RBCM +MTFC +MFTC +MBFC +RTBM +RBFM +TRFM""" + + tiles = [symbols_from_string(line) for line in data.splitlines()] + + model = cp_model.CpModel() + + # choices[i, x, y, r] is true iff we put tile i in cell (x,y) with + # rotation r. + choices = {} + for i in range(len(tiles)): + for x in range(6): + for y in range(6): + for r in range(4): + choices[(i, x, y, r)] = model.new_bool_var( + f"tile_{i}_{x}_{y}_{r}" + ) + + # corners[x, y, s] is true iff the corner at (x,y) contains symbol s. + corners = {} + for x in range(7): + for y in range(7): + for s in range(7): + corners[(x, y, s)] = model.new_bool_var(f"corner_{x}_{y}_{s}") + + # Placing a tile puts a symbol in each corner. + for (i, x, y, r), choice in choices.items(): + symbols = rotate_symbols(tiles[i], r) + model.add_implication(choice, corners[x, y, symbols[0]]) + model.add_implication(choice, corners[x, y + 1, symbols[1]]) + model.add_implication(choice, corners[x + 1, y + 1, symbols[2]]) + model.add_implication(choice, corners[x + 1, y, symbols[3]]) + + # We must make exactly one choice for each tile. + for i in range(len(tiles)): + tmp_literals = [] + for x in range(6): + for y in range(6): + for r in range(4): + tmp_literals.append(choices[(i, x, y, r)]) + model.add_exactly_one(tmp_literals) + + # We must make exactly one choice for each square. + for x, y in itertools.product(range(6), range(6)): + tmp_literals = [] + for i in range(len(tiles)): + for r in range(4): + tmp_literals.append(choices[(i, x, y, r)]) + model.add_exactly_one(tmp_literals) + + # Each corner contains exactly one symbol. + for x, y in itertools.product(range(7), range(7)): + model.add_exactly_one(corners[x, y, s] for s in range(7)) + + # Solve. + solver = cp_model.CpSolver() + solver.parameters.num_workers = 8 + solver.parameters.max_time_in_seconds = 20 + solver.parameters.log_search_progress = True + solver.parameters.cp_model_presolve = False + solver.parameters.symmetry_level = 0 + + callback = TimeRecorder(solver.parameters.max_time_in_seconds) + solver.Solve(model, callback) + self.assertLess(solver.wall_time, callback.last_time + 5.0) + + def testIssue4376MinimizeModel(self): + print("testIssue4376MinimizeModel") + + model = cp_model.CpModel() + + jobs = [ + [3, 3], # [duration, width] + [2, 5], + [1, 3], + [3, 7], + [7, 3], + [2, 2], + [2, 2], + [5, 5], + [10, 2], + [4, 3], + [2, 6], + [1, 2], + [6, 8], + [4, 5], + [3, 7], + ] + + max_width = 10 + + horizon = sum(t[0] for t in jobs) + num_jobs = len(jobs) + all_jobs = range(num_jobs) + + intervals = [] + intervals0 = [] + intervals1 = [] + performed = [] + starts = [] + ends = [] + demands = [] + + for i in all_jobs: + # Create main interval. + start = model.new_int_var(0, horizon, f"start_{i}") + duration = jobs[i][0] + end = model.new_int_var(0, horizon, f"end_{i}") + interval = model.new_interval_var(start, duration, end, f"interval_{i}") + starts.append(start) + intervals.append(interval) + ends.append(end) + demands.append(jobs[i][1]) + + # Create an optional copy of interval to be executed on machine 0. + performed_on_m0 = model.new_bool_var(f"perform_{i}_on_m0") + performed.append(performed_on_m0) + start0 = model.new_int_var(0, horizon, f"start_{i}_on_m0") + end0 = model.new_int_var(0, horizon, f"end_{i}_on_m0") + interval0 = model.new_optional_interval_var( + start0, duration, end0, performed_on_m0, f"interval_{i}_on_m0" + ) + intervals0.append(interval0) + + # Create an optional copy of interval to be executed on machine 1. + start1 = model.new_int_var(0, horizon, f"start_{i}_on_m1") + end1 = model.new_int_var(0, horizon, f"end_{i}_on_m1") + interval1 = model.new_optional_interval_var( + start1, + duration, + end1, + ~performed_on_m0, + f"interval_{i}_on_m1", + ) + intervals1.append(interval1) + + # We only propagate the constraint if the tasks is performed on the + # machine. + model.add(start0 == start).only_enforce_if(performed_on_m0) + model.add(start1 == start).only_enforce_if(~performed_on_m0) + + # Width constraint (modeled as a cumulative) + model.add_cumulative(intervals, demands, max_width) + + # Choose which machine to perform the jobs on. + model.add_no_overlap(intervals0) + model.add_no_overlap(intervals1) + + # Objective variable. + makespan = model.new_int_var(0, horizon, "makespan") + model.add_max_equality(makespan, ends) + model.minimize(makespan) + + # Symmetry breaking. + model.add(performed[0] == 0) + + # Solve. + solver = cp_model.CpSolver() + solver.parameters.num_workers = 8 + solver.parameters.max_time_in_seconds = 50 + solver.parameters.log_search_progress = True + callback = TimeRecorder(solver.parameters.max_time_in_seconds) + solver.Solve(model, callback) + self.assertLess(solver.wall_time, callback.last_time + 5.0) + if __name__ == "__main__": absltest.main() diff --git a/ortools/sat/swig_helper.cc b/ortools/sat/swig_helper.cc index b03de25b90c..0d9b045c4e6 100644 --- a/ortools/sat/swig_helper.cc +++ b/ortools/sat/swig_helper.cc @@ -15,7 +15,6 @@ #include -#include #include #include @@ -27,9 +26,9 @@ #include "ortools/sat/cp_model_utils.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_parameters.pb.h" +#include "ortools/sat/util.h" #include "ortools/util/logging.h" #include "ortools/util/sorted_interval_list.h" -#include "ortools/util/time_limit.h" namespace operations_research { namespace sat { @@ -90,18 +89,15 @@ bool SolutionCallback::SolutionBooleanValue(int index) { } void SolutionCallback::StopSearch() { - if (stopped_ptr_ != nullptr) { - (*stopped_ptr_) = true; - } + if (wrapper_ != nullptr) wrapper_->StopSearch(); } operations_research::sat::CpSolverResponse SolutionCallback::Response() const { return response_; } -void SolutionCallback::SetAtomicBooleanToStopTheSearch( - std::atomic* stopped_ptr) const { - stopped_ptr_ = stopped_ptr; +void SolutionCallback::SetWrapperClass(SolveWrapper* wrapper) const { + wrapper_ = wrapper; } bool SolutionCallback::HasResponse() const { return has_response_; } @@ -116,15 +112,13 @@ void SolveWrapper::SetStringParameters(const std::string& string_parameters) { } void SolveWrapper::AddSolutionCallback(const SolutionCallback& callback) { - // Overwrite the atomic bool. - callback.SetAtomicBooleanToStopTheSearch(&stopped_); + callback.SetWrapperClass(this); model_.Add(NewFeasibleSolutionObserver( [&callback](const CpSolverResponse& r) { return callback.Run(r); })); } void SolveWrapper::ClearSolutionCallback(const SolutionCallback& callback) { - // cleanup the atomic bool. - callback.SetAtomicBooleanToStopTheSearch(nullptr); + callback.SetWrapperClass(nullptr); // Detach the wrapper class. } void SolveWrapper::AddLogCallback( @@ -157,11 +151,13 @@ void SolveWrapper::AddBestBoundCallbackFromClass(BestBoundCallback* callback) { operations_research::sat::CpSolverResponse SolveWrapper::Solve( const operations_research::sat::CpModelProto& model_proto) { FixFlagsAndEnvironmentForSwig(); - model_.GetOrCreate()->RegisterExternalBooleanAsLimit(&stopped_); return operations_research::sat::SolveCpModel(model_proto, &model_); } -void SolveWrapper::StopSearch() { stopped_ = true; } +void SolveWrapper::StopSearch() { + model_.GetOrCreate()->Stop(); +} + std::string CpSatHelper::ModelStats( const operations_research::sat::CpModelProto& model_proto) { return CpModelStats(model_proto); diff --git a/ortools/sat/swig_helper.h b/ortools/sat/swig_helper.h index e9821b620d1..3a9cfeec691 100644 --- a/ortools/sat/swig_helper.h +++ b/ortools/sat/swig_helper.h @@ -14,24 +14,20 @@ #ifndef OR_TOOLS_SAT_SWIG_HELPER_H_ #define OR_TOOLS_SAT_SWIG_HELPER_H_ -#include #include #include #include #include "ortools/sat/cp_model.pb.h" -#include "ortools/sat/cp_model_checker.h" -#include "ortools/sat/cp_model_solver.h" -#include "ortools/sat/cp_model_utils.h" #include "ortools/sat/model.h" #include "ortools/sat/sat_parameters.pb.h" -#include "ortools/util/logging.h" #include "ortools/util/sorted_interval_list.h" -#include "ortools/util/time_limit.h" namespace operations_research { namespace sat { +class SolveWrapper; + // Base class for SWIG director based on solution callbacks. // See http://www.swig.org/Doc4.0/SWIGDocumentation.html#CSharp_directors. class SolutionCallback { @@ -72,14 +68,14 @@ class SolutionCallback { operations_research::sat::CpSolverResponse Response() const; // We use mutable and non const methods to overcome SWIG difficulties. - void SetAtomicBooleanToStopTheSearch(std::atomic* stopped_ptr) const; + void SetWrapperClass(SolveWrapper* wrapper) const; bool HasResponse() const; private: mutable CpSolverResponse response_; mutable bool has_response_ = false; - mutable std::atomic* stopped_ptr_; + mutable SolveWrapper* wrapper_ = nullptr; }; // Simple director class for C#. @@ -126,7 +122,6 @@ class SolveWrapper { private: Model model_; - std::atomic stopped_ = false; }; // Static methods are stored in a module which name can vary.