Skip to content

Commit

Permalink
Parse ASSUME declarations names
Browse files Browse the repository at this point in the history
  • Loading branch information
fan-tom committed Jan 20, 2024
1 parent 10dcc1e commit 6e469bd
Show file tree
Hide file tree
Showing 23 changed files with 113 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,10 @@ class JsonToTla[T <: JsonRepresentation](
opDecl

case "TlaAssumeDecl" =>
val definedName = declJson.getFieldOpt("name").map(scalaFactory.asStr)
val bodyField = getOrThrow(declJson, "body")
val body = asTlaEx(bodyField)
TlaAssumeDecl(body)(typeTag)
TlaAssumeDecl(definedName, body)(typeTag)
case _ => throw new JsonDeserializationError(s"$kind is not a valid TlaDecl kind")
}
setLoc(decl, getSourceLocationOpt(declJson))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,10 @@ class JsonToTlaViaBuilder[T <: JsonRepresentation](
opDecl

case "TlaAssumeDecl" =>
val definedName = declJson.getFieldOpt("name").map(scalaFactory.asStr)
val bodyField = getOrThrow(declJson, "body")
val body = asTBuilderInstruction(bodyField)
TlaAssumeDecl(body)(typeTag)
TlaAssumeDecl(definedName, body)(typeTag)
case _ => throw new JsonDeserializationError(s"$kind is not a valid TlaDecl kind")
}
setLoc(decl, getSourceLocationOpt(declJson))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,20 @@ class TlaToJson[T <: JsonRepresentation](
"body" -> bodyJson,
)

