Skip to content

Commit

Permalink
refactor: fuse nested mkCongrArg calls (leanprover#3203)
Browse files Browse the repository at this point in the history
Encouraged by the performance gains from making `rewrite` produce
smaller proof objects
(leanprover#3121) I am here looking for low-hanging fruit in `simp`.

Consider this typical example:

```
set_option pp.explicit true

theorem test
  (a : Nat)
  (b : Nat)
  (c : Nat)
  (heq : a = b)
  (h : (c.add (c.add ((c.add b).add c))).add c = c)
  : (c.add (c.add ((c.add a).add c))).add c = c
```
We get a rather nice proof term when using
```
  := by rw [heq]; assumption
```
namely
```
theorem test : ∀ (a b c : Nat),
  @eq Nat a b →
    @eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c b) c))) c) c →
      @eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c a) c))) c) c :=
fun a b c heq h =>
  @Eq.mpr (@eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c a) c))) c) c)
    (@eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c b) c))) c) c)
    (@congrArg Nat Prop a b (fun _a => @eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c _a) c))) c) c) heq) h
```
(this is with leanprover#3121).

But with `by simp only [heq]; assumption`, it looks rather different:

```
theorem test : ∀ (a b c : Nat),
  @eq Nat a b →
    @eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c b) c))) c) c →
      @eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c a) c))) c) c :=
fun a b c heq h =>
  @Eq.mpr (@eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c a) c))) c) c)
    (@eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c b) c))) c) c)
    (@id
      (@eq Prop (@eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c a) c))) c) c)
        (@eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c b) c))) c) c))
      (@congrFun Nat (fun a => Prop) (@eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c a) c))) c))
        (@eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c b) c))) c))
        (@congrArg Nat (Nat → Prop) (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c a) c))) c)
          (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c b) c))) c) (@eq Nat)
          (@congrFun Nat (fun a => Nat) (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c a) c))))
            (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c b) c))))
            (@congrArg Nat (Nat → Nat) (Nat.add c (Nat.add c (Nat.add (Nat.add c a) c)))
              (Nat.add c (Nat.add c (Nat.add (Nat.add c b) c))) Nat.add
              (@congrArg Nat Nat (Nat.add c (Nat.add (Nat.add c a) c)) (Nat.add c (Nat.add (Nat.add c b) c)) (Nat.add c)
                (@congrArg Nat Nat (Nat.add (Nat.add c a) c) (Nat.add (Nat.add c b) c) (Nat.add c)
                  (@congrFun Nat (fun a => Nat) (Nat.add (Nat.add c a)) (Nat.add (Nat.add c b))
                    (@congrArg Nat (Nat → Nat) (Nat.add c a) (Nat.add c b) Nat.add
                      (@congrArg Nat Nat a b (Nat.add c) heq))
                    c))))
            c))
        c))
    h
```
Since simp uses only single-step `congrArg`/`congrFun` congruence lemmas
here, the proof
term grows very large, likely quadratic in this case.

Can we do better? Every nesting of `congrArg` (and it's little brother
`congrFun`) can be
turned into a single `congrArg` call. 

In this PR I make making the smart app builders `Meta.mkCongrArg` and
`Meta.mkCongrFun` a bit
smarter and not only fuse with `Eq.refl`, but also with
`congrArg`/`congrFun`.

Now we get, in this simple example,
```
theorem test : ∀ (a b c : Nat),
  @eq Nat a b →
    @eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c b) c))) c) c →
      @eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c a) c))) c) c :=
fun a b c heq h =>
  @Eq.mpr (@eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c a) c))) c) c)
    (@eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c b) c))) c) c)
    (@congrArg Nat Prop a b (fun x => @eq Nat (Nat.add (Nat.add c (Nat.add c (Nat.add (Nat.add c x) c))) c) c) heq) h
```

Let’s see if it works and how much we gain.
  • Loading branch information
nomeata authored Jan 25, 2024
1 parent 550fa69 commit de23226
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 13 deletions.
46 changes: 41 additions & 5 deletions src/Lean/Meta/AppBuilder.lean
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,41 @@ def mkEqOfHEq (h : Expr) : MetaM Expr := do
| _ =>
throwAppBuilderException ``HEq.trans m!"heterogeneous equality proof expected{indentExpr h}"

/--
If `e` is `@Eq.refl α a`, return `a`.
-/
def isRefl? (e : Expr) : Option Expr := do
if e.isAppOfArity ``Eq.refl 2 then
some e.appArg!
else
none

/--
If `e` is `@congrArg α β a b f h`, return `α`, `f` and `h`.
Also works if `e` can be turned into such an application (e.g. `congrFun`).
-/
def congrArg? (e : Expr) : MetaM (Option (Expr × Expr × Expr )) := do
if e.isAppOfArity ``congrArg 6 then
let #[α, _β, _a, _b, f, h] := e.getAppArgs | unreachable!
return some (α, f, h)
if e.isAppOfArity ``congrFun 6 then
let #[α, β, _f, _g, h, a] := e.getAppArgs | unreachable!
let α' ← withLocalDecl `x .default α fun x => do
mkForallFVars #[x] (β.beta #[x])
let f' ← withLocalDecl `x .default α' fun f => do
mkLambdaFVars #[f] (f.app a)
return some (α', f', h)
return none

/-- Given `f : α → β` and `h : a = b`, returns a proof of `f a = f b`.-/
def mkCongrArg (f h : Expr) : MetaM Expr := do
if h.isAppOf ``Eq.refl then
mkEqRefl (mkApp f h.appArg!)
partial def mkCongrArg (f h : Expr) : MetaM Expr := do
if let some a := isRefl? h then
mkEqRefl (mkApp f a)
else if let some (α, f₁, h₁) ← congrArg? h then
-- Fuse nested `congrArg` for smaller proof terms, e.g. when using simp
let f' ← withLocalDecl `x .default α fun x => do
mkLambdaFVars #[x] (f.beta #[f₁.beta #[x]])
mkCongrArg f' h₁
else
let hType ← infer h
let fType ← infer f
Expand All @@ -181,8 +212,13 @@ def mkCongrArg (f h : Expr) : MetaM Expr := do

/-- Given `h : f = g` and `a : α`, returns a proof of `f a = g a`.-/
def mkCongrFun (h a : Expr) : MetaM Expr := do
if h.isAppOf ``Eq.refl then
mkEqRefl (mkApp h.appArg! a)
if let some f := isRefl? h then
mkEqRefl (mkApp f a)
else if let some (α, f₁, h₁) ← congrArg? h then
-- Fuse nested `congrArg` for smaller proof terms, e.g. when using simp
let f' ← withLocalDecl `x .default α fun x => do
mkLambdaFVars #[x] (f₁.beta #[x, a])
mkCongrArg f' h₁
else
let hType ← infer h
match hType.eq? with
Expand Down
14 changes: 6 additions & 8 deletions tests/lean/simpZetaFalse.lean.expected.out
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@ theorem ex1 : ∀ (x : Nat),
fun x h =>
Eq.mpr
(id
(congrFun
(congrArg Eq
(let_congr (Eq.refl (x * x)) fun y =>
ite_congr (Eq.trans (congrFun (congrArg Eq h) x) (eq_self x)) (fun a => Eq.refl 1) fun a =>
Eq.refl (y + 1)))
1))
(congrArg (fun x => x = 1)
(let_congr (Eq.refl (x * x)) fun y =>
ite_congr (Eq.trans (congrArg (fun x_1 => x_1 = x) h) (eq_self x)) (fun a => Eq.refl 1) fun a =>
Eq.refl (y + 1))))
(of_eq_true (eq_self 1))
x z : Nat
h : f (f x) = x
Expand All @@ -31,7 +29,7 @@ theorem ex2 : ∀ (x z : Nat),
y) =
z :=
fun x z h h' =>
Eq.mpr (id (congrFun (congrArg Eq (let_val_congr (fun y => y) h)) z))
Eq.mpr (id (congrArg (fun x => x = z) (let_val_congr (fun y => y) h)))
(of_eq_true (Eq.trans (congrArg (Eq x) h') (eq_self x)))
x z : Nat
⊢ (let α := Nat;
Expand All @@ -48,5 +46,5 @@ theorem ex4 : ∀ (p : Prop),
fun x => x = x) =
fun z => p :=
fun p h =>
Eq.mpr (id (congrFun (congrArg Eq (let_body_congr 10 fun n => funext fun x => eq_self x)) fun z => p))
Eq.mpr (id (congrArg (fun x => x = fun z => p) (let_body_congr 10 fun n => funext fun x => eq_self x)))
(of_eq_true (Eq.trans (congrArg (Eq fun x => True) (funext fun z => eq_true h)) (eq_self fun x => True)))

0 comments on commit de23226

Please sign in to comment.