-
Notifications
You must be signed in to change notification settings - Fork 330
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(Tactic/Linarith): Simplex Algorithm oracle (#12014)
- Reduce the `linarith` certificate search problem to some Linear Programming problem and implement the Simplex Algorithm to solve it. - Set the default oracle for `linarith` to this. - Remove unnecessary hypotheses in `Mathlib/Analysis/Calculus/BumpFunction/FiniteDimension.lean` and `Mathlib/Analysis/Distribution/SchwartzSpace.lean` which were needed with the Fourier-Motzkin oracle. - Adjust the definition of `CerficateOracle` to enable dot notation to choose between oracles. This addresses #2717 and #8875 (except when the user overrides the oracle) Also, this oracle is far more efficient: The example below takes lots of time with the Fourier-Motzkin oracle: I waited 5 minutes and still didn't get it. But with the just implemented `Linarith.SimplexAlgo.produceCertificate` oracle the proof succeeds in less than a second. ```lean import Mathlib.Tactic.Linarith example (x0 x1 x2 x3 x4 : ℚ) : 3 * x0 - 15 * x1 - 30 * x2 + 20 * x3 + 12 * x4 ≤ 0 → 35 * x0 - 30 * x1 + 12 * x2 - 15 * x3 + 18 * x4 ≤ 0 → 5 * x0 + 20 * x1 + 24 * x2 + 20 * x3 + 9 * x4 ≤ 0 → -2 * x0 - 30 * x1 + 30 * x2 - 10 * x3 - 12 * x4 ≤ 0 → -4 * x0 - 25 * x1 + 6 * x2 - 20 * x3 ≤ 0 → 25 * x1 + 30 * x2 - 25 * x3 + 12 * x4 ≤ 0 → 10 * x1 - 18 * x2 - 30 * x3 + 18 * x4 ≤ 0 → 4 * x0 + 10 * x1 - 18 * x2 - 15 * x3 + 15 * x4 ≤ 0 → -4 * x0 + 15 * x1 - 30 * x2 + 15 * x3 + 6 * x4 ≤ 0 → -2 * x0 - 5 * x1 + 18 * x2 - 25 * x3 - 161 * x4 ≤ 0 → -6 * x0 + 30 * x1 + 6 * x2 - 15 * x3 ≤ 0 → 3 * x0 + 10 * x1 - 30 * x2 + 25 * x3 + 12 * x4 ≤ 0 → 2 * x0 + 10 * x1 - 24 * x2 - 15 * x3 + 3 * x4 ≤ 0 → 82 * x1 + 36 * x2 + 20 * x3 + 9 * x4 ≤ 0 → 2 * x0 - 30 * x1 - 30 * x2 - 15 * x3 + 6 * x4 ≤ 0 → 4 * x0 - 15 * x1 + 25 * x2 < 0 → -4 * x0 - 10 * x1 + 30 * x2 - 15 * x3 ≤ 0 → 2 * x0 + 6 * x2 + 133 * x3 + 12 * x4 ≤ 0 → 3 * x0 + 15 * x1 - 6 * x2 - 15 * x3 - 15 * x4 ≤ 0 → 10 * x1 + 6 * x2 - 25 * x3 + 3 * x4 ≤ 0 → -2 * x0 + 5 * x1 + 12 * x2 - 20 * x3 + 12 * x4 ≤ 0 → -5 * x0 - 25 * x1 + 30 * x3 - 12 * x4 ≤ 0 → -6 * x0 - 30 * x1 - 36 * x2 + 20 * x3 + 12 * x4 ≤ 0 → 5 * x0 - 5 * x1 + 6 * x2 - 25 * x3 ≤ 0 → -3 * x0 - 20 * x1 - 30 * x2 + 5 * x3 + 3 * x4 ≤ 0 → False := by intros; linarith (config := {oracle := Linarith.FourierMotzkin.produceCertificate}) ``` I am planning to prove the "completeness" of the oracle in the next PRs, but so far I have run a stress test on randomly generated examples of various sizes, and it seems that everything is OK. Co-authored-by: Eric Wieser <[email protected]>
- Loading branch information
1 parent
f67cc81
commit f993b41
Showing
13 changed files
with
519 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
/- | ||
Copyright (c) 2024 Vasily Nesterov. All rights reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: Vasily Nesterov | ||
-/ | ||
import Mathlib.Tactic.Linarith.Datatypes | ||
import Mathlib.Tactic.Linarith.SimplexAlgorithm.PositiveVector | ||
|
||
/-! | ||
# Hooks to enable the use of the simplex algorithm in `linarith` | ||
-/ | ||
|
||
open Batteries | ||
|
||
namespace Linarith.SimplexAlgorithm | ||
|
||
/-- Preprocess the goal to pass it to `findPositiveVector`. -/ | ||
def preprocess (hyps : List Comp) (maxVar : ℕ) : Matrix (maxVar + 1) (hyps.length) × List Nat := | ||
let mdata : Array (Array ℚ) := Array.ofFn fun i : Fin (maxVar + 1) => | ||
Array.mk <| hyps.map (·.coeffOf i) | ||
let strictIndexes : List ℕ := hyps.findIdxs (·.str == Ineq.lt) | ||
⟨⟨mdata⟩, strictIndexes⟩ | ||
|
||
/-- Extract the certificate from the `vec` found by `findPositiveVector`. -/ | ||
def postprocess (vec : Array ℚ) : HashMap ℕ ℕ := | ||
let common_den : ℕ := vec.foldl (fun acc item => acc.lcm item.den) 1 | ||
let vecNat : Array ℕ := vec.map (fun x : ℚ => (x * common_den).floor.toNat) | ||
HashMap.ofList <| vecNat.toList.enum.filter (fun ⟨_, item⟩ => item != 0) | ||
|
||
|
||
end SimplexAlgorithm | ||
|
||
open SimplexAlgorithm | ||
|
||
/-- An oracle that uses the simplex algorithm. -/ | ||
def CertificateOracle.simplexAlgorithm : CertificateOracle where | ||
produceCertificate hyps maxVar := do | ||
let ⟨A, strictIndexes⟩ := preprocess hyps maxVar | ||
let vec := findPositiveVector A strictIndexes | ||
return postprocess vec | ||
|
||
end Linarith |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/- | ||
Copyright (c) 2024 Vasily Nesterov. All rights reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: Vasily Nesterov | ||
-/ | ||
import Batteries.Data.Rat.Basic | ||
|
||
/-! | ||
# Datatypes for Simplex Algorithm implementation | ||
-/ | ||
|
||
namespace Linarith.SimplexAlgorithm | ||
|
||
/-- | ||
Structure for matrices over ℚ. | ||
So far it is just a 2d-array carrying dimensions (that are supposed to match with the actual | ||
dimensions of `data`), but the plan is to add some `Prop`-data and make the structure strict and | ||
safe. | ||
Note: we avoid using the `Matrix` from `Mathlib.Data.Matrix` because it is far more efficient to | ||
store matrix as its entries than as function between `Fin`-s. | ||
-/ | ||
structure Matrix (n m : Nat) where | ||
/-- The content of the matrix. -/ | ||
data : Array (Array Rat) | ||
-- hn_pos : n > 0 | ||
-- hm_pos : m > 0 | ||
-- hn : data.size = n | ||
-- hm (i : Fin n) : data[i].size = m | ||
|
||
instance (n m : Nat) : GetElem (Matrix n m) Nat (Array Rat) fun _ i => i < n where | ||
getElem mat i _ := mat.data[i]! | ||
|
||
/-- | ||
`Table` is a structure Simplex Algorithm operates on. The `i`-th row of `mat` expresses the | ||
variable `basic[i]` as a linear combination of variables from `free`. | ||
-/ | ||
structure Table where | ||
/-- Array containing the basic variables' indexes -/ | ||
basic : Array Nat | ||
/-- Array containing the free variables' indexes -/ | ||
free : Array Nat | ||
/-- Matrix of coefficients the basic variables expressed through the free ones. -/ | ||
mat : Matrix basic.size free.size | ||
|
||
end Linarith.SimplexAlgorithm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
/- | ||
Copyright (c) 2024 Vasily Nesterov. All rights reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: Vasily Nesterov | ||
-/ | ||
import Mathlib.Tactic.Linarith.SimplexAlgorithm.Datatypes | ||
|
||
/-! | ||
# Gaussian Elimination algorithm | ||
The first step of `Linarith.SimplexAlgorithm.findPositiveVector` is finding initial feasible | ||
solution which is done by standard Gaussian Elimination algorithm implemented in this file. | ||
-/ | ||
|
||
namespace Linarith.SimplexAlgorithm.Gauss | ||
|
||
/-- The monad for the Gaussian Elimination algorithm. -/ | ||
abbrev GaussM (n m : Nat) := StateM <| Matrix n m | ||
|
||
/-- Finds the first row starting from the current row with nonzero element in current column. -/ | ||
def findNonzeroRow (row col : Nat) {n m : Nat} : GaussM n m <| Option Nat := do | ||
for i in [row:n] do | ||
if (← get)[i]![col]! != 0 then | ||
return i | ||
return .none | ||
|
||
/-- Swaps two rows. -/ | ||
def swapRows {n m : Nat} (i j : Nat) : GaussM n m Unit := do | ||
if i != j then | ||
modify fun mat => | ||
let swapped : Matrix n m := ⟨mat.data.swap! i j⟩ | ||
swapped | ||
|
||
/-- Subtracts `i`-th row * `coef` from `j`-th row. -/ | ||
def subtractRow {n m : Nat} (i j : Nat) (coef : Rat) : GaussM n m Unit := | ||
modify fun mat => | ||
let newData : Array (Array Rat) := mat.data.modify j fun row => | ||
row.zipWith mat[i]! fun x y => x - coef * y | ||
⟨newData⟩ | ||
|
||
/-- Divides row by `coef`. -/ | ||
def divideRow {n m : Nat} (i : Nat) (coef : Rat) : GaussM n m Unit := | ||
modify fun mat => | ||
let newData : Array (Array Rat) := mat.data.modify i (·.map (· / coef)) | ||
⟨newData⟩ | ||
|
||
/-- Implementation of `getTable` in `GaussM` monad. -/ | ||
def getTableImp {n m : Nat} : GaussM n m Table := do | ||
let mut free : Array Nat := #[] | ||
let mut basic : Array Nat := #[] | ||
|
||
let mut row : Nat := 0 | ||
let mut col : Nat := 0 | ||
|
||
while row < n && col < m do | ||
match ← findNonzeroRow row col with | ||
| .none => | ||
free := free.push col | ||
col := col + 1 | ||
continue | ||
| .some rowToSwap => | ||
swapRows row rowToSwap | ||
|
||
divideRow row (← get)[row]![col]! | ||
|
||
for i in [:n] do | ||
if i == row then | ||
continue | ||
let coef := (← get)[i]![col]! | ||
subtractRow row i coef | ||
|
||
basic := basic.push col | ||
row := row + 1 | ||
col := col + 1 | ||
|
||
for i in [col:m] do | ||
free := free.push i | ||
|
||
let ansData : Array (Array Rat) := ← do | ||
let mat := (← get) | ||
return Array.ofFn (fun row : Fin row => free.map fun f => -mat[row]![f]!) | ||
|
||
return { | ||
free := free | ||
basic := basic | ||
mat := ⟨ansData⟩ | ||
} | ||
|
||
/-- | ||
Given matrix `A`, solves the linear equation `A x = 0` and returns the solution as a table where | ||
some variables are free and others (basic) variable are expressed as linear combinations of the free | ||
ones. | ||
-/ | ||
def getTable {n m : Nat} (A : Matrix n m) : Table := Id.run do | ||
return (← getTableImp.run A).fst | ||
|
||
end Linarith.SimplexAlgorithm.Gauss |
79 changes: 79 additions & 0 deletions
79
Mathlib/Tactic/Linarith/SimplexAlgorithm/PositiveVector.lean
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/- | ||
Copyright (c) 2024 Vasily Nesterov. All rights reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: Vasily Nesterov | ||
-/ | ||
import Mathlib.Tactic.Linarith.SimplexAlgorithm.SimplexAlgorithm | ||
import Mathlib.Tactic.Linarith.SimplexAlgorithm.Gauss | ||
|
||
/-! | ||
# `linarith` certificate search a LP problem | ||
`linarith` certificate search can easily be reduced to this LP problem: given the matrix `A` and the | ||
list `strictIndexes`, find the non-negative vector `v` such that some of its coordinates from | ||
the `strictIndexes` are positive and `A v = 0`. | ||
The function `findPositiveVector` solves this problem. | ||
-/ | ||
|
||
namespace Linarith.SimplexAlgorithm | ||
|
||
/-- | ||
Given matrix `A` and list `strictIndexes` of strict inequalities' indexes, we want to state the | ||
Linear Programming problem which solution produces solution for the initial problem (see | ||
`findPositiveVector`). | ||
As an objective function (that we are trying to maximize) we use sum of coordinates from | ||
`strictIndexes`: it suffices to find the non-negative vector that makes this function positive. | ||
We introduce two auxiliary variables and one constraint: | ||
* The variable `y` is interpreted as "homogenized" `1`. We need it because dealing with a | ||
homogenized problem is easier, but having some "unit" is necessary. | ||
* To bound the problem we add the constraint `x₁ + ... + xₘ + z = y` introducing new variable `z`. | ||
The objective function also interpreted as an auxiliary variable with constraint | ||
`f = ∑ i ∈ strictIndexes, xᵢ`. | ||
The variable `f` has to always be basic while `y` has to be free. Our Gauss method implementation | ||
greedy collects basic variables moving from left to right. So we place `f` before `x`-s and `y` | ||
after them. We place `z` between `f` and `x` because in this case `z` will be basic and | ||
`Gauss.getTable` produce table with non-negative last column, meaning that we are starting from | ||
a feasible point. | ||
-/ | ||
def stateLP {n m : Nat} (A : Matrix n m) (strictIndexes : List Nat) : Matrix (n + 2) (m + 3) := | ||
Id.run do | ||
let mut objectiveRow : Array Rat := #[-1, 0] ++ (Array.mkArray m 0) ++ #[0] | ||
for idx in strictIndexes do | ||
objectiveRow := objectiveRow.set! (idx + 2) 1 -- +2 due to shifting by `f` and `z` | ||
|
||
let constraintRow : Array Rat := #[0, 1] ++ (Array.mkArray m 1) ++ #[-1] | ||
|
||
let data : Array (Array Rat) := #[objectiveRow, constraintRow] | ||
++ A.data.map (#[0, 0] ++ · ++ #[0]) | ||
|
||
return ⟨data⟩ | ||
|
||
/-- Extracts target vector from the table, putting auxilary variables aside (see `stateLP`). -/ | ||
def extractSolution (table : Table) : Array Rat := Id.run do | ||
let mut ans : Array Rat := Array.mkArray (table.basic.size + table.free.size - 3) 0 | ||
for i in [1:table.mat.data.size] do | ||
ans := ans.set! (table.basic[i]! - 2) table.mat.data[i]!.back | ||
return ans | ||
|
||
/-- | ||
Finds nonnegative vector `v`, such that `A v = 0` and some of its coordinates from `strictCoords` | ||
are positive, in the case such `v` exists. | ||
-/ | ||
def findPositiveVector {n m : Nat} (A : Matrix n m) (strictIndexes : List Nat) : Array Rat := | ||
/- State the linear programming problem. -/ | ||
let B := stateLP A strictIndexes | ||
|
||
/- Using Gaussian elimination split variable into free and basic forming the table that will be | ||
operated by Simplex Algorithm. -/ | ||
let initTable := Gauss.getTable B | ||
|
||
/- Run Simplex Algorithm and extract the solution. -/ | ||
let resTable := runSimplexAlgorithm initTable | ||
extractSolution resTable | ||
|
||
end Linarith.SimplexAlgorithm |
Oops, something went wrong.