diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/BoolRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/BoolRule.scala index 4c577c06ef..0ebcb2e005 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/BoolRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/BoolRule.scala @@ -1,9 +1,8 @@ package at.forsyte.apalache.tla.bmcmt.rules.vmt import at.forsyte.apalache.tla.bmcmt.RewriterException import at.forsyte.apalache.tla.lir.formulas.Booleans._ -import at.forsyte.apalache.tla.lir.formulas.Term -import at.forsyte.apalache.tla.lir.{OperEx, TlaEx} import at.forsyte.apalache.tla.lir.oper.TlaBoolOper +import at.forsyte.apalache.tla.lir.{OperEx, TlaEx} /** * BoolRule defines translations for reTLA patterns which use operators from propositional logic. @@ -20,16 +19,24 @@ class BoolRule(rewriter: ToTermRewriter) extends FormulaRule { } // convenience shorthand - private def rewrite: TlaEx => Term = rewriter.rewrite + private def rewrite: TlaEx => TermBuilderT = rewriter.rewrite // Assume isApplicable - override def apply(ex: TlaEx): BoolExpr = + override def apply(ex: TlaEx): TermBuilderT = ex match { - case OperEx(TlaBoolOper.and, args @ _*) => And(args.map(rewrite): _*) - case OperEx(TlaBoolOper.or, args @ _*) => Or(args.map(rewrite): _*) - case OperEx(TlaBoolOper.not, arg) => Neg(rewrite(arg)) - case OperEx(TlaBoolOper.implies, lhs, rhs) => Impl(rewrite(lhs), rewrite(rhs)) - case OperEx(TlaBoolOper.equiv, lhs, rhs) => Equiv(rewrite(lhs), rewrite(rhs)) - case _ => throw new RewriterException(s"BoolRule not applicable to $ex", ex) + case OperEx(TlaBoolOper.and, args @ _*) => cmpSeq(args.map(rewrite)).map { seq => And(seq: _*) } + case OperEx(TlaBoolOper.or, args @ _*) => cmpSeq(args.map(rewrite)).map { seq => Or(seq: _*) } + case OperEx(TlaBoolOper.not, arg) => rewrite(arg).map(Neg) + case OperEx(TlaBoolOper.implies, lhs, rhs) => + for { + lhsTerm <- rewrite(lhs) + rhsTerm <- rewrite(rhs) + } yield Impl(lhsTerm, rhsTerm) + case OperEx(TlaBoolOper.equiv, lhs, rhs) => + for { + lhsTerm <- rewrite(lhs) + rhsTerm <- rewrite(rhs) + } yield Equiv(lhsTerm, rhsTerm) + case _ => throw new RewriterException(s"BoolRule not applicable to $ex", ex) } } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/EUFRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/EUFRule.scala index 31ef9b576a..2bcce9db19 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/EUFRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/EUFRule.scala @@ -38,16 +38,16 @@ class EUFRule(rewriter: ToTermRewriter, restrictedSetJudgement: RestrictedSetJud private def isRestrictedSet(ex: TlaEx) = restrictedSetJudgement.isRestrictedSet(ex) /** - * When translating g = [f EXCEPT ![x1,...,xn] = a], we need to construct a VMT function representation, which differs + * When translating g = [f EXCEPT ![x1,...,xn] = a], we need to construct a function representation, which differs * from that of `f` at exactly one point. * * Given a function f: (S1,...,Sn) => S, constructed with the rule f(y1,...,yn) = ef, arguments `x1: S1, ..., xn: Sn`, * and an expression `a`, constructs a function definition for g, with the rule: - * ``` + * {{{ * g(y1, ... yn) = if y1 = x1 /\ ... /\ yn = xn * then a * else ef - * ``` + * }}} * @param fnArgTerms * the values `x1, ..., xn` * @param newCaseTerm @@ -59,10 +59,10 @@ class EUFRule(rewriter: ToTermRewriter, restrictedSetJudgement: RestrictedSetJud * @return */ private def exceptAsNewFunDef( - fnArgTerms: List[Term], + fnArgTerms: Seq[Term], newCaseTerm: Term, - )(args: List[(String, Sort)], - baseCaseTerm: Term): FunDef = { + args: Seq[(String, Sort)], + baseCaseTerm: Term): TermBuilderT = { // sanity check assert(args.length == fnArgTerms.length) @@ -79,88 +79,45 @@ class EUFRule(rewriter: ToTermRewriter, restrictedSetJudgement: RestrictedSetJud case _ => And(matchConds: _*) } - FunDef(gen.newName(), args, ITE(ifCondition, newCaseTerm, baseCaseTerm)) + // We store the new definition, but return a Term (computation) + defineAndUse(gen.newName(), args, ITE(ifCondition, newCaseTerm, baseCaseTerm)) } // Convenience shorthand - private def rewrite: TlaEx => Term = rewriter.rewrite - - // Applies a fixed renaming scheme to a term tree (e.g. renames all instances of "a" to "b") - private def replaceFixedLeaf(replacement: Map[Term, Term])(t: Term): Term = { - val replace = replaceFixedLeaf(replacement) _ - t match { - case ITE(ifTerm, thenTerm, elseTerm) => ITE(replace(ifTerm), replace(thenTerm), replace(elseTerm)) - case Apply(term, terms @ _*) => Apply(replace(term), terms.map(replace): _*) - case And(terms @ _*) => And(terms.map(replace): _*) - case Or(terms @ _*) => Or(terms.map(replace): _*) - case Equiv(lhs, rhs) => Equiv(replace(lhs), replace(rhs)) - case Equal(lhs, rhs) => Equal(replace(lhs), replace(rhs)) - case Neg(term) => Neg(replace(term)) - case Impl(lhs, rhs) => Impl(replace(lhs), replace(rhs)) - case Forall(boundVars, term) => Forall(boundVars, replace(term)) - case Exists(boundVars, term) => Exists(boundVars, replace(term)) - case FunDef(name, args, body) => FunDef(name, args, replace(body)) - case _ => replacement.getOrElse(t, t) - } - } + private def rewrite: TlaEx => TermBuilderT = rewriter.rewrite // Creates a function application from a list of arguments private def mkApp(fnTerm: Term, args: List[(String, Sort)]): Term = Apply(fnTerm, args.map { case (name, sort) => mkVariable(name, sort) }: _*) - // Creates axiomatic function equality. Functions f and g are equal over a domain S1 \X ... \X Sn, iff - // \A (x1,...,xn) \in S1 \X ... \X Sn: f(x1,...,xn) = g(x1,...xn) - // domain ... List((x1,S1), ..., (xn,Sn)) - // lhs ... f(x1,...,xn) - // rhs ... g(x1,...,xn) - private def mkAxiomaticFunEq(domain: List[(String, Sort)], lhs: Term, rhs: Term): Term = - Forall(domain, Equal(lhs, rhs)) - // Assume isApplicable - override def apply(ex: TlaEx): Term = + override def apply(ex: TlaEx): TermBuilderT = ex match { case OperEx(TlaOper.eq | ApalacheOper.assign, lhs, rhs) => - // := is just = in VMT - val lhsTerm = rewrite(lhs) - val rhsTerm = rewrite(rhs) - - // Assume sorts are equal, so we just match the left sort. - // If not, this will get caught by Equal's `require` anyhow. - lhsTerm.sort match { - // For functions, do axiomatic equality, i.e. - // f = g <=> \A x \in DOMAIN f: f[x] = g[x] - case FunctionSort(_, from @ _*) => - (lhsTerm, rhsTerm) match { - case (FunDef(_, largs, lbody), FunDef(_, rargs, rbody)) => - // sanity check - assert(largs.length == rargs.length) - // We arbitrarily pick one set of argnames (the left), and rename the other body. - // We rename all instances of rargs in rbody to same-indexed largs - val renameMap: Map[Term, Term] = largs - .zip(rargs) - .map { case ((largName, lsort), (rargName, rsort)) => - mkVariable(rargName, rsort) -> mkVariable(largName, lsort) - } - .toMap - val renamedRBody = replaceFixedLeaf(renameMap)(rbody) - mkAxiomaticFunEq(largs, lbody, renamedRBody) - - case (fv: FunctionVar, FunDef(_, args, body)) => - mkAxiomaticFunEq(args, mkApp(fv, args), body) - case (FunDef(_, args, body), fv: FunctionVar) => - mkAxiomaticFunEq(args, mkApp(fv, args), body) - case (lvar: FunctionVar, rvar: FunctionVar) => - // we just invent formal argument names - val inventedVars = from.toList.map { s => - (gen.newName(), s) - } - mkAxiomaticFunEq(inventedVars, mkApp(lvar, inventedVars), mkApp(rvar, inventedVars)) - - case _ => Equal(lhsTerm, rhsTerm) - } - - // Otherwise, do direct equality - case _ => Equal(lhsTerm, rhsTerm) + // := is just = in SMT + for { + lhsTerm <- rewrite(lhs) + rhsTerm <- rewrite(rhs) + } yield { + require(lhsTerm.sort == rhsTerm.sort, "Equality requires terms of equal Sorts.") + // Assume sorts are equal, so we just match the left sort. + lhsTerm.sort match { + // For functions, do axiomatic equality, i.e. + // f = g <=> \A x \in DOMAIN f: f[x] = g[x] + case FunctionSort(_, from @ _*) => + // we just invent formal argument names + val inventedVars = from.toList.map { s => + (gen.newName(), s) + } + // Creates axiomatic function equality. Functions f and g are equal over a domain S1 \X ... \X Sn, iff + // \A (x1,...,xn) \in S1 \X ... \X Sn: f(x1,...,xn) = g(x1,...xn) + // domain ... Seq((x1,S1), ..., (xn,Sn)) + // lhs ... f(x1,...,xn) + // rhs ... g(x1,...,xn) + Forall(inventedVars, Equal(mkApp(lhsTerm, inventedVars), mkApp(rhsTerm, inventedVars))) + // Otherwise, do direct equality + case _ => Equal(lhsTerm, rhsTerm) + } } case OperEx(TlaFunOper.funDef, e, varsAndSets @ _*) @@ -173,72 +130,64 @@ class EUFRule(rewriter: ToTermRewriter, restrictedSetJudgement: RestrictedSetJud case (NameEx(name), sort) => (name, sort) case (ex, _) => throw new RewriterException(s"$ex must be a name.", ex) } - FunDef(gen.newName(), argList, rewrite(e)) + for { + rewrittenE <- rewrite(e) + funTerm <- defineAndUse(gen.newName(), argList, rewrittenE) + } yield funTerm case OperEx(TlaFunOper.app, fn, arg) => - val fnTerm = rewrite(fn) + val fnTermCmp = rewrite(fn) // Arity 2+ functions pack their arguments as a single tuple, which we might need to unpack - val appArgs = arg match { + val appArgsCmp = arg match { case OperEx(TlaFunOper.tuple, args @ _*) => args.map(rewrite) case _ => Seq(rewrite(arg)) } - // When applying a FunDef, we inline it - fnTerm match { - case FunDef(_, args, body) => - // sanity check - assert(args.length == appArgs.length) - - val replacementMap: Map[Term, Term] = args - .zip(appArgs) - .map { case ((argName, argSort), concrete) => - mkVariable(argName, argSort) -> concrete - } - .toMap - // For a function with a rule f(x1,...,xn) = e - // we inline f(a1,...,an) to e[x1\a1,...,xn\an] - replaceFixedLeaf(replacementMap)(body) - case _ => Apply(fnTerm, appArgs: _*) - } + for { + fnTerm <- fnTermCmp + args <- cmpSeq(appArgsCmp) + } yield Apply(fnTerm, args: _*) case OperEx(TlaFunOper.except, fn, arg, newVal) => - val valTerm = rewrite(newVal) + val valTermCmp = rewrite(newVal) // Toplevel, arg is always a TLaFunOper.tuple. Within, it's either a single value, or another // tuple, in the case of arity 2+ functions - val fnArgTerms = arg match { + val fnArgTermsCmp = arg match { // ![a,b,...] case case OperEx(TlaFunOper.tuple, OperEx(TlaFunOper.tuple, args @ _*)) => - args.toList.map(rewrite) + args.map(rewrite) // ![a] case case OperEx(TlaFunOper.tuple, arg) => - List(rewrite(arg)) + Seq(rewrite(arg)) case invalidArg => throw new IllegalArgumentException(s"Invalid arg for TlaFunOper.except in EUFRule: ${invalidArg}") } - val exceptTermFn = exceptAsNewFunDef(fnArgTerms, valTerm) _ - - // Assume fnArgTerms is nonempty. - // We have two scenarios: the original function is either defined, or symbolic - // If it is defined, then we have a rule and arguments, which we can just reuse - // If it is symbolic, we need to invent new variable names and apply it. - // If it is ever the case, in the future, that this is slow, we can change this code - // to use Apply in both cases. - rewrite(fn) match { - case FunDef(_, args, oldFnBody) => - exceptTermFn(args, oldFnBody) - case fVar @ FunctionVar(_, FunctionSort(_, from @ _*)) => - val fnArgPairs = from.toList.map { sort => (gen.newName(), sort) } - val appArgs = fnArgPairs.map { case (varName, varSort) => - mkVariable(varName, varSort) + for { + valTerm <- valTermCmp + fnArgTerms <- cmpSeq(fnArgTermsCmp) + fnTerm <- rewrite(fn) + newFnTerm <- + // Assume fnArgTerms is nonempty. + // We need to invent new variable names and apply the function. + fnTerm match { + case fVar @ FunctionVar(_, FunctionSort(_, from @ _*)) => + val fnArgPairs = from.toList.map { sort => (gen.newName(), sort) } + val appArgs = fnArgPairs.map { case (varName, varSort) => + mkVariable(varName, varSort) + } + exceptAsNewFunDef(fnArgTerms, valTerm, fnArgPairs, Apply(fVar, appArgs: _*)) + case _ => throw new RewriterException(s"$fn must be a function term.", fn) } - exceptTermFn(fnArgPairs, Apply(fVar, appArgs: _*)) - case _ => throw new RewriterException(s"$fn must be a function term.", fn) - } + } yield newFnTerm case OperEx(TlaControlOper.ifThenElse, condEx, thenEx, elseEx) => - ITE(rewrite(condEx), rewrite(thenEx), rewrite(elseEx)) + for { + condTerm <- rewrite(condEx) + thenTerm <- rewrite(thenEx) + elseTerm <- rewrite(elseEx) + } yield ITE(condTerm, thenTerm, elseTerm) case _ => throw new RewriterException(s"EUFRule not applicable to $ex", ex) } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/FormulaRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/FormulaRule.scala index c60c631469..6d9825ef15 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/FormulaRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/FormulaRule.scala @@ -1,11 +1,10 @@ package at.forsyte.apalache.tla.bmcmt.rules.vmt import at.forsyte.apalache.tla.lir.TlaEx -import at.forsyte.apalache.tla.lir.formulas.Term /** - * FormulaRule is analogous to RewritingRule, except that it produces a Term translation directly. It is side-effect - * free, instead of mutating the arena and solver context. + * FormulaRule is analogous to RewritingRule, except that it produces a Term translation directly, while possibly + * discharging declarations. It is side-effect free, instead of mutating the arena and solver context. * * @author * Jure Kukovec @@ -13,5 +12,5 @@ import at.forsyte.apalache.tla.lir.formulas.Term trait FormulaRule { def isApplicable(ex: TlaEx): Boolean - def apply(ex: TlaEx): Term + def apply(ex: TlaEx): TermBuilderT } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/QuantifierRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/QuantifierRule.scala index 45f2609545..06c6b1f9e6 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/QuantifierRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/QuantifierRule.scala @@ -1,9 +1,9 @@ package at.forsyte.apalache.tla.bmcmt.rules.vmt import at.forsyte.apalache.tla.bmcmt.RewriterException -import at.forsyte.apalache.tla.lir.formulas.Booleans.{BoolExpr, Exists, Forall} -import at.forsyte.apalache.tla.lir.formulas.Term -import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx} +import at.forsyte.apalache.tla.lir.formulas.Booleans.{Exists, Forall} +import at.forsyte.apalache.tla.lir.formulas.{Sort, Term} import at.forsyte.apalache.tla.lir.oper.TlaBoolOper +import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx} /** * QuantifierRule defines translations for quantified expressions in reTLA. @@ -22,15 +22,24 @@ class QuantifierRule(rewriter: ToTermRewriter, restrictedSetJudgement: Restricte private def isRestrictedSet(ex: TlaEx) = restrictedSetJudgement.isRestrictedSet(ex) // Convenience shorthand - private def rewrite: TlaEx => Term = rewriter.rewrite + private def rewrite: TlaEx => TermBuilderT = rewriter.rewrite + + // Both \E and \A translate the same, up to the constructor name + private def mk(Ctor: (Seq[(String, Sort)], Term) => Term)(name: String, set: TlaEx, pred: TlaEx): TermBuilderT = { + val setSort = restrictedSetJudgement.getSort(set) + for { + _ <- storeUninterpretedSort(setSort) + predTerm <- rewrite(pred) + } yield Ctor(Seq((name, setSort)), predTerm) + } // No magic here, all quantifiers in reTLA have fixed arity and are 1-to-1 with SMT quantifiers - override def apply(ex: TlaEx): BoolExpr = + override def apply(ex: TlaEx): TermBuilderT = ex match { case OperEx(TlaBoolOper.exists, NameEx(name), set, pred) if isRestrictedSet(set) => - Exists(List((name, restrictedSetJudgement.getSort(set))), rewrite(pred)) + mk(Exists)(name, set, pred) case OperEx(TlaBoolOper.forall, NameEx(name), set, pred) if isRestrictedSet(set) => - Forall(List((name, restrictedSetJudgement.getSort(set))), rewrite(pred)) + mk(Forall)(name, set, pred) case _ => throw new RewriterException(s"QuantifierRule not applicable to $ex", ex) } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/RestrictedSetJudgement.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/RestrictedSetJudgement.scala index 61a405168a..e82d9d8f3e 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/RestrictedSetJudgement.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/RestrictedSetJudgement.scala @@ -31,8 +31,8 @@ class RestrictedSetJudgement(constSets: Map[String, UninterpretedSort]) { ex match { case ValEx(s: TlaPredefSet) => s match { - case TlaIntSet | TlaNatSet => IntSort() - case TlaBoolSet => BoolSort() + case TlaIntSet | TlaNatSet => IntSort + case TlaBoolSet => BoolSort case _ => throw new RewriterException(s"$s not supported in reTLA", ex) } case NameEx(name) if constSets.contains(name) => constSets(name) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TermToVMTWriter.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TermToVMTWriter.scala index ebed06d547..0210f4585f 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TermToVMTWriter.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TermToVMTWriter.scala @@ -26,7 +26,7 @@ object TermToVMTWriter { s"($head ${argStrings.mkString(" ")})" } - private def mkQuant(quant: String, boundVars: List[(String, Sort)], p: Term): String = { + private def mkQuant(quant: String, boundVars: Seq[(String, Sort)], p: Term): String = { val pairs = boundVars.map { case (name, sort) => s"($name ${sortStringForQuant(sort)})" } @@ -36,8 +36,8 @@ object TermToVMTWriter { // In quantifiers, complex sorts aren't permitted. private def sortStringForQuant(sort: Sort): String = sort match { - case IntSort() => "Int" - case BoolSort() => "Bool" + case IntSort => "Int" + case BoolSort => "Bool" case UninterpretedSort(name) => name // We should never have function sorts or untyped in quantifiers case s => throw new IllegalArgumentException(s"Sort of quantified variable cannot be $s") @@ -62,7 +62,6 @@ object TermToVMTWriter { case False => "false" case True => "true" case UninterpretedLiteral(s, sort) => s"${s}_${sort.sortName}" - case FunDef(name, _, _) => name case And(args @ _*) => mkAndOrArgs("and", "true", args) case Or(args @ _*) => mkAndOrArgs("or", "false", args) case Neg(x) => s"(not ${tr(x)})" @@ -103,12 +102,12 @@ object TermToVMTWriter { } // Constructs an SMT function definition from FunDef - def mkFunDef(fd: FunDef): String = { - val FunDef(name, args, body) = fd + def mkFunDef(fd: DefineFun): String = { + val DefineFun(name, args, body) = fd val pairs = args.map { case (name, argSort) => s"($name ${sortStringForQuant(argSort)})" } - val toSortString = sortStringForQuant(fd.sort.to) + val toSortString = sortStringForQuant(fd.asVar.sort.to) s"(define-fun $name (${pairs.mkString(" ")}) $toSortString ${tr(body)})" } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TlaExToVMTWriter.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TlaExToVMTWriter.scala index eaf992fb7a..fe0ac48e61 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TlaExToVMTWriter.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TlaExToVMTWriter.scala @@ -1,12 +1,10 @@ package at.forsyte.apalache.tla.bmcmt.rules.vmt import at.forsyte.apalache.io.OutputManager -import at.forsyte.apalache.tla.lir.formulas.Booleans.{And, Equiv, Exists, Forall, Impl, Neg, Or} -import at.forsyte.apalache.tla.lir.formulas.EUF.{Apply, Equal, FunDef, ITE, UninterpretedLiteral, UninterpretedVar} +import at.forsyte.apalache.tla.lir.TypedPredefs.TypeTagAsTlaType1 import at.forsyte.apalache.tla.lir.formulas._ -import at.forsyte.apalache.tla.lir.{ConstT1, SetT1, StrT1, TlaConstDecl, TlaEx, TlaVarDecl, Typed} +import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.pp.UniqueNameGenerator -import at.forsyte.apalache.tla.lir.TypedPredefs.TypeTagAsTlaType1 import scalaz.unused /** @@ -20,69 +18,6 @@ import scalaz.unused * Jure Kukovec */ class TlaExToVMTWriter(gen: UniqueNameGenerator) { - - // Collector is used to aggregate all function definitions, uninterpreted literals and uninterpreted sorts - // that appear in any operator anywhere, so we can declare them in the VMT file. - private class Collector { - var fnDefs: List[FunDef] = List.empty - var uninterpLits: Set[UninterpretedLiteral] = Set.empty - var uninterpSorts: Set[UninterpretedSort] = Set.empty - - private def addFnDef(fd: FunDef): Unit = - fnDefs = fd :: fnDefs - private def addUL(ul: UninterpretedLiteral): Unit = { - uninterpLits += ul - uninterpSorts += ul.sort - } - - private def addUS(us: UninterpretedSort): Unit = - uninterpSorts += us - - def collectAll(t: Term): Unit = t match { - case fd @ FunDef(_, _, body) => - addFnDef(fd) - collectAll(body) - case ITE(i, t, e) => - collectAll(i) - collectAll(t) - collectAll(e) - case Apply(fn, args @ _*) => - collectAll(fn) - args.foreach(collectAll) - case And(args @ _*) => args.foreach(collectAll) - case Or(args @ _*) => args.foreach(collectAll) - case Equiv(lhs, rhs) => - collectAll(lhs) - collectAll(rhs) - case Equal(lhs, rhs) => - collectAll(lhs) - collectAll(rhs) - case Impl(lhs, rhs) => - collectAll(lhs) - collectAll(rhs) - case Neg(arg) => collectAll(arg) - case Forall(boundVars, body) => - boundVars.foreach { case (_, setSort) => - setSort match { - case us: UninterpretedSort => addUS(us) - case _ => () - } - } - collectAll(body) - case Exists(boundVars, body) => - boundVars.foreach { case (_, setSort) => - setSort match { - case us: UninterpretedSort => addUS(us) - case _ => () - } - } - collectAll(body) - case UninterpretedVar(_, uvSort) => addUS(uvSort) - case ul: UninterpretedLiteral => addUL(ul) - case _ => () - } - } - // Main entry point. def annotateAndWrite( varDecls: Seq[TlaVarDecl], @@ -111,26 +46,32 @@ class TlaExToVMTWriter(gen: UniqueNameGenerator) { // val cinitStrs = cinits.map(TermToVMTWriter.mkSMT2String) // convenience shorthand - def rewrite: TlaEx => Term = rewriter.rewrite + def rewrite: TlaEx => TermBuilderT = rewriter.rewrite // Each transition in initTransitions needs the VMT wrapper Init - val inits = initTransitions.map { case (name, ex) => - Init(name, rewrite(ex)) - } - - val initStrs = inits.map(TermToVMTWriter.mkVMTString) + val initCmps = cmpSeq(initTransitions.map { case (name, ex) => + rewrite(ex).map { Init(name, _) } + }) // Each transition in nextTransitions needs the VMT wrapper Trans - val transitions = nextTransitions.map { case (name, ex) => - Trans(name, rewrite(ex)) - } - - val transStrs = transitions.map(TermToVMTWriter.mkVMTString) + val transitionCmps = cmpSeq(nextTransitions.map { case (name, ex) => + rewrite(ex).map { Trans(name, _) } + }) // Each invariant in invariants needs the VMT wrapper Invar - val invs = invariants.zipWithIndex.map { case ((name, ex), i) => - Invar(name, i, rewrite(ex)) - } + val invCmps = cmpSeq(invariants.zipWithIndex.map { case ((name, ex), i) => + rewrite(ex).map { Invar(name, i, _) } + }) + + val (smtDecls, (inits, transitions, invs)) = (for { + initTerms <- initCmps + transitionTerms <- transitionCmps + invTerms <- invCmps + } yield (initTerms, transitionTerms, invTerms)).run(SmtDeclarations.init) + + val initStrs = inits.map(TermToVMTWriter.mkVMTString) + + val transStrs = transitions.map(TermToVMTWriter.mkVMTString) val invStrs = invs.map(TermToVMTWriter.mkVMTString) @@ -145,20 +86,14 @@ class TlaExToVMTWriter(gen: UniqueNameGenerator) { // Variable declarations val smtVarDecls = varDecls.map(TermToVMTWriter.mkSMTDecl) - // Now we declare constants and define functions aggregated by Collector - val collector = new Collector - inits.foreach { i => collector.collectAll(i.initExpr) } - transitions.foreach { t => collector.collectAll(t.transExpr) } - invs.foreach { i => collector.collectAll(i.invExpr) } - // Sort declaration - val allSorts = (setConstants.values ++ collector.uninterpSorts).toSet + val allSorts = setConstants.values.toSet ++ smtDecls.uninterpretedSorts.map(UninterpretedSort) val sortDecls = allSorts.map(TermToVMTWriter.mkSortDecl) // Uninterpreted literal declaration and uniqueness assert - val ulitDecls = collector.uninterpLits.map(TermToVMTWriter.mkConstDecl) + val ulitDecls = smtDecls.uninterpretedLiterals.values.map(TermToVMTWriter.mkConstDecl) val disticntAsserts = allSorts.flatMap { s => - val litsForSortS = collector.uninterpLits.filter { + val litsForSortS = smtDecls.uninterpretedLiterals.values.filter { _.sort == s } (if (litsForSortS.size > 1) Some(litsForSortS) else None).map(TermToVMTWriter.assertDistinct) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TlaType1ToSortConverter.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TlaType1ToSortConverter.scala index 2cc91cb7ec..acb0469afa 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TlaType1ToSortConverter.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TlaType1ToSortConverter.scala @@ -12,8 +12,8 @@ import at.forsyte.apalache.tla.lir.{BoolT1, ConstT1, FunT1, IntT1, StrT1, TlaTyp object TlaType1ToSortConverter { def sortFromType(tt: TlaType1): Sort = tt match { - case IntT1 => IntSort() - case BoolT1 => BoolSort() + case IntT1 => IntSort + case BoolT1 => BoolSort case StrT1 => UninterpretedSort(tt.toString) case ConstT1(name) => UninterpretedSort(name) case FunT1(TupT1(args @ _*), res) => FunctionSort(sortFromType(res), args.map(sortFromType): _*) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/ToTermRewriter.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/ToTermRewriter.scala index 142574fc5c..c6d8dbfaf6 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/ToTermRewriter.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/ToTermRewriter.scala @@ -1,7 +1,6 @@ package at.forsyte.apalache.tla.bmcmt.rules.vmt import at.forsyte.apalache.tla.lir.TlaEx -import at.forsyte.apalache.tla.lir.formulas.Term /** * ToTermRewriter defines a translation from TLA+ to SMT Terms. @@ -10,5 +9,5 @@ import at.forsyte.apalache.tla.lir.formulas.Term * Jure Kukovec */ abstract class ToTermRewriter { - def rewrite(ex: TlaEx): Term + def rewrite(ex: TlaEx): TermBuilderT } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/ToTermRewriterImpl.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/ToTermRewriterImpl.scala index 4c8d7bf947..a5f114438b 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/ToTermRewriterImpl.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/ToTermRewriterImpl.scala @@ -1,11 +1,10 @@ package at.forsyte.apalache.tla.bmcmt.rules.vmt import at.forsyte.apalache.tla.bmcmt.RewriterException import at.forsyte.apalache.tla.lir.TlaEx -import at.forsyte.apalache.tla.lir.formulas.Term import at.forsyte.apalache.tla.pp.UniqueNameGenerator /** - * The ToTermRewriter implementation for reTLA to VMT. + * The ToTermRewriter implementation from reTLA to SMT Terms. * * @author * Jure Kukovec @@ -20,7 +19,7 @@ class ToTermRewriterImpl(constSets: ConstSetMapT, gen: UniqueNameGenerator) exte new ValueRule, ) - override def rewrite(ex: TlaEx): Term = + override def rewrite(ex: TlaEx): TermBuilderT = rules.find(r => r.isApplicable(ex)) match { case Some(r) => r(ex) diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/VMTExpr.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/VMTExpr.scala index 483f2ce29e..256f92f287 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/VMTExpr.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/VMTExpr.scala @@ -15,11 +15,11 @@ sealed case class Next(name: String, current: Variable, next: Variable) extends require(current.sort == next.sort, s"Variable binding $name must bind two variables of the same sort.") } sealed case class Init(name: String, initExpr: Term) extends VMTExpr { - require(initExpr.sort == BoolSort(), s"Initial state predicate $name must have Boolean sort.") + require(initExpr.sort == BoolSort, s"Initial state predicate $name must have Boolean sort.") } sealed case class Trans(name: String, transExpr: Term) extends VMTExpr { - require(transExpr.sort == BoolSort(), s"Transition predicate $name must have Boolean sort.") + require(transExpr.sort == BoolSort, s"Transition predicate $name must have Boolean sort.") } sealed case class Invar(name: String, idx: Int, invExpr: Term) extends VMTExpr { - require(invExpr.sort == BoolSort(), s"Invariant $name must have Boolean sort.") + require(invExpr.sort == BoolSort, s"Invariant $name must have Boolean sort.") } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/ValueRule.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/ValueRule.scala index 406d8f4ea4..5f867579be 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/ValueRule.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/ValueRule.scala @@ -1,13 +1,13 @@ package at.forsyte.apalache.tla.bmcmt.rules.vmt import at.forsyte.apalache.tla.bmcmt.RewriterException +import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.lir.formulas.Booleans.{False, True} import at.forsyte.apalache.tla.lir.formulas.EUF.UninterpretedLiteral import at.forsyte.apalache.tla.lir.formulas.Integers.IntLiteral import at.forsyte.apalache.tla.lir.formulas._ import at.forsyte.apalache.tla.lir.oper.TlaActionOper import at.forsyte.apalache.tla.lir.values.{TlaBool, TlaInt, TlaStr} -import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.types.ModelValueHandler /** @@ -30,22 +30,24 @@ class ValueRule extends FormulaRule { import ValueRule._ - def apply(ex: TlaEx): Term = ex match { - case ValEx(v) => - v match { - case TlaInt(i) => IntLiteral(i) - case TlaStr(s) => - val (tlaType, id) = ModelValueHandler.typeAndIndex(s).getOrElse((StrT1, s)) - UninterpretedLiteral(id, UninterpretedSort(tlaType.toString)) - case TlaBool(b) => if (b) True else False - case _ => throwOn(ex) - } - case nameEx: NameEx => termFromNameEx(nameEx) - case OperEx(TlaActionOper.prime, nEx: NameEx) => - // Rename x' to x^ for VMT - termFromNameEx(renamePrimesForVMT(nEx)) - case _ => throwOn(ex) - + def apply(ex: TlaEx): TermBuilderT = { + val term = ex match { + case ValEx(v) => + v match { + case TlaInt(i) => IntLiteral(i) + case TlaStr(s) => + val (tlaType, id) = ModelValueHandler.typeAndIndex(s).getOrElse((StrT1, s)) + UninterpretedLiteral(id, UninterpretedSort(tlaType.toString)) + case TlaBool(b) => if (b) True else False + case _ => throwOn(ex) + } + case nameEx: NameEx => termFromNameEx(nameEx) + case OperEx(TlaActionOper.prime, nEx: NameEx) => + // Rename x' to x^ for VMT + termFromNameEx(renamePrimesForVMT(nEx)) + case _ => throwOn(ex) + } + storeUninterpretedLiteralOrVar(term).map { _ => term } } } @@ -62,7 +64,7 @@ object ValueRule { val sort = TlaType1ToSortConverter.sortFromType(tt) mkVariable(ex.name, sort) case Untyped => - mkVariable(ex.name, UntypedSort()) + mkVariable(ex.name, UntypedSort) case Typed(other) => throw new RewriterException(s"Term construction is not supported: $other is not in TlaType1", ex) } diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/package.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/package.scala index 59f9d3d232..3d7220c08f 100644 --- a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/package.scala +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/package.scala @@ -2,16 +2,73 @@ package at.forsyte.apalache.tla.bmcmt.rules import at.forsyte.apalache.tla.lir.NameEx import at.forsyte.apalache.tla.lir.formulas.Booleans.BoolVar -import at.forsyte.apalache.tla.lir.formulas.EUF.{FunctionVar, UninterpretedVar} +import at.forsyte.apalache.tla.lir.formulas.EUF.{DefineFun, FunctionVar, UninterpretedLiteral, UninterpretedVar} import at.forsyte.apalache.tla.lir.formulas.Integers.IntVar -import at.forsyte.apalache.tla.lir.formulas.{Sort, Variable} import at.forsyte.apalache.tla.lir.formulas._ +import scalaz.Scalaz._ +import scalaz._ package object vmt { type ConstSetMapT = Map[String, UninterpretedSort] + // collects all definitions/declarations that rules may discharge. In principle, this could be a single bucket + // of Declarations, but for clarity, it's nicer to split them. + // This should be future-proof, so any State modifications should always use _copy_, as to allow for the addition of + // other declaration fields later. + sealed case class SmtDeclarations( + definedFunctions: Map[String, DefineFun], + uninterpretedSorts: Set[String], + uninterpretedLiterals: Map[String, UninterpretedLiteral]) + + object SmtDeclarations { + def init: SmtDeclarations = SmtDeclarations(Map.empty, Set.empty, Map.empty) + } + + type TermBuilderTemplateT[A] = State[SmtDeclarations, A] + type TermBuilderT = TermBuilderTemplateT[Term] + + /** Turns a sequence of States into a single State wrapping list of values */ + def cmpSeq[A, S](args: Iterable[State[S, A]]): State[S, List[A]] = + // Scalaz defines .sequence only over Lists, not Seqs, but we get args (from variadic constructors) + // as Seq, so there's a bit of back-and-forth conversion happening here. + args.toList.sequence + + /** Adds a function definition to the internal state collection, and returns that function's Term representation. */ + def storeDefAndUse(funDef: DefineFun): TermBuilderT = State[SmtDeclarations, Term] { s => + (s.copy(definedFunctions = s.definedFunctions + (funDef.name -> funDef)), funDef.asVar) + } + + /** + * Creates and adds a function definition to the internal state collection, and returns that function's Term + * representation. + */ + def defineAndUse(name: String, args: Seq[(String, Sort)], body: Term): TermBuilderT = { + val funDef = DefineFun(name, args, body) + storeDefAndUse(funDef) + } + + /** Adds an uninterpreted sort declaration to the internal state collection. */ + def storeUninterpretedSort(sort: Sort): TermBuilderTemplateT[Unit] = sort match { + case UninterpretedSort(name) => + modify[SmtDeclarations] { s => s.copy(uninterpretedSorts = s.uninterpretedSorts + name) } + case _ => ().point[TermBuilderTemplateT] + } + + /** + * Adds an uninterpreted literal declaration to the internal state collection. If its Sort is not declared yet, or the + * Term is an uninterpreted variable instead, also adds the sort declaration. + */ + def storeUninterpretedLiteralOrVar(term: Term): TermBuilderTemplateT[Unit] = term match { + case l @ UninterpretedLiteral(name, sort) => + storeUninterpretedSort(sort).flatMap { _ => + modify[SmtDeclarations] { s => s.copy(uninterpretedLiterals = s.uninterpretedLiterals + (name -> l)) } + } + case UninterpretedVar(_, sort) => storeUninterpretedSort(sort) + case _ => ().point[TermBuilderTemplateT] + } + /** - * Since ' is not a legal symbol in SMTLIB, we have to choose a convention for the names of primed variables. If `x` + * Since ['] is not a legal symbol in SMTLIB, we have to choose a convention for the names of primed variables. If `x` * is a variable name, then `x^` is the name used to represent `x'` in SMTLIB. */ def VMTprimeName(s: String) = s"$s^" @@ -30,8 +87,8 @@ package object vmt { * Creates a `Variable` term, of the appropriate subtype, based on the sort. */ def mkVariable(name: String, sort: Sort): Variable = sort match { - case IntSort() => IntVar(name) - case BoolSort() => BoolVar(name) + case IntSort => IntVar(name) + case BoolSort => BoolVar(name) case fs: FunctionSort => FunctionVar(name, fs) case us: UninterpretedSort => UninterpretedVar(name, us) case s => new Variable(name) { override val sort: Sort = s } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestBoolRule.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestBoolRule.scala index bfb62d952a..c8d22a7030 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestBoolRule.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestBoolRule.scala @@ -1,30 +1,31 @@ package at.forsyte.apalache.tla.bmcmt.rules.vmt -import at.forsyte.apalache.tla.lir.TypedPredefs._ -import at.forsyte.apalache.tla.lir.{BoolT1, TlaEx} -import at.forsyte.apalache.tla.lir.convenience.tla +import at.forsyte.apalache.tla.lir.{BoolT1, IntT1, SetT1, TlaEx, TlaType1} import at.forsyte.apalache.tla.lir.formulas.Booleans._ +import at.forsyte.apalache.tla.lir.formulas.Term +import at.forsyte.apalache.tla.typecomp.TBuilderInstruction +import at.forsyte.apalache.tla.types.tla import org.junit.runner.RunWith import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestBoolRule extends AnyFunSuite { - val rewriter = ToTermRewriterImpl() + val rewriter: ToTermRewriter = ToTermRewriterImpl() - val rule = new BoolRule(rewriter) + val rule: FormulaRule = new BoolRule(rewriter) - val b = BoolT1 + val b: TlaType1 = BoolT1 - val p = tla.name("p").as(b) - val pVar = BoolVar("p") - val q = tla.name("q").as(b) - val qVar = BoolVar("q") + val p: TBuilderInstruction = tla.name("p", b) + val pVar: Term = BoolVar("p") + val q: TBuilderInstruction = tla.name("q", b) + val qVar: Term = BoolVar("q") val expected: Map[TlaEx, BoolExpr] = Map( - (tla.and(p, q).as(b)) -> And(pVar, qVar), - (tla.not(p).as(b)) -> Neg(pVar), - (tla.or(tla.impl(p, q).as(b), p).as(b)) -> Or(Impl(pVar, qVar), pVar), + tla.and(p, q).build -> And(pVar, qVar), + tla.not(p).build -> Neg(pVar), + tla.or(tla.impl(p, q), p).build -> Or(Impl(pVar, qVar), pVar), ) test("BoolRule applicability") { @@ -32,25 +33,23 @@ class TestBoolRule extends AnyFunSuite { assert(rule.isApplicable(ex)) } - import at.forsyte.apalache.tla.lir.UntypedPredefs._ - val notApp = List( tla.tuple(tla.int(1), tla.int(2)), - tla.funSet(tla.name("S"), tla.dotdot(tla.int(1), tla.int(42))), - tla.unchanged(tla.name("x")), - tla.forall(tla.name("x"), tla.name("S"), tla.name("p")), + tla.funSet(tla.name("S", SetT1(IntT1)), tla.dotdot(tla.int(1), tla.int(42))), + tla.unchanged(tla.name("x", IntT1)), + tla.forall(tla.name("x", IntT1), tla.name("S", SetT1(IntT1)), tla.name("p", BoolT1)), tla.int(2), tla.bool(true), ) notApp.foreach { ex => - assert(!rule.isApplicable(ex.untyped())) + assert(!rule.isApplicable(ex)) } } test("BoolRule correctness") { expected.foreach { case (k, expected) => - val actual = rule(k) + val actual = rule(k).run(SmtDeclarations.init)._2 assert(actual == expected) } } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestEUFRule.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestEUFRule.scala index 8068f3880a..2a27a1aa8a 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestEUFRule.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestEUFRule.scala @@ -1,12 +1,12 @@ package at.forsyte.apalache.tla.bmcmt.rules.vmt -import at.forsyte.apalache.tla.lir.TypedPredefs._ -import at.forsyte.apalache.tla.lir.{BoolT1, ConstT1, FunT1, IntT1, SetT1, TlaEx, TupT1} -import at.forsyte.apalache.tla.lir.convenience.tla +import at.forsyte.apalache.tla.lir.{BoolT1, ConstT1, FunT1, IntT1, SetT1, TlaEx, TlaType1, TupT1} import at.forsyte.apalache.tla.lir.formulas.Booleans._ -import at.forsyte.apalache.tla.lir.formulas.EUF.{Apply, Equal, FunDef, FunctionVar, ITE} +import at.forsyte.apalache.tla.lir.formulas.EUF.{Apply, DefineFun, Equal, FunctionVar, ITE} import at.forsyte.apalache.tla.lir.formulas._ import at.forsyte.apalache.tla.pp.UniqueNameGenerator +import at.forsyte.apalache.tla.typecomp.TBuilderInstruction +import at.forsyte.apalache.tla.types.tla import org.junit.runner.RunWith import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.junit.JUnitRunner @@ -14,44 +14,44 @@ import org.scalatestplus.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestEUFRule extends AnyFunSuite { - val sType = ConstT1("SSORT") - val sSort = UninterpretedSort("SSORT") + val sType: TlaType1 = ConstT1("SSORT") + val sSort: UninterpretedSort = UninterpretedSort("SSORT") - val constSets = Map("S" -> sSort) + val constSets: ConstSetMapT = Map("S" -> sSort) - val rewriter = ToTermRewriterImpl(constSets) + val rewriter: ToTermRewriter = ToTermRewriterImpl(constSets) - val funName = "f" - val constGen = new UniqueNameGenerator { + val funName: String = "f" + val constGen: UniqueNameGenerator = new UniqueNameGenerator { override def newName(): String = funName } - val fType = FunT1(TupT1(sType, IntT1), sType) - val f = tla.name(funName).as(fType) + val fType: TlaType1 = FunT1(TupT1(sType, IntT1), sType) + val f: TBuilderInstruction = tla.name(funName, fType) - val rule = new EUFRule(rewriter, new RestrictedSetJudgement(constSets), constGen) + val rule: FormulaRule = new EUFRule(rewriter, new RestrictedSetJudgement(constSets), constGen) - val b = BoolT1 + val b: TlaType1 = BoolT1 - val p = tla.name("p").as(b) - val pVar = BoolVar("p") - val q = tla.name("q").as(b) - val qVar = BoolVar("q") + val p: TBuilderInstruction = tla.name("p", b) + val pVar: Term = BoolVar("p") + val q: TBuilderInstruction = tla.name("q", b) + val qVar: Term = BoolVar("q") - val x = tla.name("x").as(sType) - val xVar = mkVariable("x", sSort) - val xPrimeVar = mkVariable(VMTprimeName("x"), sSort) - val y = tla.name("y").as(IntT1) - val set = tla.name("S").as(SetT1(sType)) - val intSet = tla.intSet().as(SetT1(IntT1)) + val x: TBuilderInstruction = tla.name("x", sType) + val xVar: Term = mkVariable("x", sSort) + val xPrimeVar: Term = mkVariable(VMTprimeName("x"), sSort) + val y: TBuilderInstruction = tla.name("y", IntT1) + val set: TBuilderInstruction = tla.name("S", SetT1(sType)) + val intSet: TBuilderInstruction = tla.intSet() val expected: Map[TlaEx, Term] = Map( - tla.assign(tla.prime(x).as(sType), x).as(b) -> Equal(xPrimeVar, xVar), - tla.eql(x, x).as(b) -> Equal(xVar, xVar), - tla.ite(p, p, q).as(b) -> ITE(pVar, pVar, qVar), - tla.funDef(x, x, set, y, intSet).as(fType) -> - FunDef(funName, List(("x", sSort), ("y", IntSort())), xVar), - tla.appFun(f, tla.tuple(x, y).as(fType.arg)).as(fType.res) -> - Apply(FunctionVar(funName, FunctionSort(sSort, sSort, IntSort())), xVar, mkVariable("y", IntSort())), + tla.assign(tla.prime(x), x).build -> Equal(xPrimeVar, xVar), + tla.eql(x, x).build -> Equal(xVar, xVar), + tla.ite(p, p, q).build -> ITE(pVar, pVar, qVar), + tla.funDef(x, x -> set, y -> intSet).build -> + DefineFun(funName, List(("x", sSort), ("y", IntSort)), xVar).asVar, + tla.app(f, tla.tuple(x, y)).build -> + Apply(FunctionVar(funName, FunctionSort(sSort, sSort, IntSort)), xVar, mkVariable("y", IntSort)), ) test("EUFRule applicability") { @@ -59,25 +59,23 @@ class TestEUFRule extends AnyFunSuite { assert(rule.isApplicable(ex)) } - import at.forsyte.apalache.tla.lir.UntypedPredefs._ - val notApp = List( tla.tuple(tla.int(1), tla.int(2)), - tla.funSet(tla.name("S"), tla.dotdot(tla.int(1), tla.int(42))), - tla.unchanged(tla.name("x")), - tla.and(tla.name("x"), tla.name("T"), tla.name("p")), + tla.funSet(tla.name("S", SetT1(IntT1)), tla.dotdot(tla.int(1), tla.int(42))), + tla.unchanged(tla.name("x", IntT1)), + tla.and(tla.name("x", BoolT1), tla.name("T", BoolT1), tla.name("p", BoolT1)), tla.int(2), tla.bool(true), ) notApp.foreach { ex => - assert(!rule.isApplicable(ex.untyped())) + assert(!rule.isApplicable(ex)) } } test("EUFRule correctness") { expected.foreach { case (k, expected) => - val actual = rule(k) + val actual = rule(k).run(SmtDeclarations.init)._2 assert(actual == expected) } } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestJudgement.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestJudgement.scala index a4ba9db66a..f7434f202d 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestJudgement.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestJudgement.scala @@ -1,14 +1,13 @@ package at.forsyte.apalache.tla.bmcmt.rules.vmt import at.forsyte.apalache.tla.bmcmt.RewriterException -import at.forsyte.apalache.tla.lir.{NameEx, TlaEx, ValEx} import at.forsyte.apalache.tla.lir.formulas._ import at.forsyte.apalache.tla.lir.values.{TlaRealSet, TlaStrSet} +import at.forsyte.apalache.tla.lir._ +import at.forsyte.apalache.tla.types.tla import org.junit.runner.RunWith import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.junit.JUnitRunner -import at.forsyte.apalache.tla.lir.UntypedPredefs._ -import at.forsyte.apalache.tla.lir.convenience.tla @RunWith(classOf[JUnitRunner]) class TestJudgement extends AnyFunSuite { @@ -21,21 +20,23 @@ class TestJudgement extends AnyFunSuite { "z" -> UninterpretedSort("ZSORT"), ) - val allowed: Seq[TlaEx] = (Seq( + val constantMapKeyExs: Seq[TlaEx] = constantMap.keys.toSeq.map { tla.name(_, ConstT1("X")) }.map { _.build } + + val allowed: Seq[TlaEx] = Seq( tla.intSet(), tla.natSet(), tla.booleanSet(), - ).map { _.untyped() }) ++ (constantMap.keys.toSeq.map { tla.name(_).untyped() }) + ).map { _.build } ++ constantMapKeyExs val disallowed: Seq[TlaEx] = Seq( - ValEx(TlaRealSet), - ValEx(TlaStrSet), + ValEx(TlaRealSet)(Typed(SetT1(RealT1))), + ValEx(TlaStrSet)(Typed(SetT1(StrT1))), tla.enumSet(tla.int(1), tla.int(2)), tla.dotdot(tla.int(0), tla.int(42)), - NameEx("potato"), + tla.name("potato", SetT1(IntT1)), ) - val judgement = new RestrictedSetJudgement(constantMap) + val judgement: RestrictedSetJudgement = new RestrictedSetJudgement(constantMap) test("Restricted set recognition") { allowed.foreach { ex => @@ -50,10 +51,10 @@ class TestJudgement extends AnyFunSuite { test("Restricted set Sort recognition") { val expected: Map[TlaEx, Sort] = Map( - tla.intSet().untyped() -> IntSort(), - tla.natSet().untyped() -> IntSort(), - tla.booleanSet().untyped() -> BoolSort(), - ) ++ (constantMap.map { case (k, v) => tla.name(k).untyped() -> v }) + tla.intSet().build -> IntSort, + tla.natSet().build -> IntSort, + tla.booleanSet().build -> BoolSort, + ) ++ (constantMap.map { case (k, v) => tla.name(k, ConstT1("X")).build -> v }) allowed.foreach { ex => assert(judgement.getSort(ex) == expected(ex)) diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestQuantRule.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestQuantRule.scala index 41321b22ab..709dbe152b 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestQuantRule.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestQuantRule.scala @@ -1,11 +1,11 @@ package at.forsyte.apalache.tla.bmcmt.rules.vmt import at.forsyte.apalache.tla.bmcmt.RewriterException -import at.forsyte.apalache.tla.lir.TypedPredefs._ -import at.forsyte.apalache.tla.lir.{BoolT1, ConstT1, IntT1, SetT1, TlaEx} -import at.forsyte.apalache.tla.lir.convenience.tla +import at.forsyte.apalache.tla.lir._ import at.forsyte.apalache.tla.lir.formulas.Booleans._ import at.forsyte.apalache.tla.lir.formulas._ +import at.forsyte.apalache.tla.typecomp.TBuilderInstruction +import at.forsyte.apalache.tla.types.tla import org.junit.runner.RunWith import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.junit.JUnitRunner @@ -13,30 +13,30 @@ import org.scalatestplus.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestQuantRule extends AnyFunSuite { - val sType = ConstT1("SSORT") - val sSort = UninterpretedSort("SSORT") + val sType: TlaType1 = ConstT1("SSORT") + val sSort: UninterpretedSort = UninterpretedSort("SSORT") - val constSets = Map("S" -> sSort) + val constSets: ConstSetMapT = Map("S" -> sSort) - val rewriter = ToTermRewriterImpl(constSets) + val rewriter: ToTermRewriter = ToTermRewriterImpl(constSets) - val rule = new QuantifierRule(rewriter, new RestrictedSetJudgement(constSets)) + val rule: FormulaRule = new QuantifierRule(rewriter, new RestrictedSetJudgement(constSets)) - val b = BoolT1 + val b: TlaType1 = BoolT1 - val p = tla.name("p").as(b) - val pVar = BoolVar("p") - val q = tla.name("q").as(b) - val qVar = BoolVar("q") + val p: TBuilderInstruction = tla.name("p", b) + val pVar: Term = BoolVar("p") + val q: TBuilderInstruction = tla.name("q", b) + val qVar: Term = BoolVar("q") - val x = tla.name("x").as(sType) - val y = tla.name("y").as(IntT1) - val set = tla.name("S").as(SetT1(sType)) - val intSet = tla.intSet().as(SetT1(IntT1)) + val x: TBuilderInstruction = tla.name("x", sType) + val y: TBuilderInstruction = tla.name("y", IntT1) + val set: TBuilderInstruction = tla.name("S", SetT1(sType)) + val intSet: TBuilderInstruction = tla.intSet() val expected: Map[TlaEx, BoolExpr] = Map( - (tla.exists(x, set, p).as(b)) -> Exists(List(("x", sSort)), pVar), - (tla.forall(y, intSet, q).as(b)) -> Forall(List(("y", IntSort())), qVar), + tla.exists(x, set, p).build -> Exists(List(("x", sSort)), pVar), + tla.forall(y, intSet, q).build -> Forall(List(("y", IntSort)), qVar), ) test("QuantRule applicability") { @@ -46,28 +46,26 @@ class TestQuantRule extends AnyFunSuite { assertThrows[RewriterException] { val tType = ConstT1("TSORT") - rule(tla.exists(tla.name("t").as(tType), tla.name("T").as(tType), p).as(b)) + rule(tla.exists(tla.name("t", tType), tla.name("T", SetT1(tType)), p)) } - import at.forsyte.apalache.tla.lir.UntypedPredefs._ - val notApp = List( tla.tuple(tla.int(1), tla.int(2)), - tla.funSet(tla.name("S"), tla.dotdot(tla.int(1), tla.int(42))), - tla.unchanged(tla.name("x")), - tla.and(tla.name("x"), tla.name("T"), tla.name("p")), + tla.funSet(tla.name("S", SetT1(IntT1)), tla.dotdot(tla.int(1), tla.int(42))), + tla.unchanged(tla.name("x", IntT1)), + tla.and(tla.name("x", BoolT1), tla.name("T", BoolT1), tla.name("p", BoolT1)), tla.int(2), tla.bool(true), ) notApp.foreach { ex => - assert(!rule.isApplicable(ex.untyped())) + assert(!rule.isApplicable(ex)) } } test("QuantRule correctness") { expected.foreach { case (k, expected) => - val actual = rule(k) + val actual = rule(k).run(SmtDeclarations.init)._2 assert(actual == expected) } } diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestTermConstruction.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestTermConstruction.scala index 93054800c6..74ba7350c0 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestTermConstruction.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestTermConstruction.scala @@ -11,19 +11,19 @@ import at.forsyte.apalache.tla.lir.formulas._ @RunWith(classOf[JUnitRunner]) class TestTermConstruction extends AnyFunSuite { - val int1 = IntLiteral(1) - val int2 = IntLiteral(2) - val intV1 = IntVar("a") - val intV2 = IntVar("b") + val int1: Term = IntLiteral(1) + val int2: Term = IntLiteral(2) + val intV1: Term = IntVar("a") + val intV2: Term = IntVar("b") - val usort1 = UninterpretedSort("A") - val usort2 = UninterpretedSort("B") + val usort1: UninterpretedSort = UninterpretedSort("A") + val usort2: UninterpretedSort = UninterpretedSort("B") - val uval1 = UninterpretedLiteral("x", usort1) - val uval2 = UninterpretedLiteral("y", usort2) + val uval1: Term = UninterpretedLiteral("x", usort1) + val uval2: Term = UninterpretedLiteral("y", usort2) - val uvar1 = UninterpretedVar("c", usort1) - val uvar2 = UninterpretedVar("d", usort2) + val uvar1: Term = UninterpretedVar("c", usort1) + val uvar2: Term = UninterpretedVar("d", usort2) test("Equal requirements") { // Does not throw: @@ -61,10 +61,10 @@ class TestTermConstruction extends AnyFunSuite { test("Apply requirements.") { - val fNullary = FunctionSort(IntSort()) - val fUnary1 = FunctionSort(IntSort(), BoolSort()) + val fNullary = FunctionSort(IntSort) + val fUnary1 = FunctionSort(IntSort, BoolSort) val fUnary2 = FunctionSort(usort1, usort2) - val fNary1 = FunctionSort(usort2, IntSort(), IntSort(), BoolSort()) + val fNary1 = FunctionSort(usort2, IntSort, IntSort, BoolSort) val fnTermNullary = FunctionVar("f", fNullary) val fnTermUnary1 = FunctionVar("f", fUnary1) diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestValueRule.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestValueRule.scala index f1563c2fa1..e718faf211 100644 --- a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestValueRule.scala +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/rules/vmt/TestValueRule.scala @@ -1,12 +1,12 @@ package at.forsyte.apalache.tla.bmcmt.rules.vmt -import at.forsyte.apalache.tla.lir.TypedPredefs._ -import at.forsyte.apalache.tla.lir.convenience.tla import at.forsyte.apalache.tla.lir.formulas.Booleans.{BoolVar, True} import at.forsyte.apalache.tla.lir.formulas.EUF.{UninterpretedLiteral, UninterpretedVar} import at.forsyte.apalache.tla.lir.formulas.Integers.{IntLiteral, IntVar} import at.forsyte.apalache.tla.lir.formulas.{Term, UninterpretedSort} -import at.forsyte.apalache.tla.lir.{BoolT1, ConstT1, IntT1, StrT1, TlaEx} +import at.forsyte.apalache.tla.lir.{BoolT1, ConstT1, IntT1, SetT1, StrT1, TlaEx, TlaType1} +import at.forsyte.apalache.tla.typecomp.TBuilderInstruction +import at.forsyte.apalache.tla.types.tla import org.junit.runner.RunWith import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.junit.JUnitRunner @@ -14,33 +14,33 @@ import org.scalatestplus.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) class TestValueRule extends AnyFunSuite { - val rule = new ValueRule + val rule: FormulaRule = new ValueRule - val b = BoolT1 - val i = IntT1 - val aType = ConstT1("A") - val bType = ConstT1("B") + val b: TlaType1 = BoolT1 + val i: TlaType1 = IntT1 + val aType: TlaType1 = ConstT1("A") + val bType: TlaType1 = ConstT1("B") - val intEx1 = tla.int(1).as(i) - val intEx2 = tla.name("x").as(i) - val intEx3 = tla.plus(intEx1, intEx2).as(i) + val intEx1: TBuilderInstruction = tla.int(1) + val intEx2: TBuilderInstruction = tla.name("x", i) + val intEx3: TBuilderInstruction = tla.plus(intEx1, intEx2) - val boolEx1 = tla.bool(true).as(b) - val boolEx2 = tla.name("p").as(b) - val boolEx3 = tla.and(boolEx1, boolEx2).as(b) + val boolEx1: TBuilderInstruction = tla.bool(true) + val boolEx2: TBuilderInstruction = tla.name("p", b) + val boolEx3: TBuilderInstruction = tla.and(boolEx1, boolEx2) - val uninterpEx1 = tla.str("1_OF_A").as(aType) - val uninterpEx2 = tla.name("v").as(bType) - val uninterpEx3 = tla.str("string").as(StrT1) + val uninterpEx1: TBuilderInstruction = tla.constParsed("1_OF_A") + val uninterpEx2: TBuilderInstruction = tla.name("v", bType) + val uninterpEx3: TBuilderInstruction = tla.str("string") val expected: Map[TlaEx, Term] = Map( - intEx1 -> IntLiteral(1), - intEx2 -> IntVar("x"), - boolEx1 -> True, - boolEx2 -> BoolVar("p"), - uninterpEx1 -> UninterpretedLiteral("1", UninterpretedSort("A")), - uninterpEx2 -> UninterpretedVar("v", UninterpretedSort("B")), - uninterpEx3 -> UninterpretedLiteral("string", UninterpretedSort(StrT1.toString)), + intEx1.build -> IntLiteral(1), + intEx2.build -> IntVar("x"), + boolEx1.build -> True, + boolEx2.build -> BoolVar("p"), + uninterpEx1.build -> UninterpretedLiteral("1", UninterpretedSort("A")), + uninterpEx2.build -> UninterpretedVar("v", UninterpretedSort("B")), + uninterpEx3.build -> UninterpretedLiteral("string", UninterpretedSort(StrT1.toString)), ) test("ValueRule applicability") { @@ -48,25 +48,23 @@ class TestValueRule extends AnyFunSuite { assert(rule.isApplicable(ex)) } - import at.forsyte.apalache.tla.lir.UntypedPredefs._ - val notApp = List( tla.and(tla.bool(true), tla.bool(false)), - tla.impl(tla.name("x"), tla.name("q")), + tla.impl(tla.name("x", BoolT1), tla.name("q", BoolT1)), tla.tuple(tla.int(1), tla.int(2)), - tla.funSet(tla.name("S"), tla.dotdot(tla.int(1), tla.int(42))), - tla.unchanged(tla.name("x")), - tla.forall(tla.name("x"), tla.name("S"), tla.name("p")), + tla.funSet(tla.name("S", SetT1(IntT1)), tla.dotdot(tla.int(1), tla.int(42))), + tla.unchanged(tla.name("x", IntT1)), + tla.forall(tla.name("x", IntT1), tla.name("S", SetT1(IntT1)), tla.name("p", BoolT1)), ) notApp.foreach { ex => - assert(!rule.isApplicable(ex.untyped())) + assert(!rule.isApplicable(ex)) } } test("ValueRune correctness") { expected.foreach { case (k, expected) => - val actual = rule(k) + val actual = rule(k).run(SmtDeclarations.init)._2 assert(actual == expected) } } diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/Base.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/Base.scala index bd856906b4..7ae5d5b395 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/Base.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/Base.scala @@ -1,6 +1,6 @@ package at.forsyte.apalache.tla.lir.formulas -import scalaz.unused +import scala.annotation.unused /** * A representation of an SMT/VMT sort. We only support non-parametric sorts at the moment. @@ -11,7 +11,7 @@ import scalaz.unused abstract class Sort(val sortName: String) /** - * A representation of an SMT/VMT term. Each term has a singular sort. + * A representation of an SMT term. Each term has a singular sort. * * @author * Jure Kukovec @@ -22,10 +22,20 @@ trait Term { abstract class Variable(@unused name: String) extends Term -sealed case class BoolSort() extends Sort("Boolean") -sealed case class IntSort() extends Sort("Integer") -sealed case class UntypedSort() extends Sort("Untyped") +case object BoolSort extends Sort("Boolean") +case object IntSort extends Sort("Integer") +case object UntypedSort extends Sort("Untyped") sealed case class UninterpretedSort(override val sortName: String) extends Sort(sortName) sealed case class FunctionSort(to: Sort, from: Sort*) extends Sort("Function") { def arity: Int = from.size } + +/** + * A representation of an SMT declaration. + * + * @author + * Jure Kukovec + */ +abstract class Declaration + +sealed case class DeclareConst(name: String, sort: Sort) extends Declaration diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/Booleans.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/Booleans.scala index 5e55c95661..e379fcc895 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/Booleans.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/Booleans.scala @@ -8,33 +8,33 @@ package at.forsyte.apalache.tla.lir.formulas */ object Booleans { trait BoolExpr extends Term { - val sort: Sort = BoolSort() + val sort: Sort = BoolSort } object False extends BoolExpr object True extends BoolExpr sealed case class And(args: Term*) extends BoolExpr { - require(args.forall(_.sort == BoolSort()), "All arguments of a conjunction must have Boolean sorts.") + require(args.forall(_.sort == BoolSort), "All arguments of a conjunction must have Boolean sorts.") } sealed case class Or(args: Term*) extends BoolExpr { - require(args.forall(_.sort == BoolSort()), "All arguments of a disjunction must have Boolean sorts.") + require(args.forall(_.sort == BoolSort), "All arguments of a disjunction must have Boolean sorts.") } sealed case class Neg(arg: Term) extends BoolExpr { - require(arg.sort == BoolSort(), "Negation is only applicable to arguments with Boolean sorts.") + require(arg.sort == BoolSort, "Negation is only applicable to arguments with Boolean sorts.") } sealed case class Impl(lhs: Term, rhs: Term) extends BoolExpr { - require(Seq(lhs, rhs).forall { _.sort == BoolSort() }, + require(Seq(lhs, rhs).forall { _.sort == BoolSort }, "Implication is only applicable to arguments with Boolean sorts.") } sealed case class Equiv(lhs: Term, rhs: Term) extends BoolExpr { - require(Seq(lhs, rhs).forall { _.sort == BoolSort() }, + require(Seq(lhs, rhs).forall { _.sort == BoolSort }, "Equivalence is only applicable to arguments with Boolean sorts.") } - sealed case class Forall(boundVars: List[(String, Sort)], arg: Term) extends BoolExpr { - require(arg.sort == BoolSort(), "Quantification condition must be Boolean.") + sealed case class Forall(boundVars: Seq[(String, Sort)], arg: Term) extends BoolExpr { + require(arg.sort == BoolSort, "Quantification condition must be Boolean.") } - sealed case class Exists(boundVars: List[(String, Sort)], arg: Term) extends BoolExpr { - require(arg.sort == BoolSort(), "Quantification condition must be Boolean.") + sealed case class Exists(boundVars: Seq[(String, Sort)], arg: Term) extends BoolExpr { + require(arg.sort == BoolSort, "Quantification condition must be Boolean.") } sealed case class BoolVar(name: String) extends Variable(name) with BoolExpr } diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/EUF.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/EUF.scala index 5e702ed187..9ae9534dca 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/EUF.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/EUF.scala @@ -1,7 +1,7 @@ package at.forsyte.apalache.tla.lir.formulas /** - * EUF defines constructors for terms in the fragment of (E)quality and (U)ninterpreted (f)unctions. + * EUF defines constructors for terms in the fragment of (E)quality and (U)ninterpreted (F)unctions. * * @author * Jure Kukovec @@ -22,23 +22,27 @@ object EUF { } sealed case class ITE(cond: Term, thenTerm: Term, elseTerm: Term) extends Term { // Sanity check - require(cond.sort == BoolSort(), "IF-condition must have Boolean sort.") + require(cond.sort == BoolSort, "IF-condition must have Boolean sort.") require(thenTerm.sort == elseTerm.sort, "ITE is only defined for branches of matching sorts.") val sort: Sort = thenTerm.sort } - /** - * A function term. FunDef plays a dual role, because it conceptually represents side-effects: SMT requires that each - * function is defined separately from where it is used, unlike TLA. If we want to translate a TLA syntax-tree to - * s-expressions, we either need side-effects (for introducing definitions), or as is the case with FunDef, we pack - * the definition with the term, and recover it later (see VMTWriter::Collector) - * - * In terms of s-expressions (and when translated to a string), it is equivalent to FunctionVar(name, sort). - */ - sealed case class FunDef(name: String, args: List[(String, Sort)], body: Term) extends FnExpr { - val sort: FunctionSort = FunctionSort(body.sort, args.map { _._2 }: _*) - } sealed case class FunctionVar(name: String, sort: FunctionSort) extends Variable(name) with FnExpr + + sealed case class DeclareFun(name: String, fnSort: FunctionSort) extends Declaration { + def asVar: FunctionVar = FunctionVar(name, fnSort) + } + + sealed case class DefineFun(name: String, args: Seq[(String, Sort)], body: Term) extends Declaration { + def asVar: FunctionVar = { + val sort: FunctionSort = FunctionSort(body.sort, + args.map { + _._2 + }: _*) + FunctionVar(name, sort) + } + } + sealed case class Apply(fn: Term, args: Term*) extends Term { require(hasFnSort(fn), "Apply is only defined for terms with function sorts.") private val asFnSort = fn.sort.asInstanceOf[FunctionSort] diff --git a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/Integers.scala b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/Integers.scala index 2d1f08deee..68461800d8 100644 --- a/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/Integers.scala +++ b/tlair/src/main/scala/at/forsyte/apalache/tla/lir/formulas/Integers.scala @@ -3,23 +3,43 @@ package at.forsyte.apalache.tla.lir.formulas object Integers { trait IntExpr extends Term { - val sort: Sort = IntSort() + val sort: Sort = IntSort } - // TODO: Before introducing integers, refactor static type requirement of IntExpr to require(_.sort == IntSort) - - sealed case class Plus(lhs: IntExpr, rhs: IntExpr) extends IntExpr - sealed case class Minus(lhs: IntExpr, rhs: IntExpr) extends IntExpr - sealed case class Uminus(arg: IntExpr) extends IntExpr - sealed case class Mult(lhs: IntExpr, rhs: IntExpr) extends IntExpr - sealed case class Div(lhs: IntExpr, rhs: IntExpr) extends IntExpr - sealed case class Mod(lhs: IntExpr, rhs: IntExpr) extends IntExpr - sealed case class Abs(arg: IntExpr) extends IntExpr + sealed case class Plus(lhs: Term, rhs: Term) extends IntExpr { + require(Seq(lhs, rhs).forall { _.sort == IntSort }, "Plus is only applicable to arguments with Integer sorts.") + } + sealed case class Minus(lhs: Term, rhs: Term) extends IntExpr { + require(Seq(lhs, rhs).forall { _.sort == IntSort }, "Minus is only applicable to arguments with Integer sorts.") + } + sealed case class Uminus(arg: Term) extends IntExpr { + require(arg.sort == IntSort, "Uminus is only applicable to arguments with Integer sorts.") + } + sealed case class Mult(lhs: Term, rhs: Term) extends IntExpr { + require(Seq(lhs, rhs).forall { _.sort == IntSort }, "Mult is only applicable to arguments with Integer sorts.") + } + sealed case class Div(lhs: Term, rhs: Term) extends IntExpr { + require(Seq(lhs, rhs).forall { _.sort == IntSort }, "Div is only applicable to arguments with Integer sorts.") + } + sealed case class Mod(lhs: Term, rhs: Term) extends IntExpr { + require(Seq(lhs, rhs).forall { _.sort == IntSort }, "Mod is only applicable to arguments with Integer sorts.") + } + sealed case class Abs(arg: Term) extends IntExpr { + require(arg.sort == IntSort, "Abs is only applicable to arguments with Integer sorts.") + } - sealed case class Lt(lhs: IntExpr, rhs: IntExpr) extends Booleans.BoolExpr - sealed case class Le(lhs: IntExpr, rhs: IntExpr) extends Booleans.BoolExpr - sealed case class Gt(lhs: IntExpr, rhs: IntExpr) extends Booleans.BoolExpr - sealed case class Ge(lhs: IntExpr, rhs: IntExpr) extends Booleans.BoolExpr + sealed case class Lt(lhs: Term, rhs: Term) extends Booleans.BoolExpr { + require(Seq(lhs, rhs).forall { _.sort == IntSort }, "[<] is only applicable to arguments with Integer sorts.") + } + sealed case class Le(lhs: Term, rhs: Term) extends Booleans.BoolExpr { + require(Seq(lhs, rhs).forall { _.sort == IntSort }, "[<=] is only applicable to arguments with Integer sorts.") + } + sealed case class Gt(lhs: Term, rhs: Term) extends Booleans.BoolExpr { + require(Seq(lhs, rhs).forall { _.sort == IntSort }, "[>] is only applicable to arguments with Integer sorts.") + } + sealed case class Ge(lhs: Term, rhs: Term) extends Booleans.BoolExpr { + require(Seq(lhs, rhs).forall { _.sort == IntSort }, "[>=] is only applicable to arguments with Integer sorts.") + } sealed case class IntLiteral(i: BigInt) extends IntExpr sealed case class IntVar(name: String) extends Variable(name) with IntExpr