From f993b4146066718813b1e09b457e52b0c174a9f2 Mon Sep 17 00:00:00 2001 From: Vasily Nesterov Date: Thu, 9 May 2024 22:39:39 +0000 Subject: [PATCH] feat(Tactic/Linarith): Simplex Algorithm oracle (#12014) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Reduce the `linarith` certificate search problem to some Linear Programming problem and implement the Simplex Algorithm to solve it. - Set the default oracle for `linarith` to this. - Remove unnecessary hypotheses in `Mathlib/Analysis/Calculus/BumpFunction/FiniteDimension.lean` and `Mathlib/Analysis/Distribution/SchwartzSpace.lean` which were needed with the Fourier-Motzkin oracle. - Adjust the definition of `CerficateOracle` to enable dot notation to choose between oracles. This addresses #2717 and #8875 (except when the user overrides the oracle) Also, this oracle is far more efficient: The example below takes lots of time with the Fourier-Motzkin oracle: I waited 5 minutes and still didn't get it. But with the just implemented `Linarith.SimplexAlgo.produceCertificate` oracle the proof succeeds in less than a second. ```lean import Mathlib.Tactic.Linarith example (x0 x1 x2 x3 x4 : ℚ) : 3 * x0 - 15 * x1 - 30 * x2 + 20 * x3 + 12 * x4 ≤ 0 → 35 * x0 - 30 * x1 + 12 * x2 - 15 * x3 + 18 * x4 ≤ 0 → 5 * x0 + 20 * x1 + 24 * x2 + 20 * x3 + 9 * x4 ≤ 0 → -2 * x0 - 30 * x1 + 30 * x2 - 10 * x3 - 12 * x4 ≤ 0 → -4 * x0 - 25 * x1 + 6 * x2 - 20 * x3 ≤ 0 → 25 * x1 + 30 * x2 - 25 * x3 + 12 * x4 ≤ 0 → 10 * x1 - 18 * x2 - 30 * x3 + 18 * x4 ≤ 0 → 4 * x0 + 10 * x1 - 18 * x2 - 15 * x3 + 15 * x4 ≤ 0 → -4 * x0 + 15 * x1 - 30 * x2 + 15 * x3 + 6 * x4 ≤ 0 → -2 * x0 - 5 * x1 + 18 * x2 - 25 * x3 - 161 * x4 ≤ 0 → -6 * x0 + 30 * x1 + 6 * x2 - 15 * x3 ≤ 0 → 3 * x0 + 10 * x1 - 30 * x2 + 25 * x3 + 12 * x4 ≤ 0 → 2 * x0 + 10 * x1 - 24 * x2 - 15 * x3 + 3 * x4 ≤ 0 → 82 * x1 + 36 * x2 + 20 * x3 + 9 * x4 ≤ 0 → 2 * x0 - 30 * x1 - 30 * x2 - 15 * x3 + 6 * x4 ≤ 0 → 4 * x0 - 15 * x1 + 25 * x2 < 0 → -4 * x0 - 10 * x1 + 30 * x2 - 15 * x3 ≤ 0 → 2 * x0 + 6 * x2 + 133 * x3 + 12 * x4 ≤ 0 → 3 * x0 + 15 * x1 - 6 * x2 - 15 * x3 - 15 * x4 ≤ 0 → 10 * x1 + 6 * x2 - 25 * x3 + 3 * x4 ≤ 0 → -2 * x0 + 5 * x1 + 12 * x2 - 20 * x3 + 12 * x4 ≤ 0 → -5 * x0 - 25 * x1 + 30 * x3 - 12 * x4 ≤ 0 → -6 * x0 - 30 * x1 - 36 * x2 + 20 * x3 + 12 * x4 ≤ 0 → 5 * x0 - 5 * x1 + 6 * x2 - 25 * x3 ≤ 0 → -3 * x0 - 20 * x1 - 30 * x2 + 5 * x3 + 3 * x4 ≤ 0 → False := by intros; linarith (config := {oracle := Linarith.FourierMotzkin.produceCertificate}) ``` I am planning to prove the "completeness" of the oracle in the next PRs, but so far I have run a stress test on randomly generated examples of various sizes, and it seems that everything is OK. Co-authored-by: Eric Wieser --- Mathlib.lean | 7 +- .../BumpFunction/FiniteDimension.lean | 1 - .../Analysis/Distribution/SchwartzSpace.lean | 1 - Mathlib/Tactic.lean | 7 +- Mathlib/Tactic/Linarith/Datatypes.lean | 19 +-- .../FourierMotzkin.lean} | 11 +- .../Linarith/Oracle/SimplexAlgorithm.lean | 42 ++++++ .../Linarith/SimplexAlgorithm/Datatypes.lean | 47 +++++++ .../Linarith/SimplexAlgorithm/Gauss.lean | 97 ++++++++++++++ .../SimplexAlgorithm/PositiveVector.lean | 79 ++++++++++++ .../SimplexAlgorithm/SimplexAlgorithm.lean | 120 ++++++++++++++++++ Mathlib/Tactic/Linarith/Verification.lean | 7 +- test/linarith.lean | 108 +++++++++++++++- 13 files changed, 519 insertions(+), 27 deletions(-) rename Mathlib/Tactic/Linarith/{Elimination.lean => Oracle/FourierMotzkin.lean} (97%) create mode 100644 Mathlib/Tactic/Linarith/Oracle/SimplexAlgorithm.lean create mode 100644 Mathlib/Tactic/Linarith/SimplexAlgorithm/Datatypes.lean create mode 100644 Mathlib/Tactic/Linarith/SimplexAlgorithm/Gauss.lean create mode 100644 Mathlib/Tactic/Linarith/SimplexAlgorithm/PositiveVector.lean create mode 100644 Mathlib/Tactic/Linarith/SimplexAlgorithm/SimplexAlgorithm.lean diff --git a/Mathlib.lean b/Mathlib.lean index b4f5569dc57c5..2b7e022237f1d 100644 --- a/Mathlib.lean +++ b/Mathlib.lean @@ -3713,11 +3713,16 @@ import Mathlib.Tactic.Lift import Mathlib.Tactic.LiftLets import Mathlib.Tactic.Linarith import Mathlib.Tactic.Linarith.Datatypes -import Mathlib.Tactic.Linarith.Elimination import Mathlib.Tactic.Linarith.Frontend import Mathlib.Tactic.Linarith.Lemmas +import Mathlib.Tactic.Linarith.Oracle.FourierMotzkin +import Mathlib.Tactic.Linarith.Oracle.SimplexAlgorithm import Mathlib.Tactic.Linarith.Parsing import Mathlib.Tactic.Linarith.Preprocessing +import Mathlib.Tactic.Linarith.SimplexAlgorithm.Datatypes +import Mathlib.Tactic.Linarith.SimplexAlgorithm.Gauss +import Mathlib.Tactic.Linarith.SimplexAlgorithm.PositiveVector +import Mathlib.Tactic.Linarith.SimplexAlgorithm.SimplexAlgorithm import Mathlib.Tactic.Linarith.Verification import Mathlib.Tactic.LinearCombination import Mathlib.Tactic.Lint diff --git a/Mathlib/Analysis/Calculus/BumpFunction/FiniteDimension.lean b/Mathlib/Analysis/Calculus/BumpFunction/FiniteDimension.lean index 7c3adf2ad021b..292514ef65ebc 100644 --- a/Mathlib/Analysis/Calculus/BumpFunction/FiniteDimension.lean +++ b/Mathlib/Analysis/Calculus/BumpFunction/FiniteDimension.lean @@ -556,7 +556,6 @@ instance (priority := 100) {E : Type*} [NormedAddCommGroup E] [NormedSpace ℝ E _ = 1 - (R - 1) / (R + 1) := by field_simp; ring support := fun R hR => by have A : 0 < (R + 1) / 2 := by linarith - have A' : 0 < R + 1 := by linarith have C : (R - 1) / (R + 1) < 1 := by apply (div_lt_one _).2 <;> linarith simp only [hR, if_true, support_comp_inv_smul₀ A.ne', y_support _ (IR R hR) C, _root_.smul_ball A.ne', Real.norm_of_nonneg A.le, smul_zero] diff --git a/Mathlib/Analysis/Distribution/SchwartzSpace.lean b/Mathlib/Analysis/Distribution/SchwartzSpace.lean index 5f251ff7d496f..ebf8c47f4f03f 100644 --- a/Mathlib/Analysis/Distribution/SchwartzSpace.lean +++ b/Mathlib/Analysis/Distribution/SchwartzSpace.lean @@ -693,7 +693,6 @@ integral in terms of suitable seminorms of `f`. -/ lemma pow_mul_le_of_le_of_pow_mul_le {C₁ C₂ : ℝ} {k l : ℕ} {x f : ℝ} (hx : 0 ≤ x) (hf : 0 ≤ f) (h₁ : f ≤ C₁) (h₂ : x ^ (k + l) * f ≤ C₂) : x ^ k * f ≤ 2 ^ l * (C₁ + C₂) * (1 + x) ^ (- (l : ℝ)) := by - have : 0 ≤ C₁ := le_trans (by positivity) h₁ have : 0 ≤ C₂ := le_trans (by positivity) h₂ have : 2 ^ l * (C₁ + C₂) * (1 + x) ^ (- (l : ℝ)) = ((1 + x) / 2) ^ (-(l:ℝ)) * (C₁ + C₂) := by rw [Real.div_rpow (by linarith) zero_le_two] diff --git a/Mathlib/Tactic.lean b/Mathlib/Tactic.lean index eedcd3e9c8fbc..44470b3d232fa 100644 --- a/Mathlib/Tactic.lean +++ b/Mathlib/Tactic.lean @@ -97,11 +97,16 @@ import Mathlib.Tactic.Lift import Mathlib.Tactic.LiftLets import Mathlib.Tactic.Linarith import Mathlib.Tactic.Linarith.Datatypes -import Mathlib.Tactic.Linarith.Elimination import Mathlib.Tactic.Linarith.Frontend import Mathlib.Tactic.Linarith.Lemmas +import Mathlib.Tactic.Linarith.Oracle.FourierMotzkin +import Mathlib.Tactic.Linarith.Oracle.SimplexAlgorithm import Mathlib.Tactic.Linarith.Parsing import Mathlib.Tactic.Linarith.Preprocessing +import Mathlib.Tactic.Linarith.SimplexAlgorithm.Datatypes +import Mathlib.Tactic.Linarith.SimplexAlgorithm.Gauss +import Mathlib.Tactic.Linarith.SimplexAlgorithm.PositiveVector +import Mathlib.Tactic.Linarith.SimplexAlgorithm.SimplexAlgorithm import Mathlib.Tactic.Linarith.Verification import Mathlib.Tactic.LinearCombination import Mathlib.Tactic.Lint diff --git a/Mathlib/Tactic/Linarith/Datatypes.lean b/Mathlib/Tactic/Linarith/Datatypes.lean index 6cc9f4106c55b..6544f88a8d603 100644 --- a/Mathlib/Tactic/Linarith/Datatypes.lean +++ b/Mathlib/Tactic/Linarith/Datatypes.lean @@ -298,18 +298,19 @@ instance GlobalPreprocessorToGlobalBranchingPreprocessor : ⟨GlobalPreprocessor.branching⟩ /-- -A `CertificateOracle` is a function +A `CertificateOracle` provides a function `produceCertificate : List Comp → Nat → MetaM (HashMap Nat Nat)`. -`produceCertificate hyps max_var` tries to derive a contradiction from the comparisons in `hyps` -by eliminating all variables ≤ `max_var`. -If successful, it returns a map `coeff : Nat → Nat` as a certificate. -This map represents that we can find a contradiction by taking the sum `∑ (coeff i) * hyps[i]`. The default `CertificateOracle` used by `linarith` is -`Linarith.FourierMotzkin.produceCertificate`. +`Linarith.CertificateOracle.simplexAlgorithm`. +`Linarith.CertificateOracle.fourierMotzkin` is also available (though has some bugs). -/ -def CertificateOracle : Type := - List Comp → Nat → MetaM (Batteries.HashMap Nat Nat) +structure CertificateOracle : Type where + /-- `produceCertificate hyps max_var` tries to derive a contradiction from the comparisons in + `hyps` by eliminating all variables ≤ `max_var`. + If successful, it returns a map `coeff : Nat → Nat` as a certificate. + This map represents that we can find a contradiction by taking the sum `∑ (coeff i) * hyps[i]`. -/ + produceCertificate (hyps : List Comp) (max_var : Nat) : MetaM (Batteries.HashMap Nat Nat) open Meta @@ -334,7 +335,7 @@ structure LinarithConfig : Type where /-- Override the list of preprocessors. -/ preprocessors : Option (List GlobalBranchingPreprocessor) := none /-- Specify an oracle for identifying candidate contradictions. - The only implementation here is Fourier-Motzkin elimination. -/ + `.simplexAlgorithm` and `.fourierMotzkin` are both available. -/ oracle : Option CertificateOracle := none /-- diff --git a/Mathlib/Tactic/Linarith/Elimination.lean b/Mathlib/Tactic/Linarith/Oracle/FourierMotzkin.lean similarity index 97% rename from Mathlib/Tactic/Linarith/Elimination.lean rename to Mathlib/Tactic/Linarith/Oracle/FourierMotzkin.lean index bece3e64b4a95..90d2224bcd22f 100644 --- a/Mathlib/Tactic/Linarith/Elimination.lean +++ b/Mathlib/Tactic/Linarith/Oracle/FourierMotzkin.lean @@ -324,14 +324,9 @@ those hypotheses. It produces an initial state for the elimination monad. def mkLinarithData (hyps : List Comp) (maxVar : ℕ) : LinarithData := ⟨maxVar, .ofList (hyps.enum.map fun ⟨n, cmp⟩ => PComp.assump cmp n) _⟩ -/-- -`produceCertificate hyps vars` tries to derive a contradiction from the comparisons in `hyps` -by eliminating all variables ≤ `maxVar`. -If successful, it returns a map `coeff : ℕ → ℕ` as a certificate. -This map represents that we can find a contradiction by taking the sum `∑ (coeff i) * hyps[i]`. --/ -def FourierMotzkin.produceCertificate : CertificateOracle := - fun hyps maxVar => match ExceptT.run +/-- An oracle that uses Fourier-Motzkin elimination. -/ +def CertificateOracle.fourierMotzkin : CertificateOracle where + produceCertificate hyps maxVar := match ExceptT.run (StateT.run (do validate; elimAllVarsM : LinarithM Unit) (mkLinarithData hyps maxVar)) with | (Except.ok _) => failure | (Except.error contr) => return contr.src.flatten diff --git a/Mathlib/Tactic/Linarith/Oracle/SimplexAlgorithm.lean b/Mathlib/Tactic/Linarith/Oracle/SimplexAlgorithm.lean new file mode 100644 index 0000000000000..8ad59ccf9a14e --- /dev/null +++ b/Mathlib/Tactic/Linarith/Oracle/SimplexAlgorithm.lean @@ -0,0 +1,42 @@ +/- +Copyright (c) 2024 Vasily Nesterov. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Vasily Nesterov +-/ +import Mathlib.Tactic.Linarith.Datatypes +import Mathlib.Tactic.Linarith.SimplexAlgorithm.PositiveVector + +/-! +# Hooks to enable the use of the simplex algorithm in `linarith` +-/ + +open Batteries + +namespace Linarith.SimplexAlgorithm + +/-- Preprocess the goal to pass it to `findPositiveVector`. -/ +def preprocess (hyps : List Comp) (maxVar : ℕ) : Matrix (maxVar + 1) (hyps.length) × List Nat := + let mdata : Array (Array ℚ) := Array.ofFn fun i : Fin (maxVar + 1) => + Array.mk <| hyps.map (·.coeffOf i) + let strictIndexes : List ℕ := hyps.findIdxs (·.str == Ineq.lt) + ⟨⟨mdata⟩, strictIndexes⟩ + +/-- Extract the certificate from the `vec` found by `findPositiveVector`. -/ +def postprocess (vec : Array ℚ) : HashMap ℕ ℕ := + let common_den : ℕ := vec.foldl (fun acc item => acc.lcm item.den) 1 + let vecNat : Array ℕ := vec.map (fun x : ℚ => (x * common_den).floor.toNat) + HashMap.ofList <| vecNat.toList.enum.filter (fun ⟨_, item⟩ => item != 0) + + +end SimplexAlgorithm + +open SimplexAlgorithm + +/-- An oracle that uses the simplex algorithm. -/ +def CertificateOracle.simplexAlgorithm : CertificateOracle where + produceCertificate hyps maxVar := do + let ⟨A, strictIndexes⟩ := preprocess hyps maxVar + let vec := findPositiveVector A strictIndexes + return postprocess vec + +end Linarith diff --git a/Mathlib/Tactic/Linarith/SimplexAlgorithm/Datatypes.lean b/Mathlib/Tactic/Linarith/SimplexAlgorithm/Datatypes.lean new file mode 100644 index 0000000000000..9c66604324813 --- /dev/null +++ b/Mathlib/Tactic/Linarith/SimplexAlgorithm/Datatypes.lean @@ -0,0 +1,47 @@ +/- +Copyright (c) 2024 Vasily Nesterov. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Vasily Nesterov +-/ +import Batteries.Data.Rat.Basic + +/-! +# Datatypes for Simplex Algorithm implementation +-/ + +namespace Linarith.SimplexAlgorithm + +/-- +Structure for matrices over ℚ. + +So far it is just a 2d-array carrying dimensions (that are supposed to match with the actual +dimensions of `data`), but the plan is to add some `Prop`-data and make the structure strict and +safe. + +Note: we avoid using the `Matrix` from `Mathlib.Data.Matrix` because it is far more efficient to +store matrix as its entries than as function between `Fin`-s. +-/ +structure Matrix (n m : Nat) where + /-- The content of the matrix. -/ + data : Array (Array Rat) + -- hn_pos : n > 0 + -- hm_pos : m > 0 + -- hn : data.size = n + -- hm (i : Fin n) : data[i].size = m + +instance (n m : Nat) : GetElem (Matrix n m) Nat (Array Rat) fun _ i => i < n where + getElem mat i _ := mat.data[i]! + +/-- +`Table` is a structure Simplex Algorithm operates on. The `i`-th row of `mat` expresses the +variable `basic[i]` as a linear combination of variables from `free`. +-/ +structure Table where + /-- Array containing the basic variables' indexes -/ + basic : Array Nat + /-- Array containing the free variables' indexes -/ + free : Array Nat + /-- Matrix of coefficients the basic variables expressed through the free ones. -/ + mat : Matrix basic.size free.size + +end Linarith.SimplexAlgorithm diff --git a/Mathlib/Tactic/Linarith/SimplexAlgorithm/Gauss.lean b/Mathlib/Tactic/Linarith/SimplexAlgorithm/Gauss.lean new file mode 100644 index 0000000000000..46a036b56a2d6 --- /dev/null +++ b/Mathlib/Tactic/Linarith/SimplexAlgorithm/Gauss.lean @@ -0,0 +1,97 @@ +/- +Copyright (c) 2024 Vasily Nesterov. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Vasily Nesterov +-/ +import Mathlib.Tactic.Linarith.SimplexAlgorithm.Datatypes + +/-! +# Gaussian Elimination algorithm + +The first step of `Linarith.SimplexAlgorithm.findPositiveVector` is finding initial feasible +solution which is done by standard Gaussian Elimination algorithm implemented in this file. +-/ + +namespace Linarith.SimplexAlgorithm.Gauss + +/-- The monad for the Gaussian Elimination algorithm. -/ +abbrev GaussM (n m : Nat) := StateM <| Matrix n m + +/-- Finds the first row starting from the current row with nonzero element in current column. -/ +def findNonzeroRow (row col : Nat) {n m : Nat} : GaussM n m <| Option Nat := do + for i in [row:n] do + if (← get)[i]![col]! != 0 then + return i + return .none + +/-- Swaps two rows. -/ +def swapRows {n m : Nat} (i j : Nat) : GaussM n m Unit := do + if i != j then + modify fun mat => + let swapped : Matrix n m := ⟨mat.data.swap! i j⟩ + swapped + +/-- Subtracts `i`-th row * `coef` from `j`-th row. -/ +def subtractRow {n m : Nat} (i j : Nat) (coef : Rat) : GaussM n m Unit := + modify fun mat => + let newData : Array (Array Rat) := mat.data.modify j fun row => + row.zipWith mat[i]! fun x y => x - coef * y + ⟨newData⟩ + +/-- Divides row by `coef`. -/ +def divideRow {n m : Nat} (i : Nat) (coef : Rat) : GaussM n m Unit := + modify fun mat => + let newData : Array (Array Rat) := mat.data.modify i (·.map (· / coef)) + ⟨newData⟩ + +/-- Implementation of `getTable` in `GaussM` monad. -/ +def getTableImp {n m : Nat} : GaussM n m Table := do + let mut free : Array Nat := #[] + let mut basic : Array Nat := #[] + + let mut row : Nat := 0 + let mut col : Nat := 0 + + while row < n && col < m do + match ← findNonzeroRow row col with + | .none => + free := free.push col + col := col + 1 + continue + | .some rowToSwap => + swapRows row rowToSwap + + divideRow row (← get)[row]![col]! + + for i in [:n] do + if i == row then + continue + let coef := (← get)[i]![col]! + subtractRow row i coef + + basic := basic.push col + row := row + 1 + col := col + 1 + + for i in [col:m] do + free := free.push i + + let ansData : Array (Array Rat) := ← do + let mat := (← get) + return Array.ofFn (fun row : Fin row => free.map fun f => -mat[row]![f]!) + + return { + free := free + basic := basic + mat := ⟨ansData⟩ + } + +/-- +Given matrix `A`, solves the linear equation `A x = 0` and returns the solution as a table where +some variables are free and others (basic) variable are expressed as linear combinations of the free +ones. +-/ +def getTable {n m : Nat} (A : Matrix n m) : Table := Id.run do + return (← getTableImp.run A).fst + +end Linarith.SimplexAlgorithm.Gauss diff --git a/Mathlib/Tactic/Linarith/SimplexAlgorithm/PositiveVector.lean b/Mathlib/Tactic/Linarith/SimplexAlgorithm/PositiveVector.lean new file mode 100644 index 0000000000000..c7b511149116b --- /dev/null +++ b/Mathlib/Tactic/Linarith/SimplexAlgorithm/PositiveVector.lean @@ -0,0 +1,79 @@ +/- +Copyright (c) 2024 Vasily Nesterov. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Vasily Nesterov +-/ +import Mathlib.Tactic.Linarith.SimplexAlgorithm.SimplexAlgorithm +import Mathlib.Tactic.Linarith.SimplexAlgorithm.Gauss + +/-! +# `linarith` certificate search a LP problem + +`linarith` certificate search can easily be reduced to this LP problem: given the matrix `A` and the +list `strictIndexes`, find the non-negative vector `v` such that some of its coordinates from +the `strictIndexes` are positive and `A v = 0`. + +The function `findPositiveVector` solves this problem. +-/ + +namespace Linarith.SimplexAlgorithm + +/-- +Given matrix `A` and list `strictIndexes` of strict inequalities' indexes, we want to state the +Linear Programming problem which solution produces solution for the initial problem (see +`findPositiveVector`). + +As an objective function (that we are trying to maximize) we use sum of coordinates from +`strictIndexes`: it suffices to find the non-negative vector that makes this function positive. + +We introduce two auxiliary variables and one constraint: +* The variable `y` is interpreted as "homogenized" `1`. We need it because dealing with a + homogenized problem is easier, but having some "unit" is necessary. +* To bound the problem we add the constraint `x₁ + ... + xₘ + z = y` introducing new variable `z`. + +The objective function also interpreted as an auxiliary variable with constraint +`f = ∑ i ∈ strictIndexes, xᵢ`. + +The variable `f` has to always be basic while `y` has to be free. Our Gauss method implementation +greedy collects basic variables moving from left to right. So we place `f` before `x`-s and `y` +after them. We place `z` between `f` and `x` because in this case `z` will be basic and +`Gauss.getTable` produce table with non-negative last column, meaning that we are starting from +a feasible point. +-/ +def stateLP {n m : Nat} (A : Matrix n m) (strictIndexes : List Nat) : Matrix (n + 2) (m + 3) := + Id.run do + let mut objectiveRow : Array Rat := #[-1, 0] ++ (Array.mkArray m 0) ++ #[0] + for idx in strictIndexes do + objectiveRow := objectiveRow.set! (idx + 2) 1 -- +2 due to shifting by `f` and `z` + + let constraintRow : Array Rat := #[0, 1] ++ (Array.mkArray m 1) ++ #[-1] + + let data : Array (Array Rat) := #[objectiveRow, constraintRow] + ++ A.data.map (#[0, 0] ++ · ++ #[0]) + + return ⟨data⟩ + +/-- Extracts target vector from the table, putting auxilary variables aside (see `stateLP`). -/ +def extractSolution (table : Table) : Array Rat := Id.run do + let mut ans : Array Rat := Array.mkArray (table.basic.size + table.free.size - 3) 0 + for i in [1:table.mat.data.size] do + ans := ans.set! (table.basic[i]! - 2) table.mat.data[i]!.back + return ans + +/-- +Finds nonnegative vector `v`, such that `A v = 0` and some of its coordinates from `strictCoords` +are positive, in the case such `v` exists. +-/ +def findPositiveVector {n m : Nat} (A : Matrix n m) (strictIndexes : List Nat) : Array Rat := + /- State the linear programming problem. -/ + let B := stateLP A strictIndexes + + /- Using Gaussian elimination split variable into free and basic forming the table that will be + operated by Simplex Algorithm. -/ + let initTable := Gauss.getTable B + + /- Run Simplex Algorithm and extract the solution. -/ + let resTable := runSimplexAlgorithm initTable + extractSolution resTable + +end Linarith.SimplexAlgorithm diff --git a/Mathlib/Tactic/Linarith/SimplexAlgorithm/SimplexAlgorithm.lean b/Mathlib/Tactic/Linarith/SimplexAlgorithm/SimplexAlgorithm.lean new file mode 100644 index 0000000000000..4cf28b12331ae --- /dev/null +++ b/Mathlib/Tactic/Linarith/SimplexAlgorithm/SimplexAlgorithm.lean @@ -0,0 +1,120 @@ +/- +Copyright (c) 2024 Vasily Nesterov. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Vasily Nesterov +-/ +import Mathlib.Tactic.Linarith.SimplexAlgorithm.Datatypes + +/-! +# Simplex Algorithm + +To obtain required vector in `Linarith.SimplexAlgorithm.findPositiveVector` we run the Simplex +Algorithm. We use Bland's rule for pivoting, which guarantees that the algorithm terminates. +-/ + +namespace Linarith.SimplexAlgorithm + +/-- An exception in the `SimplexAlgorithmM` monad. -/ +inductive SimplexAlgorithmException + /-- The solution is infeasible. -/ +| infeasible : SimplexAlgorithmException + +/-- The mutable state for the `SimplexAlgorithmM` monad. -/ +structure SimplexAlgorithmState where + /-- Current table. -/ + table : Table + +/-- The monad for the Simplex Algorithm. -/ +abbrev SimplexAlgorithmM := ExceptT SimplexAlgorithmException <| StateM SimplexAlgorithmState + +/-- +Given indexes `exitIdx` and `enterIdx` of exiting and entering variables in the `basic` and `free` +arrays, performs pivot operation, i.e. expresses one through the other and makes the free one basic +and vice versa. +-/ +def doPivotOperation (exitIdx enterIdx : Nat) : SimplexAlgorithmM Unit := do + let mat := (← get).table.mat + let intersectCoef := mat[exitIdx]![enterIdx]! + + let mut newCurRow := mat[exitIdx]! + newCurRow := newCurRow.set! enterIdx (-1) + newCurRow := newCurRow.map (- · / intersectCoef) + let mut newData : Array (Array Rat) := mat.data.map fun row => + let newRow := row.zipWith mat[exitIdx]! fun x y => x - row[enterIdx]! * y / intersectCoef + newRow.set! enterIdx <| row[enterIdx]! / intersectCoef + newData := newData.set! exitIdx newCurRow + + let newBasic : Array Nat := (← get).table.basic.set! exitIdx (← get).table.free[enterIdx]! + let newFree : Array Nat := (← get).table.free.set! enterIdx (← get).table.basic[exitIdx]! + + let newMat : Matrix newBasic.size newFree.size := ⟨newData⟩ + set ({← get with table := ⟨newBasic, newFree, newMat⟩} : SimplexAlgorithmState) + +/-- +Check if the solution is found: the objective function is positive and all basic variables are +nonnegative. +-/ +def checkSuccess : SimplexAlgorithmM Bool := do + return (← get).table.mat[0]!.back > 0 && (← get).table.mat.data.all (fun row => row.back >= 0) + +/-- +Chooses an entering variable: among the variables with a positive coefficient in the objective +function, the one with the smallest index (in the initial indexing). +-/ +def chooseEnteringVar : SimplexAlgorithmM Nat := do + let mut enterIdxOpt : Option Nat := .none -- index of entering variable in the `free` array + let mut minIdx := 0 + for i in [:(← get).table.mat[0]!.size - 1] do + if (← get).table.mat[0]![i]! > 0 && (enterIdxOpt.isNone || (← get).table.free[i]! < minIdx) then + enterIdxOpt := i + minIdx := (← get).table.free[i]! + + /- If there is no such variable the solution does not exist for sure. -/ + match enterIdxOpt with + | .none => throw SimplexAlgorithmException.infeasible + | .some enterIdx => return enterIdx + +/-- +Chooses an exiting variable: the variable imposing the strictest limit on the increase of the +entering variable, breaking ties by choosing the variable with smallest index. +-/ +def chooseExitingVar (enterIdx : Nat) : SimplexAlgorithmM Nat := do + let mut exitIdxOpt : Option Nat := .none -- index of entering variable in the `basic` array + let mut minCoef := 0 + let mut minIdx := 0 + for i in [1:(← get).table.mat.data.size] do + if (← get).table.mat[i]![enterIdx]! >= 0 then + continue + let coef := -(← get).table.mat[i]!.back / (← get).table.mat[i]![enterIdx]! + if exitIdxOpt.isNone || coef < minCoef || + (coef == minCoef && (← get).table.basic[i]! < minIdx) then + exitIdxOpt := i + minCoef := coef + minIdx := (← get).table.basic[i]! + return exitIdxOpt.get! -- such variable always exists because our problem is bounded + +/-- +Chooses entering and exiting variables using Bland's rule that guarantees that the Simplex +Algorithm terminates. +-/ +def choosePivots : SimplexAlgorithmM (Nat × Nat) := do + let enterIdx ← chooseEnteringVar + let exitIdx ← chooseExitingVar enterIdx + return ⟨exitIdx, enterIdx⟩ + +/-- Implementation of `runSimplexAlgorithm` in `SimplexAlgorithmM` monad. -/ +def runSimplexAlgorithmImp : SimplexAlgorithmM Unit := do + while !(← checkSuccess) do + let ⟨exitIdx, enterIdx⟩ ← try + choosePivots + catch | .infeasible => return + doPivotOperation exitIdx enterIdx + +/-- +Runs Simplex Algorithm starting with `initTable`. It always terminates, finding solution if +such exists. Returns the table obtained at the last step. +-/ +def runSimplexAlgorithm (initTable : Table) : Table := Id.run do + return (← runSimplexAlgorithmImp.run ⟨initTable⟩).snd.table + +end Linarith.SimplexAlgorithm diff --git a/Mathlib/Tactic/Linarith/Verification.lean b/Mathlib/Tactic/Linarith/Verification.lean index 9f4091b9c9f41..ac81c2ec63808 100644 --- a/Mathlib/Tactic/Linarith/Verification.lean +++ b/Mathlib/Tactic/Linarith/Verification.lean @@ -4,7 +4,8 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Robert Y. Lewis -/ -import Mathlib.Tactic.Linarith.Elimination +-- import Mathlib.Tactic.Linarith.Oracle.FourierMotzkin +import Mathlib.Tactic.Linarith.Oracle.SimplexAlgorithm import Mathlib.Tactic.Linarith.Parsing import Mathlib.Util.Qq @@ -202,10 +203,10 @@ def proveFalseByLinarith (cfg : LinarithConfig) : MVarId → List Expr → MetaM let (comps, max_var) ← linearFormsAndMaxVar cfg.transparency inputs trace[linarith.detail] "... finished `linearFormsAndMaxVar`." trace[linarith.detail] "{comps}" - let oracle := cfg.oracle.getD FourierMotzkin.produceCertificate + let oracle := cfg.oracle.getD (.simplexAlgorithm) -- perform the elimination and fail if no contradiction is found. let certificate : Batteries.HashMap Nat Nat ← try - oracle comps max_var + oracle.produceCertificate comps max_var catch e => trace[linarith] e.toMessageData throwError "linarith failed to find a contradiction" diff --git a/test/linarith.lean b/test/linarith.lean index 18944c556e196..954126dae6d68 100644 --- a/test/linarith.lean +++ b/test/linarith.lean @@ -1,4 +1,5 @@ import Mathlib.Tactic.Linarith +import Mathlib.Tactic.Linarith.Oracle.FourierMotzkin import Mathlib.Algebra.BigOperators.Basic import Mathlib.Algebra.Order.Ring.Int import Mathlib.Data.Nat.Interval @@ -7,6 +8,7 @@ import Mathlib.Data.Rat.Order private axiom test_sorry : ∀ {α}, α set_option linter.unusedVariables false set_option autoImplicit true +set_option pp.mvars false example [LinearOrderedCommRing α] {a b : α} (h : a < b) (w : b < a) : False := by linarith @@ -431,12 +433,19 @@ lemma norm_nonpos_left (x y : ℚ) (h1 : x * x + y * y ≤ 0) : x = 0 := by variable {E : Type _} [AddGroup E] example (f : ℤ → E) (h : 0 = f 0) : 1 ≤ 2 := by nlinarith -example (a : E) (h : a = a) : 1 ≤ 2 := by nlinarith +example (a : E) (h : a = a) : 1 ≤ 2 := by nlinarith example (p q r s t u v w : ℕ) (h1 : p + u = q + t) (h2 : r + w = s + v) : p * r + q * s + (t * w + u * v) = p * s + q * r + (t * v + u * w) := by nlinarith +-- note: much faster than the simplex algorithm (the default oracle for `linarith`) +-- TODO: make the simplex algorithm able to work with sparse matrices. This should speed up +-- `nlinarith` because it passes large and sparse matrices to the oracle. +example (p q r s t u v w : ℕ) (h1 : p + u = q + t) (h2 : r + w = s + v) : + p * r + q * s + (t * w + u * v) = p * s + q * r + (t * v + u * w) := +by nlinarith (config := { oracle := some .fourierMotzkin }) + -- Tests involving a norm, including that squares in a type where `sq_nonneg` does not apply -- do not cause an exception variable {R : Type _} [Ring R] (abs : R → ℚ) @@ -574,7 +583,7 @@ example (q : Prop) (p : ∀ (x : ℤ), q → 1 = 2) : 1 = 2 := by /-- error: Argument passed to linarith has metavariables: - p ?a + p ?_ -/ #guard_msgs in example (q : Prop) (p : ∀ (x : ℤ), 1 = 2) : 1 = 2 := by @@ -582,7 +591,7 @@ example (q : Prop) (p : ∀ (x : ℤ), 1 = 2) : 1 = 2 := by /-- error: Argument passed to nlinarith has metavariables: - p ?a + p ?_ -/ #guard_msgs in example (q : Prop) (p : ∀ (x : ℤ), 1 = 2) : 1 = 2 := by @@ -601,3 +610,96 @@ example (h : False): True := by example (x : Nat) : 0 ≤ x ^ 9890 := by fail_if_success linarith -- this should not stack overflow apply zero_le + +/-- https://github.com/leanprover-community/mathlib4/issues/8875 -/ +example (a b c d e : ℚ) + (ha : 2 * a + b + c + d + e = 4) + (hb : a + 2 * b + c + d + e = 5) + (hc : a + b + 2 * c + d + e = 6) + (hd : a + b + c + 2 * d + e = 7) + (he : a + b + c + d + 2 * e = 8) : + e = 3 := by + linarith + +/-- https://github.com/leanprover-community/mathlib4/issues/2717 -/ +example : + (3 * x4 - x3 - x2 - x1 : ℚ) < 0 → + x5 - x4 < 0 → + 2 * (x5 - x4) < 0 → + -x6 + x3 < 0 → + -x6 + x2 < 0 → + 2 * (x6 - x5) < 0 → + x8 - x7 < 0 → + -x8 + x2 < 0 → + -x8 + x7 - x5 + x1 < 0 → + x7 - x5 < 0 → + False := by + intros; linarith + +-- TODO: still broken with Fourier-Motzkin +/-- +error: linarith failed to find a contradiction +case h1.h +E : Type _ +inst✝¹ : AddGroup E +R : Type _ +inst✝ : Ring R +abs : R → ℚ +a b c d e : ℚ +ha : 2 * a + b + c + d + e = 4 +hb : a + 2 * b + c + d + e = 5 +hc : a + b + 2 * c + d + e = 6 +hd : a + b + c + 2 * d + e = 7 +he : a + b + c + d + 2 * e = 8 +a✝ : e < 3 +⊢ False +failed +-/ +#guard_msgs in +/-- https://github.com/leanprover-community/mathlib4/issues/8875 -/ +example (a b c d e : ℚ) + (ha : 2 * a + b + c + d + e = 4) + (hb : a + 2 * b + c + d + e = 5) + (hc : a + b + 2 * c + d + e = 6) + (hd : a + b + c + 2 * d + e = 7) + (he : a + b + c + d + 2 * e = 8) : + e = 3 := by + linarith (config := { oracle := some .fourierMotzkin }) + +-- TODO: still broken with Fourier-Motzkin +/-- +error: linarith failed to find a contradiction +E : Type _ +inst✝¹ : AddGroup E +R : Type _ +inst✝ : Ring R +abs : R → ℚ +x4 x3 x2 x1 x5 x6 x8 x7 : ℚ +a✝⁹ : 3 * x4 - x3 - x2 - x1 < 0 +a✝⁸ : x5 - x4 < 0 +a✝⁷ : 2 * (x5 - x4) < 0 +a✝⁶ : -x6 + x3 < 0 +a✝⁵ : -x6 + x2 < 0 +a✝⁴ : 2 * (x6 - x5) < 0 +a✝³ : x8 - x7 < 0 +a✝² : -x8 + x2 < 0 +a✝¹ : -x8 + x7 - x5 + x1 < 0 +a✝ : x7 - x5 < 0 +⊢ False +failed +-/ +#guard_msgs in +/-- https://github.com/leanprover-community/mathlib4/issues/2717 -/ +example : + (3 * x4 - x3 - x2 - x1 : ℚ) < 0 → + x5 - x4 < 0 → + 2 * (x5 - x4) < 0 → + -x6 + x3 < 0 → + -x6 + x2 < 0 → + 2 * (x6 - x5) < 0 → + x8 - x7 < 0 → + -x8 + x2 < 0 → + -x8 + x7 - x5 + x1 < 0 → + x7 - x5 < 0 → False := by + intros + linarith (config := { oracle := some .fourierMotzkin })