case TlaAssumeDecl(body) =>
case TlaAssumeDecl(definedName, body) =>
val bodyJson = apply(body)
withLoc(
val fields = Array[(String, T)](
typeFieldName -> typeTagPrinter(decl.typeTag),
kindFieldName -> "TlaAssumeDecl",
"body" -> bodyJson,
)

val f = definedName match {
case None => fields
case Some(name) => fields :+ ("name" -> name: (String, T))
}

withLoc(f.toIndexedSeq: _*)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,12 @@ class PrettyWriter(
"VARIABLE" <> nest(line <> wrapWithComment(annotations.get) <> line <> parseableName(name))
}

case TlaAssumeDecl(body) =>
val doc = group("ASSUME" <> parens(exToDoc((0, 0), body, nameResolver)))
case TlaAssumeDecl(definedName, body) =>
val doc = definedName match {
case None => group("ASSUME" <> parens(exToDoc((0, 0), body, nameResolver)))
case Some(name) => group("ASSUME" <+> name <+> "==" <+> parens(exToDoc((0, 0), body, nameResolver)))
}

if (annotations.isEmpty) {
doc
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -737,10 +737,11 @@ class Quint(quintOutput: QuintOutput) {
// no methods for them are provided by the ScopedBuilder.
case QuintConst(id, name, _) => Some(None, TlaConstDecl(name)(typeTagOfId(id)))
case QuintVar(id, name, _) => Some(None, TlaVarDecl(name)(typeTagOfId(id)))
case QuintAssume(_, _, quintEx) =>
case d @ QuintAssume(_, name, quintEx) =>
val tlaEx = build(tlaExpression(quintEx).run(nullaryOps))
val definedName = Option.unless(d.isUnnamed)(name)
// assume declarations have no entry in the type map, and are always typed bool
Some(None, TlaAssumeDecl(tlaEx)(Typed(BoolT1)))
Some(None, TlaAssumeDecl(definedName, tlaEx)(Typed(BoolT1)))
case op: QuintOpDef if op.qualifier == "run" =>
// We don't currently support run definitions
None
Expand Down
11 changes: 10 additions & 1 deletion tla-io/src/main/scala/at/forsyte/apalache/io/quint/QuintIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,16 @@ private[quint] object QuintDef {
name: String,
/** an expression to associate with the name */
assumption: QuintEx)
extends QuintDef {}
extends QuintDef {

/**
* @return
* true if this ASSUME clause has no user-defined name, false otherwise
*
* unnamed ASSUME clauses use `_` as a name
*/
def isUnnamed: Boolean = name == "_"
}
object QuintAssume {
implicit val rw: RW[QuintAssume] = macroRW
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class AssumeTranslator(
context,
OutsideRecursion(),
).translate(node.getAssume)
TlaAssumeDecl(body)(Untyped)
TlaAssumeDecl(Option(node.getDef).map(_.getName.toString), body)(Untyped)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,32 @@ class TestPrettyWriterWithTypes extends AnyFunSuite with BeforeAndAfterEach {
|""".stripMargin
assert(expected == stringWriter.toString)
}

test("unnamed assume declaration") {
val decl = TlaAssumeDecl(None, tla.eql(tla.name("x"), tla.bool(true)))
val store = createAnnotationStore()

val writer = new PrettyWriterWithAnnotations(store, printWriter, layout80)
writer.write(decl)
printWriter.flush()
val expected =
"""ASSUME(x = TRUE)
|
|""".stripMargin
assert(expected == stringWriter.toString)
}

test("named assume declaration") {
val decl = TlaAssumeDecl(Some("myAssume"), tla.eql(tla.name("x"), tla.bool(true)))
val store = createAnnotationStore()

val writer = new PrettyWriterWithAnnotations(store, printWriter, layout80)
writer.write(decl)
printWriter.flush()
val expected =
"""ASSUME myAssume == (x = TRUE)
|
|""".stripMargin
assert(expected == stringWriter.toString)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class TestUJsonToTla extends AnyFunSuite with Checkers {

val decls: Seq[TlaDecl] = Seq(
tla.declOp("X", tla.eql(tla.name("a"), tla.int(1)), OperParam("a")),
TlaAssumeDecl(tla.eql(tla.int(1), tla.int(0))),
TlaAssumeDecl(None, tla.eql(tla.int(1), tla.int(0))),
TlaConstDecl("c"),
TlaVarDecl("v"),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TestUJsonToTlaViaBuilder extends AnyFunSuite with Checkers {

val decls: Seq[TlaDecl] = Seq(
tla.decl("X", tla.eql(tla.name("a", IntT1), tla.int(1)), tla.param("a", IntT1)),
TlaAssumeDecl(tla.eql(tla.int(1), tla.int(0))),
TlaAssumeDecl(None, tla.eql(tla.int(1), tla.int(0))),
TlaConstDecl("c"),
TlaVarDecl("v"),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1808,12 +1808,14 @@ class TestSanyImporter extends SanyImporterTestBase {

test("assumptions") {
// checking that the assumptions are imported properly
val assumptionName = "nonZero"
val text =
"""
s"""
|---- MODULE assumptions ----
|CONSTANT N
|ASSUME N = 4
|ASSUME N /= 10
|ASSUME $assumptionName == N /= 0
|================================
|""".stripMargin

Expand All @@ -1825,20 +1827,28 @@ class TestSanyImporter extends SanyImporterTestBase {
expectSourceInfoInDefs(root)

modules(rootName).declarations(1) match {
case TlaAssumeDecl(e) => assert(eql(name("N"), int(4)).untyped() == e)
case TlaAssumeDecl(_, e) => assert(eql(name("N"), int(4)).untyped() == e)

case e @ _ => fail("expected an assumption, found: " + e)
}

modules(rootName).declarations(2) match {
case TlaAssumeDecl(e) => assert(neql(name("N"), int(10)).untyped() == e)
case TlaAssumeDecl(_, e) => assert(neql(name("N"), int(10)).untyped() == e)

case e @ _ => fail("expected an assumption, found: " + e)
}

modules(rootName).declarations(3) match {
case TlaAssumeDecl(definedName, e) =>
assert(neql(name("N"), int(0)).untyped() == e)
assert(definedName contains assumptionName)

case e @ _ => fail("expected an assumption, found: " + e)
}

// regression test for issue #25
val names = HashSet(modules(rootName).assumeDeclarations.map(_.name): _*)
assert(2 == names.size) // all assumptions must have unique names
assert(3 == names.size) // all assumptions must have unique names
}

test("ignore theorems") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ class ToEtcExpr(
val operType = OperT1(Seq(BoolT1), BoolT1)
val application = mkUniqApp(Seq(operType), this(d.body))
// We have to introduce a lambda abstraction, as the type checker is expecting this form.
mkLet(BlameRef(d.ID), "__Assume_" + d.ID, mkAbs(ExactRef(d.ID), application), inScopeEx)
mkLet(
BlameRef(d.ID),
"__Assume_" + d.definedName.getOrElse(d.ID.toString),
mkAbs(ExactRef(d.ID), application),
inScopeEx,
)

case d: TlaOperDecl =>
// Foo(x) == ...
Expand Down Expand Up @@ -131,7 +136,7 @@ class ToEtcExpr(
OperT1(nBools, BoolT1)
}

// Valid when the input seq has two items, the first of which is a VlaEx(TlaStr(_))
// Valid when the input seq has two items, the first of which is a ValEx(TlaStr(_))
private val validateRecordPair: Seq[TlaEx] => (String, TlaEx) = {
// Only pairs coordinating pairs and sets are valid. See TlaSetOper.recSet
case Seq(ValEx(TlaStr(name)), set) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class TypeRewriter(tracker: TransformationTracker, defaultTag: UID => TypeTag)(t
case d @ TlaVarDecl(_) =>
decl.withTag(getOrDefault(d.ID))

case d @ TlaAssumeDecl(body) =>
TlaAssumeDecl(this(body))(getOrDefault(d.ID))
case d @ TlaAssumeDecl(definedName, body) =>
TlaAssumeDecl(definedName, this(body))(getOrDefault(d.ID))

case d @ TlaTheoremDecl(name, body) =>
TlaTheoremDecl(name, this(body))(getOrDefault(d.ID))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class TestToEtcExprDecls extends AnyFunSuite with ToEtcExprBase with BeforeAndAf
}

test("assumes") {
val assume = TlaAssumeDecl(tla.name("x"))
val assume = TlaAssumeDecl(None, tla.name("x"))
val terminal = mkUniqConst(BoolT1)
// becomes:
// let Assume1 == ((Bool => Bool) "x") in
Expand All @@ -116,7 +116,8 @@ class TestToEtcExprDecls extends AnyFunSuite with ToEtcExprBase with BeforeAndAf
// The body is wrapped with the application of an operator that has the signature Bool => Bool.
// This allows us to check that the assumption has Boolean type.
val application = mkUniqApp(Seq(parser("Bool => Bool")), assumption)
val expected = mkUniqLet("__Assume_" + assume.ID, mkUniqAbs(application), terminal)
val expected =
mkUniqLet("__Assume_" + assume.definedName.getOrElse(assume.ID.toString), mkUniqAbs(application), terminal)
// Translate the declaration of positive.
// We have to pass the next expression in scope, which is just TRUE in this case.
assert(expected == mkToEtcExpr(Map())(assume, terminal))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,16 @@ case class TlaVarDecl(name: String)(implicit typeTag: TypeTag) extends TlaDecl w
/**
* An assumption defined by ASSUME(...)
*
* @param definedName
* optional assumption name, like name in `ASSUME name == x = 4`, or none, like in `ASSUME x = 4`
* @param body
* the assumption body
*/
case class TlaAssumeDecl(body: TlaEx)(implicit typeTag: TypeTag) extends TlaDecl with Serializable {
val name: String = "ASSUME" + body.ID
case class TlaAssumeDecl(definedName: Option[String], body: TlaEx)(implicit typeTag: TypeTag)
extends TlaDecl with Serializable {
override val name: String = definedName.getOrElse("ASSUME" + body.ID)

override def withTag(newTypeTag: TypeTag): TlaAssumeDecl = TlaAssumeDecl(body)(newTypeTag)
override def withTag(newTypeTag: TypeTag): TlaAssumeDecl = TlaAssumeDecl(definedName, body)(newTypeTag)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class TlaLevelFinder(module: TlaModule) {
case TlaVarDecl(_) =>
TlaLevelState

case TlaAssumeDecl(_) =>
case TlaAssumeDecl(_, _) =>
TlaLevelConst

case TlaOperDecl(name, _, body) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,11 @@ object UTFPrinter extends Printer {
case TlaVarDecl(name) =>
"VARIABLE " + name

case TlaAssumeDecl(body) =>
apply(body)
case TlaAssumeDecl(definedName, body) =>
definedName match {
case None => s"ASSUME " + apply(body)
case Some(name) => s"ASSUME $name == " + apply(body)
}

case TlaOperDecl(name, formalParams, body) =>
val ps =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class DeclarationSorter extends TlaModuleTransformation with LazyLogging {
case (map, d @ TlaVarDecl(_)) =>
map + (d.ID -> Set.empty[UID])

case (map, d @ TlaAssumeDecl(body)) =>
case (map, d @ TlaAssumeDecl(_, body)) =>
val uses = findExprUses(nameToId)(body) - d.ID
updateDependencies(map, d.ID, uses)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ class DeepCopy(tracker: TransformationTracker) {
def deepCopyDecl[T <: TlaDecl](decl: T): T = deepCopyDeclInternal(decl).asInstanceOf[T]

private def deepCopyDeclInternal: TlaDeclTransformation = tracker.trackDecl {
case d @ TlaAssumeDecl(bodyEx) => TlaAssumeDecl(deepCopyEx(bodyEx))(d.typeTag)
case d @ TlaTheoremDecl(name, body) => TlaTheoremDecl(name, deepCopyEx(body))(d.typeTag)
case d @ TlaVarDecl(name) => TlaVarDecl(name)(d.typeTag)
case d @ TlaAssumeDecl(name, bodyEx) => TlaAssumeDecl(name, deepCopyEx(bodyEx))(d.typeTag)
case d @ TlaTheoremDecl(name, body) => TlaTheoremDecl(name, deepCopyEx(body))(d.typeTag)
case d @ TlaVarDecl(name) => TlaVarDecl(name)(d.typeTag)
case d @ TlaOperDecl(name, formalParams, body) =>
val decl = TlaOperDecl(name, formalParams, deepCopyEx(body))(d.typeTag)
decl.isRecursive = d.isRecursive
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ class ModuleByExTransformer(
d.copy(body = newBody)
}

case d @ TlaAssumeDecl(body) =>
case d @ TlaAssumeDecl(definedName, body) =>
val newBody = exTrans(body)
if (newBody.ID == body.ID) {
d
} else {
TlaAssumeDecl(newBody)(d.typeTag)
TlaAssumeDecl(definedName, newBody)(d.typeTag)
}

case d => d
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ trait IrGenerators extends TlaType1Gen {
for {
ex <- exGen
tt <- genTypeTag
} yield TlaAssumeDecl(ex).withTag(tt)
} yield TlaAssumeDecl(None, ex).withTag(tt)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ class TestDeclarationSorter extends AnyFunSuite with BeforeAndAfterEach {

test("Assume uses Foo out of order") {
val foo = tla.declOp("Foo", tla.int(1))
val assume = TlaAssumeDecl(tla.appOp(tla.name("Foo")))
val input = new TlaModule("test", List(assume, foo))
val expected = new TlaModule("test", List(foo, assume))
val assume = TlaAssumeDecl(None, tla.appOp(tla.name("Foo")))
val input = TlaModule("test", List(assume, foo))
val expected = TlaModule("test", List(foo, assume))
assert(expected == DeclarationSorter.instance(input))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class TestDeepCopy extends AnyFunSuite with BeforeAndAfter with Checkers {
case (l @ TlaVarDecl(lname), r @ TlaVarDecl(rname)) =>
l.ID != r.ID && lname == rname && l.typeTag == r.typeTag

case (l @ TlaAssumeDecl(lbody), r @ TlaAssumeDecl(rbody)) =>
l.ID != r.ID && equalExCopies(lbody, rbody) && l.typeTag == r.typeTag
case (l @ TlaAssumeDecl(lname, lbody), r @ TlaAssumeDecl(rname, rbody)) =>
l.ID != r.ID && lname == rname && equalExCopies(lbody, rbody) && l.typeTag == r.typeTag

case (l @ TlaTheoremDecl(lname, lbody), r @ TlaTheoremDecl(rname, rbody)) =>
l.ID != r.ID &&
Expand Down

0 comments on commit 6e469bd

Please sign in to comment.