Skip to content

Commit

Permalink
feat(Tactic/Linarith): Simplex Algorithm oracle (#12014)
Browse files Browse the repository at this point in the history
- Reduce the `linarith` certificate search problem to some Linear Programming problem and implement the Simplex Algorithm to solve it. 
- Set the default oracle for `linarith` to this.
- Remove unnecessary hypotheses in `Mathlib/Analysis/Calculus/BumpFunction/FiniteDimension.lean` and `Mathlib/Analysis/Distribution/SchwartzSpace.lean` which were needed with the Fourier-Motzkin oracle.
- Adjust the definition of `CerficateOracle` to enable dot notation to choose between oracles.

This addresses #2717 and #8875 (except when the user overrides the oracle)

Also, this oracle is far more efficient:
The example below takes lots of time with the Fourier-Motzkin oracle: I waited 5 minutes and still didn't get it. But with the just implemented `Linarith.SimplexAlgo.produceCertificate` oracle the proof succeeds in less than a second.
```lean
import Mathlib.Tactic.Linarith

example (x0 x1 x2 x3 x4 : ℚ) :
  3 * x0 - 15 * x1 - 30 * x2 + 20 * x3 + 12 * x4 ≤ 0 →
  35 * x0 - 30 * x1 + 12 * x2 - 15 * x3 + 18 * x4 ≤ 0 →
  5 * x0 + 20 * x1 + 24 * x2 + 20 * x3 + 9 * x4 ≤ 0 →
  -2 * x0 - 30 * x1 + 30 * x2 - 10 * x3 - 12 * x4 ≤ 0 →
  -4 * x0 - 25 * x1 + 6 * x2 - 20 * x3 ≤ 0 →
  25 * x1 + 30 * x2 - 25 * x3 + 12 * x4 ≤ 0 →
  10 * x1 - 18 * x2 - 30 * x3 + 18 * x4 ≤ 0 →
  4 * x0 + 10 * x1 - 18 * x2 - 15 * x3 + 15 * x4 ≤ 0 →
  -4 * x0 + 15 * x1 - 30 * x2 + 15 * x3 + 6 * x4 ≤ 0 →
  -2 * x0 - 5 * x1 + 18 * x2 - 25 * x3 - 161 * x4 ≤ 0 →
  -6 * x0 + 30 * x1 + 6 * x2 - 15 * x3 ≤ 0 →
  3 * x0 + 10 * x1 - 30 * x2 + 25 * x3 + 12 * x4 ≤ 0 →
  2 * x0 + 10 * x1 - 24 * x2 - 15 * x3 + 3 * x4 ≤ 0 →
  82 * x1 + 36 * x2 + 20 * x3 + 9 * x4 ≤ 0 →
  2 * x0 - 30 * x1 - 30 * x2 - 15 * x3 + 6 * x4 ≤ 0 →
  4 * x0 - 15 * x1 + 25 * x2 < 0 →
  -4 * x0 - 10 * x1 + 30 * x2 - 15 * x3 ≤ 0 →
  2 * x0 + 6 * x2 + 133 * x3 + 12 * x4 ≤ 0 →
  3 * x0 + 15 * x1 - 6 * x2 - 15 * x3 - 15 * x4 ≤ 0 →
  10 * x1 + 6 * x2 - 25 * x3 + 3 * x4 ≤ 0 →
  -2 * x0 + 5 * x1 + 12 * x2 - 20 * x3 + 12 * x4 ≤ 0 →
  -5 * x0 - 25 * x1 + 30 * x3 - 12 * x4 ≤ 0 →
  -6 * x0 - 30 * x1 - 36 * x2 + 20 * x3 + 12 * x4 ≤ 0 →
  5 * x0 - 5 * x1 + 6 * x2 - 25 * x3 ≤ 0 →
  -3 * x0 - 20 * x1 - 30 * x2 + 5 * x3 + 3 * x4 ≤ 0 →
  False := by intros; linarith (config := {oracle := Linarith.FourierMotzkin.produceCertificate})
```

I am planning to prove the "completeness" of the oracle in the next PRs, but so far I have run a stress test on randomly generated examples of various sizes, and it seems that everything is OK.



Co-authored-by: Eric Wieser <[email protected]>
  • Loading branch information
vasnesterov and eric-wieser committed May 9, 2024
1 parent f67cc81 commit f993b41
Show file tree
Hide file tree
Showing 13 changed files with 519 additions and 27 deletions.
7 changes: 6 additions & 1 deletion Mathlib.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3713,11 +3713,16 @@ import Mathlib.Tactic.Lift
import Mathlib.Tactic.LiftLets
import Mathlib.Tactic.Linarith
import Mathlib.Tactic.Linarith.Datatypes
import Mathlib.Tactic.Linarith.Elimination
import Mathlib.Tactic.Linarith.Frontend
import Mathlib.Tactic.Linarith.Lemmas
import Mathlib.Tactic.Linarith.Oracle.FourierMotzkin
import Mathlib.Tactic.Linarith.Oracle.SimplexAlgorithm
import Mathlib.Tactic.Linarith.Parsing
import Mathlib.Tactic.Linarith.Preprocessing
import Mathlib.Tactic.Linarith.SimplexAlgorithm.Datatypes
import Mathlib.Tactic.Linarith.SimplexAlgorithm.Gauss
import Mathlib.Tactic.Linarith.SimplexAlgorithm.PositiveVector
import Mathlib.Tactic.Linarith.SimplexAlgorithm.SimplexAlgorithm
import Mathlib.Tactic.Linarith.Verification
import Mathlib.Tactic.LinearCombination
import Mathlib.Tactic.Lint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,6 @@ instance (priority := 100) {E : Type*} [NormedAddCommGroup E] [NormedSpace ℝ E
_ = 1 - (R - 1) / (R + 1) := by field_simp; ring
support := fun R hR => by
have A : 0 < (R + 1) / 2 := by linarith
have A' : 0 < R + 1 := by linarith
have C : (R - 1) / (R + 1) < 1 := by apply (div_lt_one _).2 <;> linarith
simp only [hR, if_true, support_comp_inv_smul₀ A.ne', y_support _ (IR R hR) C,
_root_.smul_ball A.ne', Real.norm_of_nonneg A.le, smul_zero]
Expand Down
1 change: 0 additions & 1 deletion Mathlib/Analysis/Distribution/SchwartzSpace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,6 @@ integral in terms of suitable seminorms of `f`. -/
lemma pow_mul_le_of_le_of_pow_mul_le {C₁ C₂ : ℝ} {k l : ℕ} {x f : ℝ} (hx : 0 ≤ x) (hf : 0 ≤ f)
(h₁ : f ≤ C₁) (h₂ : x ^ (k + l) * f ≤ C₂) :
x ^ k * f ≤ 2 ^ l * (C₁ + C₂) * (1 + x) ^ (- (l : ℝ)) := by
have : 0 ≤ C₁ := le_trans (by positivity) h₁
have : 0 ≤ C₂ := le_trans (by positivity) h₂
have : 2 ^ l * (C₁ + C₂) * (1 + x) ^ (- (l : ℝ)) = ((1 + x) / 2) ^ (-(l:ℝ)) * (C₁ + C₂) := by
rw [Real.div_rpow (by linarith) zero_le_two]
Expand Down
7 changes: 6 additions & 1 deletion Mathlib/Tactic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,16 @@ import Mathlib.Tactic.Lift
import Mathlib.Tactic.LiftLets
import Mathlib.Tactic.Linarith
import Mathlib.Tactic.Linarith.Datatypes
import Mathlib.Tactic.Linarith.Elimination
import Mathlib.Tactic.Linarith.Frontend
import Mathlib.Tactic.Linarith.Lemmas
import Mathlib.Tactic.Linarith.Oracle.FourierMotzkin
import Mathlib.Tactic.Linarith.Oracle.SimplexAlgorithm
import Mathlib.Tactic.Linarith.Parsing
import Mathlib.Tactic.Linarith.Preprocessing
import Mathlib.Tactic.Linarith.SimplexAlgorithm.Datatypes
import Mathlib.Tactic.Linarith.SimplexAlgorithm.Gauss
import Mathlib.Tactic.Linarith.SimplexAlgorithm.PositiveVector
import Mathlib.Tactic.Linarith.SimplexAlgorithm.SimplexAlgorithm
import Mathlib.Tactic.Linarith.Verification
import Mathlib.Tactic.LinearCombination
import Mathlib.Tactic.Lint
Expand Down
19 changes: 10 additions & 9 deletions Mathlib/Tactic/Linarith/Datatypes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -298,18 +298,19 @@ instance GlobalPreprocessorToGlobalBranchingPreprocessor :
⟨GlobalPreprocessor.branching⟩

/--
A `CertificateOracle` is a function
A `CertificateOracle` provides a function
`produceCertificate : List Comp → Nat → MetaM (HashMap Nat Nat)`.
`produceCertificate hyps max_var` tries to derive a contradiction from the comparisons in `hyps`
by eliminating all variables ≤ `max_var`.
If successful, it returns a map `coeff : Nat → Nat` as a certificate.
This map represents that we can find a contradiction by taking the sum `∑ (coeff i) * hyps[i]`.
The default `CertificateOracle` used by `linarith` is
`Linarith.FourierMotzkin.produceCertificate`.
`Linarith.CertificateOracle.simplexAlgorithm`.
`Linarith.CertificateOracle.fourierMotzkin` is also available (though has some bugs).
-/
def CertificateOracle : Type :=
List Comp → Nat → MetaM (Batteries.HashMap Nat Nat)
structure CertificateOracle : Type where
/-- `produceCertificate hyps max_var` tries to derive a contradiction from the comparisons in
`hyps` by eliminating all variables ≤ `max_var`.
If successful, it returns a map `coeff : Nat → Nat` as a certificate.
This map represents that we can find a contradiction by taking the sum `∑ (coeff i) * hyps[i]`. -/
produceCertificate (hyps : List Comp) (max_var : Nat) : MetaM (Batteries.HashMap Nat Nat)

open Meta

Expand All @@ -334,7 +335,7 @@ structure LinarithConfig : Type where
/-- Override the list of preprocessors. -/
preprocessors : Option (List GlobalBranchingPreprocessor) := none
/-- Specify an oracle for identifying candidate contradictions.
The only implementation here is Fourier-Motzkin elimination. -/
`.simplexAlgorithm` and `.fourierMotzkin` are both available. -/
oracle : Option CertificateOracle := none

/--
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,14 +324,9 @@ those hypotheses. It produces an initial state for the elimination monad.
def mkLinarithData (hyps : List Comp) (maxVar : ℕ) : LinarithData :=
⟨maxVar, .ofList (hyps.enum.map fun ⟨n, cmp⟩ => PComp.assump cmp n) _⟩

/--
`produceCertificate hyps vars` tries to derive a contradiction from the comparisons in `hyps`
by eliminating all variables ≤ `maxVar`.
If successful, it returns a map `coeff : ℕ → ℕ` as a certificate.
This map represents that we can find a contradiction by taking the sum `∑ (coeff i) * hyps[i]`.
-/
def FourierMotzkin.produceCertificate : CertificateOracle :=
fun hyps maxVar => match ExceptT.run
/-- An oracle that uses Fourier-Motzkin elimination. -/
def CertificateOracle.fourierMotzkin : CertificateOracle where
produceCertificate hyps maxVar := match ExceptT.run
(StateT.run (do validate; elimAllVarsM : LinarithM Unit) (mkLinarithData hyps maxVar)) with
| (Except.ok _) => failure
| (Except.error contr) => return contr.src.flatten
Expand Down
42 changes: 42 additions & 0 deletions Mathlib/Tactic/Linarith/Oracle/SimplexAlgorithm.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/-
Copyright (c) 2024 Vasily Nesterov. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Vasily Nesterov
-/
import Mathlib.Tactic.Linarith.Datatypes
import Mathlib.Tactic.Linarith.SimplexAlgorithm.PositiveVector

/-!
# Hooks to enable the use of the simplex algorithm in `linarith`
-/

open Batteries

namespace Linarith.SimplexAlgorithm

/-- Preprocess the goal to pass it to `findPositiveVector`. -/
def preprocess (hyps : List Comp) (maxVar : ℕ) : Matrix (maxVar + 1) (hyps.length) × List Nat :=
let mdata : Array (Array ℚ) := Array.ofFn fun i : Fin (maxVar + 1) =>
Array.mk <| hyps.map (·.coeffOf i)
let strictIndexes : List ℕ := hyps.findIdxs (·.str == Ineq.lt)
⟨⟨mdata⟩, strictIndexes⟩

/-- Extract the certificate from the `vec` found by `findPositiveVector`. -/
def postprocess (vec : Array ℚ) : HashMap ℕ ℕ :=
let common_den : ℕ := vec.foldl (fun acc item => acc.lcm item.den) 1
let vecNat : Array ℕ := vec.map (fun x : ℚ => (x * common_den).floor.toNat)
HashMap.ofList <| vecNat.toList.enum.filter (fun ⟨_, item⟩ => item != 0)


end SimplexAlgorithm

open SimplexAlgorithm

/-- An oracle that uses the simplex algorithm. -/
def CertificateOracle.simplexAlgorithm : CertificateOracle where
produceCertificate hyps maxVar := do
let ⟨A, strictIndexes⟩ := preprocess hyps maxVar
let vec := findPositiveVector A strictIndexes
return postprocess vec

end Linarith
47 changes: 47 additions & 0 deletions Mathlib/Tactic/Linarith/SimplexAlgorithm/Datatypes.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/-
Copyright (c) 2024 Vasily Nesterov. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Vasily Nesterov
-/
import Batteries.Data.Rat.Basic

/-!
# Datatypes for Simplex Algorithm implementation
-/

namespace Linarith.SimplexAlgorithm

/--
Structure for matrices over ℚ.
So far it is just a 2d-array carrying dimensions (that are supposed to match with the actual
dimensions of `data`), but the plan is to add some `Prop`-data and make the structure strict and
safe.
Note: we avoid using the `Matrix` from `Mathlib.Data.Matrix` because it is far more efficient to
store matrix as its entries than as function between `Fin`-s.
-/
structure Matrix (n m : Nat) where
/-- The content of the matrix. -/
data : Array (Array Rat)
-- hn_pos : n > 0
-- hm_pos : m > 0
-- hn : data.size = n
-- hm (i : Fin n) : data[i].size = m

instance (n m : Nat) : GetElem (Matrix n m) Nat (Array Rat) fun _ i => i < n where
getElem mat i _ := mat.data[i]!

/--
`Table` is a structure Simplex Algorithm operates on. The `i`-th row of `mat` expresses the
variable `basic[i]` as a linear combination of variables from `free`.
-/
structure Table where
/-- Array containing the basic variables' indexes -/
basic : Array Nat
/-- Array containing the free variables' indexes -/
free : Array Nat
/-- Matrix of coefficients the basic variables expressed through the free ones. -/
mat : Matrix basic.size free.size

end Linarith.SimplexAlgorithm
97 changes: 97 additions & 0 deletions Mathlib/Tactic/Linarith/SimplexAlgorithm/Gauss.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/-
Copyright (c) 2024 Vasily Nesterov. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Vasily Nesterov
-/
import Mathlib.Tactic.Linarith.SimplexAlgorithm.Datatypes

/-!
# Gaussian Elimination algorithm
The first step of `Linarith.SimplexAlgorithm.findPositiveVector` is finding initial feasible
solution which is done by standard Gaussian Elimination algorithm implemented in this file.
-/

namespace Linarith.SimplexAlgorithm.Gauss

/-- The monad for the Gaussian Elimination algorithm. -/
abbrev GaussM (n m : Nat) := StateM <| Matrix n m

/-- Finds the first row starting from the current row with nonzero element in current column. -/
def findNonzeroRow (row col : Nat) {n m : Nat} : GaussM n m <| Option Nat := do
for i in [row:n] do
if (← get)[i]![col]! != 0 then
return i
return .none

/-- Swaps two rows. -/
def swapRows {n m : Nat} (i j : Nat) : GaussM n m Unit := do
if i != j then
modify fun mat =>
let swapped : Matrix n m := ⟨mat.data.swap! i j⟩
swapped

/-- Subtracts `i`-th row * `coef` from `j`-th row. -/
def subtractRow {n m : Nat} (i j : Nat) (coef : Rat) : GaussM n m Unit :=
modify fun mat =>
let newData : Array (Array Rat) := mat.data.modify j fun row =>
row.zipWith mat[i]! fun x y => x - coef * y
⟨newData⟩

/-- Divides row by `coef`. -/
def divideRow {n m : Nat} (i : Nat) (coef : Rat) : GaussM n m Unit :=
modify fun mat =>
let newData : Array (Array Rat) := mat.data.modify i (·.map (· / coef))
⟨newData⟩

/-- Implementation of `getTable` in `GaussM` monad. -/
def getTableImp {n m : Nat} : GaussM n m Table := do
let mut free : Array Nat := #[]
let mut basic : Array Nat := #[]

let mut row : Nat := 0
let mut col : Nat := 0

while row < n && col < m do
match ← findNonzeroRow row col with
| .none =>
free := free.push col
col := col + 1
continue
| .some rowToSwap =>
swapRows row rowToSwap

divideRow row (← get)[row]![col]!

for i in [:n] do
if i == row then
continue
let coef := (← get)[i]![col]!
subtractRow row i coef

basic := basic.push col
row := row + 1
col := col + 1

for i in [col:m] do
free := free.push i

let ansData : Array (Array Rat) := ← do
let mat := (← get)
return Array.ofFn (fun row : Fin row => free.map fun f => -mat[row]![f]!)

return {
free := free
basic := basic
mat := ⟨ansData⟩
}

/--
Given matrix `A`, solves the linear equation `A x = 0` and returns the solution as a table where
some variables are free and others (basic) variable are expressed as linear combinations of the free
ones.
-/
def getTable {n m : Nat} (A : Matrix n m) : Table := Id.run do
return (← getTableImp.run A).fst

end Linarith.SimplexAlgorithm.Gauss
79 changes: 79 additions & 0 deletions Mathlib/Tactic/Linarith/SimplexAlgorithm/PositiveVector.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/-
Copyright (c) 2024 Vasily Nesterov. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Vasily Nesterov
-/
import Mathlib.Tactic.Linarith.SimplexAlgorithm.SimplexAlgorithm
import Mathlib.Tactic.Linarith.SimplexAlgorithm.Gauss

/-!
# `linarith` certificate search a LP problem
`linarith` certificate search can easily be reduced to this LP problem: given the matrix `A` and the
list `strictIndexes`, find the non-negative vector `v` such that some of its coordinates from
the `strictIndexes` are positive and `A v = 0`.
The function `findPositiveVector` solves this problem.
-/

namespace Linarith.SimplexAlgorithm

/--
Given matrix `A` and list `strictIndexes` of strict inequalities' indexes, we want to state the
Linear Programming problem which solution produces solution for the initial problem (see
`findPositiveVector`).
As an objective function (that we are trying to maximize) we use sum of coordinates from
`strictIndexes`: it suffices to find the non-negative vector that makes this function positive.
We introduce two auxiliary variables and one constraint:
* The variable `y` is interpreted as "homogenized" `1`. We need it because dealing with a
homogenized problem is easier, but having some "unit" is necessary.
* To bound the problem we add the constraint `x₁ + ... + xₘ + z = y` introducing new variable `z`.
The objective function also interpreted as an auxiliary variable with constraint
`f = ∑ i ∈ strictIndexes, xᵢ`.
The variable `f` has to always be basic while `y` has to be free. Our Gauss method implementation
greedy collects basic variables moving from left to right. So we place `f` before `x`-s and `y`
after them. We place `z` between `f` and `x` because in this case `z` will be basic and
`Gauss.getTable` produce table with non-negative last column, meaning that we are starting from
a feasible point.
-/
def stateLP {n m : Nat} (A : Matrix n m) (strictIndexes : List Nat) : Matrix (n + 2) (m + 3) :=
Id.run do
let mut objectiveRow : Array Rat := #[-1, 0] ++ (Array.mkArray m 0) ++ #[0]
for idx in strictIndexes do
objectiveRow := objectiveRow.set! (idx + 2) 1 -- +2 due to shifting by `f` and `z`

let constraintRow : Array Rat := #[0, 1] ++ (Array.mkArray m 1) ++ #[-1]

let data : Array (Array Rat) := #[objectiveRow, constraintRow]
++ A.data.map (#[0, 0] ++ · ++ #[0])

return ⟨data⟩

/-- Extracts target vector from the table, putting auxilary variables aside (see `stateLP`). -/
def extractSolution (table : Table) : Array Rat := Id.run do
let mut ans : Array Rat := Array.mkArray (table.basic.size + table.free.size - 3) 0
for i in [1:table.mat.data.size] do
ans := ans.set! (table.basic[i]! - 2) table.mat.data[i]!.back
return ans

/--
Finds nonnegative vector `v`, such that `A v = 0` and some of its coordinates from `strictCoords`
are positive, in the case such `v` exists.
-/
def findPositiveVector {n m : Nat} (A : Matrix n m) (strictIndexes : List Nat) : Array Rat :=
/- State the linear programming problem. -/
let B := stateLP A strictIndexes

/- Using Gaussian elimination split variable into free and basic forming the table that will be
operated by Simplex Algorithm. -/
let initTable := Gauss.getTable B

/- Run Simplex Algorithm and extract the solution. -/
let resTable := runSimplexAlgorithm initTable
extractSolution resTable

end Linarith.SimplexAlgorithm
Loading

0 comments on commit f993b41

Please sign in to comment.