Skip to content

Commit

Permalink
PI-week: Refactor Transpile (#2720)
Browse files Browse the repository at this point in the history
* State- refactor

* fmt-fix
  • Loading branch information
Kukovec authored Sep 14, 2023
1 parent 5c0e659 commit a487206
Show file tree
Hide file tree
Showing 23 changed files with 441 additions and 458 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package at.forsyte.apalache.tla.bmcmt.rules.vmt
import at.forsyte.apalache.tla.bmcmt.RewriterException
import at.forsyte.apalache.tla.lir.formulas.Booleans._
import at.forsyte.apalache.tla.lir.formulas.Term
import at.forsyte.apalache.tla.lir.{OperEx, TlaEx}
import at.forsyte.apalache.tla.lir.oper.TlaBoolOper
import at.forsyte.apalache.tla.lir.{OperEx, TlaEx}

/**
* BoolRule defines translations for reTLA patterns which use operators from propositional logic.
Expand All @@ -20,16 +19,24 @@ class BoolRule(rewriter: ToTermRewriter) extends FormulaRule {
}

// convenience shorthand
private def rewrite: TlaEx => Term = rewriter.rewrite
private def rewrite: TlaEx => TermBuilderT = rewriter.rewrite

// Assume isApplicable
override def apply(ex: TlaEx): BoolExpr =
override def apply(ex: TlaEx): TermBuilderT =
ex match {
case OperEx(TlaBoolOper.and, args @ _*) => And(args.map(rewrite): _*)
case OperEx(TlaBoolOper.or, args @ _*) => Or(args.map(rewrite): _*)
case OperEx(TlaBoolOper.not, arg) => Neg(rewrite(arg))
case OperEx(TlaBoolOper.implies, lhs, rhs) => Impl(rewrite(lhs), rewrite(rhs))
case OperEx(TlaBoolOper.equiv, lhs, rhs) => Equiv(rewrite(lhs), rewrite(rhs))
case _ => throw new RewriterException(s"BoolRule not applicable to $ex", ex)
case OperEx(TlaBoolOper.and, args @ _*) => cmpSeq(args.map(rewrite)).map { seq => And(seq: _*) }
case OperEx(TlaBoolOper.or, args @ _*) => cmpSeq(args.map(rewrite)).map { seq => Or(seq: _*) }
case OperEx(TlaBoolOper.not, arg) => rewrite(arg).map(Neg)
case OperEx(TlaBoolOper.implies, lhs, rhs) =>
for {
lhsTerm <- rewrite(lhs)
rhsTerm <- rewrite(rhs)
} yield Impl(lhsTerm, rhsTerm)
case OperEx(TlaBoolOper.equiv, lhs, rhs) =>
for {
lhsTerm <- rewrite(lhs)
rhsTerm <- rewrite(rhs)
} yield Equiv(lhsTerm, rhsTerm)
case _ => throw new RewriterException(s"BoolRule not applicable to $ex", ex)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ class EUFRule(rewriter: ToTermRewriter, restrictedSetJudgement: RestrictedSetJud
private def isRestrictedSet(ex: TlaEx) = restrictedSetJudgement.isRestrictedSet(ex)

/**
* When translating g = [f EXCEPT ![x1,...,xn] = a], we need to construct a VMT function representation, which differs
* When translating g = [f EXCEPT ![x1,...,xn] = a], we need to construct a function representation, which differs
* from that of `f` at exactly one point.
*
* Given a function f: (S1,...,Sn) => S, constructed with the rule f(y1,...,yn) = ef, arguments `x1: S1, ..., xn: Sn`,
* and an expression `a`, constructs a function definition for g, with the rule:
* ```
* {{{
* g(y1, ... yn) = if y1 = x1 /\ ... /\ yn = xn
* then a
* else ef
* ```
* }}}
* @param fnArgTerms
* the values `x1, ..., xn`
* @param newCaseTerm
Expand All @@ -59,10 +59,10 @@ class EUFRule(rewriter: ToTermRewriter, restrictedSetJudgement: RestrictedSetJud
* @return
*/
private def exceptAsNewFunDef(
fnArgTerms: List[Term],
fnArgTerms: Seq[Term],
newCaseTerm: Term,
)(args: List[(String, Sort)],
baseCaseTerm: Term): FunDef = {
args: Seq[(String, Sort)],
baseCaseTerm: Term): TermBuilderT = {
// sanity check
assert(args.length == fnArgTerms.length)

Expand All @@ -79,88 +79,45 @@ class EUFRule(rewriter: ToTermRewriter, restrictedSetJudgement: RestrictedSetJud
case _ => And(matchConds: _*)
}

FunDef(gen.newName(), args, ITE(ifCondition, newCaseTerm, baseCaseTerm))
// We store the new definition, but return a Term (computation)
defineAndUse(gen.newName(), args, ITE(ifCondition, newCaseTerm, baseCaseTerm))
}

// Convenience shorthand
private def rewrite: TlaEx => Term = rewriter.rewrite

// Applies a fixed renaming scheme to a term tree (e.g. renames all instances of "a" to "b")
private def replaceFixedLeaf(replacement: Map[Term, Term])(t: Term): Term = {
val replace = replaceFixedLeaf(replacement) _
t match {
case ITE(ifTerm, thenTerm, elseTerm) => ITE(replace(ifTerm), replace(thenTerm), replace(elseTerm))
case Apply(term, terms @ _*) => Apply(replace(term), terms.map(replace): _*)
case And(terms @ _*) => And(terms.map(replace): _*)
case Or(terms @ _*) => Or(terms.map(replace): _*)
case Equiv(lhs, rhs) => Equiv(replace(lhs), replace(rhs))
case Equal(lhs, rhs) => Equal(replace(lhs), replace(rhs))
case Neg(term) => Neg(replace(term))
case Impl(lhs, rhs) => Impl(replace(lhs), replace(rhs))
case Forall(boundVars, term) => Forall(boundVars, replace(term))
case Exists(boundVars, term) => Exists(boundVars, replace(term))
case FunDef(name, args, body) => FunDef(name, args, replace(body))
case _ => replacement.getOrElse(t, t)
}
}
private def rewrite: TlaEx => TermBuilderT = rewriter.rewrite

// Creates a function application from a list of arguments
private def mkApp(fnTerm: Term, args: List[(String, Sort)]): Term =
Apply(fnTerm, args.map { case (name, sort) => mkVariable(name, sort) }: _*)

// Creates axiomatic function equality. Functions f and g are equal over a domain S1 \X ... \X Sn, iff
// \A (x1,...,xn) \in S1 \X ... \X Sn: f(x1,...,xn) = g(x1,...xn)
// domain ... List((x1,S1), ..., (xn,Sn))
// lhs ... f(x1,...,xn)
// rhs ... g(x1,...,xn)
private def mkAxiomaticFunEq(domain: List[(String, Sort)], lhs: Term, rhs: Term): Term =
Forall(domain, Equal(lhs, rhs))

// Assume isApplicable
override def apply(ex: TlaEx): Term =
override def apply(ex: TlaEx): TermBuilderT =
ex match {
case OperEx(TlaOper.eq | ApalacheOper.assign, lhs, rhs) =>
// := is just = in VMT
val lhsTerm = rewrite(lhs)
val rhsTerm = rewrite(rhs)

// Assume sorts are equal, so we just match the left sort.
// If not, this will get caught by Equal's `require` anyhow.
lhsTerm.sort match {
// For functions, do axiomatic equality, i.e.
// f = g <=> \A x \in DOMAIN f: f[x] = g[x]
case FunctionSort(_, from @ _*) =>
(lhsTerm, rhsTerm) match {
case (FunDef(_, largs, lbody), FunDef(_, rargs, rbody)) =>
// sanity check
assert(largs.length == rargs.length)
// We arbitrarily pick one set of argnames (the left), and rename the other body.
// We rename all instances of rargs in rbody to same-indexed largs
val renameMap: Map[Term, Term] = largs
.zip(rargs)
.map { case ((largName, lsort), (rargName, rsort)) =>
mkVariable(rargName, rsort) -> mkVariable(largName, lsort)
}
.toMap
val renamedRBody = replaceFixedLeaf(renameMap)(rbody)
mkAxiomaticFunEq(largs, lbody, renamedRBody)

case (fv: FunctionVar, FunDef(_, args, body)) =>
mkAxiomaticFunEq(args, mkApp(fv, args), body)
case (FunDef(_, args, body), fv: FunctionVar) =>
mkAxiomaticFunEq(args, mkApp(fv, args), body)
case (lvar: FunctionVar, rvar: FunctionVar) =>
// we just invent formal argument names
val inventedVars = from.toList.map { s =>
(gen.newName(), s)
}
mkAxiomaticFunEq(inventedVars, mkApp(lvar, inventedVars), mkApp(rvar, inventedVars))

case _ => Equal(lhsTerm, rhsTerm)
}

// Otherwise, do direct equality
case _ => Equal(lhsTerm, rhsTerm)
// := is just = in SMT
for {
lhsTerm <- rewrite(lhs)
rhsTerm <- rewrite(rhs)
} yield {
require(lhsTerm.sort == rhsTerm.sort, "Equality requires terms of equal Sorts.")
// Assume sorts are equal, so we just match the left sort.
lhsTerm.sort match {
// For functions, do axiomatic equality, i.e.
// f = g <=> \A x \in DOMAIN f: f[x] = g[x]
case FunctionSort(_, from @ _*) =>
// we just invent formal argument names
val inventedVars = from.toList.map { s =>
(gen.newName(), s)
}
// Creates axiomatic function equality. Functions f and g are equal over a domain S1 \X ... \X Sn, iff
// \A (x1,...,xn) \in S1 \X ... \X Sn: f(x1,...,xn) = g(x1,...xn)
// domain ... Seq((x1,S1), ..., (xn,Sn))
// lhs ... f(x1,...,xn)
// rhs ... g(x1,...,xn)
Forall(inventedVars, Equal(mkApp(lhsTerm, inventedVars), mkApp(rhsTerm, inventedVars)))
// Otherwise, do direct equality
case _ => Equal(lhsTerm, rhsTerm)
}
}

case OperEx(TlaFunOper.funDef, e, varsAndSets @ _*)
Expand All @@ -173,72 +130,64 @@ class EUFRule(rewriter: ToTermRewriter, restrictedSetJudgement: RestrictedSetJud
case (NameEx(name), sort) => (name, sort)
case (ex, _) => throw new RewriterException(s"$ex must be a name.", ex)
}
FunDef(gen.newName(), argList, rewrite(e))
for {
rewrittenE <- rewrite(e)
funTerm <- defineAndUse(gen.newName(), argList, rewrittenE)
} yield funTerm

case OperEx(TlaFunOper.app, fn, arg) =>
val fnTerm = rewrite(fn)
val fnTermCmp = rewrite(fn)
// Arity 2+ functions pack their arguments as a single tuple, which we might need to unpack
val appArgs = arg match {
val appArgsCmp = arg match {
case OperEx(TlaFunOper.tuple, args @ _*) => args.map(rewrite)
case _ => Seq(rewrite(arg))
}

// When applying a FunDef, we inline it
fnTerm match {
case FunDef(_, args, body) =>
// sanity check
assert(args.length == appArgs.length)

val replacementMap: Map[Term, Term] = args
.zip(appArgs)
.map { case ((argName, argSort), concrete) =>
mkVariable(argName, argSort) -> concrete
}
.toMap
// For a function with a rule f(x1,...,xn) = e
// we inline f(a1,...,an) to e[x1\a1,...,xn\an]
replaceFixedLeaf(replacementMap)(body)
case _ => Apply(fnTerm, appArgs: _*)
}
for {
fnTerm <- fnTermCmp
args <- cmpSeq(appArgsCmp)
} yield Apply(fnTerm, args: _*)

case OperEx(TlaFunOper.except, fn, arg, newVal) =>
val valTerm = rewrite(newVal)
val valTermCmp = rewrite(newVal)
// Toplevel, arg is always a TLaFunOper.tuple. Within, it's either a single value, or another
// tuple, in the case of arity 2+ functions
val fnArgTerms = arg match {
val fnArgTermsCmp = arg match {
// ![a,b,...] case
case OperEx(TlaFunOper.tuple, OperEx(TlaFunOper.tuple, args @ _*)) =>
args.toList.map(rewrite)
args.map(rewrite)
// ![a] case
case OperEx(TlaFunOper.tuple, arg) =>
List(rewrite(arg))
Seq(rewrite(arg))

case invalidArg =>
throw new IllegalArgumentException(s"Invalid arg for TlaFunOper.except in EUFRule: ${invalidArg}")
}

val exceptTermFn = exceptAsNewFunDef(fnArgTerms, valTerm) _

// Assume fnArgTerms is nonempty.
// We have two scenarios: the original function is either defined, or symbolic
// If it is defined, then we have a rule and arguments, which we can just reuse
// If it is symbolic, we need to invent new variable names and apply it.
// If it is ever the case, in the future, that this is slow, we can change this code
// to use Apply in both cases.
rewrite(fn) match {
case FunDef(_, args, oldFnBody) =>
exceptTermFn(args, oldFnBody)
case fVar @ FunctionVar(_, FunctionSort(_, from @ _*)) =>
val fnArgPairs = from.toList.map { sort => (gen.newName(), sort) }
val appArgs = fnArgPairs.map { case (varName, varSort) =>
mkVariable(varName, varSort)
for {
valTerm <- valTermCmp
fnArgTerms <- cmpSeq(fnArgTermsCmp)
fnTerm <- rewrite(fn)
newFnTerm <-
// Assume fnArgTerms is nonempty.
// We need to invent new variable names and apply the function.
fnTerm match {
case fVar @ FunctionVar(_, FunctionSort(_, from @ _*)) =>
val fnArgPairs = from.toList.map { sort => (gen.newName(), sort) }
val appArgs = fnArgPairs.map { case (varName, varSort) =>
mkVariable(varName, varSort)
}
exceptAsNewFunDef(fnArgTerms, valTerm, fnArgPairs, Apply(fVar, appArgs: _*))
case _ => throw new RewriterException(s"$fn must be a function term.", fn)
}
exceptTermFn(fnArgPairs, Apply(fVar, appArgs: _*))
case _ => throw new RewriterException(s"$fn must be a function term.", fn)
}
} yield newFnTerm

case OperEx(TlaControlOper.ifThenElse, condEx, thenEx, elseEx) =>
ITE(rewrite(condEx), rewrite(thenEx), rewrite(elseEx))
for {
condTerm <- rewrite(condEx)
thenTerm <- rewrite(thenEx)
elseTerm <- rewrite(elseEx)
} yield ITE(condTerm, thenTerm, elseTerm)

case _ => throw new RewriterException(s"EUFRule not applicable to $ex", ex)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
package at.forsyte.apalache.tla.bmcmt.rules.vmt

import at.forsyte.apalache.tla.lir.TlaEx
import at.forsyte.apalache.tla.lir.formulas.Term

/**
* FormulaRule is analogous to RewritingRule, except that it produces a Term translation directly. It is side-effect
* free, instead of mutating the arena and solver context.
* FormulaRule is analogous to RewritingRule, except that it produces a Term translation directly, while possibly
* discharging declarations. It is side-effect free, instead of mutating the arena and solver context.
*
* @author
* Jure Kukovec
*/
trait FormulaRule {
def isApplicable(ex: TlaEx): Boolean

def apply(ex: TlaEx): Term
def apply(ex: TlaEx): TermBuilderT
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package at.forsyte.apalache.tla.bmcmt.rules.vmt
import at.forsyte.apalache.tla.bmcmt.RewriterException
import at.forsyte.apalache.tla.lir.formulas.Booleans.{BoolExpr, Exists, Forall}
import at.forsyte.apalache.tla.lir.formulas.Term
import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx}
import at.forsyte.apalache.tla.lir.formulas.Booleans.{Exists, Forall}
import at.forsyte.apalache.tla.lir.formulas.{Sort, Term}
import at.forsyte.apalache.tla.lir.oper.TlaBoolOper
import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx}

/**
* QuantifierRule defines translations for quantified expressions in reTLA.
Expand All @@ -22,15 +22,24 @@ class QuantifierRule(rewriter: ToTermRewriter, restrictedSetJudgement: Restricte
private def isRestrictedSet(ex: TlaEx) = restrictedSetJudgement.isRestrictedSet(ex)

// Convenience shorthand
private def rewrite: TlaEx => Term = rewriter.rewrite
private def rewrite: TlaEx => TermBuilderT = rewriter.rewrite

// Both \E and \A translate the same, up to the constructor name
private def mk(Ctor: (Seq[(String, Sort)], Term) => Term)(name: String, set: TlaEx, pred: TlaEx): TermBuilderT = {
val setSort = restrictedSetJudgement.getSort(set)
for {
_ <- storeUninterpretedSort(setSort)
predTerm <- rewrite(pred)
} yield Ctor(Seq((name, setSort)), predTerm)
}

// No magic here, all quantifiers in reTLA have fixed arity and are 1-to-1 with SMT quantifiers
override def apply(ex: TlaEx): BoolExpr =
override def apply(ex: TlaEx): TermBuilderT =
ex match {
case OperEx(TlaBoolOper.exists, NameEx(name), set, pred) if isRestrictedSet(set) =>
Exists(List((name, restrictedSetJudgement.getSort(set))), rewrite(pred))
mk(Exists)(name, set, pred)
case OperEx(TlaBoolOper.forall, NameEx(name), set, pred) if isRestrictedSet(set) =>
Forall(List((name, restrictedSetJudgement.getSort(set))), rewrite(pred))
mk(Forall)(name, set, pred)
case _ =>
throw new RewriterException(s"QuantifierRule not applicable to $ex", ex)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class RestrictedSetJudgement(constSets: Map[String, UninterpretedSort]) {
ex match {
case ValEx(s: TlaPredefSet) =>
s match {
case TlaIntSet | TlaNatSet => IntSort()
case TlaBoolSet => BoolSort()
case TlaIntSet | TlaNatSet => IntSort
case TlaBoolSet => BoolSort
case _ => throw new RewriterException(s"$s not supported in reTLA", ex)
}
case NameEx(name) if constSets.contains(name) => constSets(name)
Expand Down
Loading

0 comments on commit a487206

Please sign in to comment.