Skip to content

Commit

Permalink
ModelValueCache rework (#2675)
Browse files Browse the repository at this point in the history
* Add cache + tests

* Apply suggestions from code review

Co-authored-by: Shon Feder <[email protected]>

* PR comments

* Pr comments

* test fix

* Update tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/caches/UninterpretedLiteralCache.scala

Co-authored-by: Thomas Pani <[email protected]>

---------

Co-authored-by: Shon Feder <[email protected]>
Co-authored-by: Thomas Pani <[email protected]>
  • Loading branch information
3 people authored Aug 10, 2023
1 parent b374361 commit d9a9e90
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches

import at.forsyte.apalache.tla.bmcmt.smt.SolverContext
import at.forsyte.apalache.tla.bmcmt.types.{CellT, CellTFrom}
import at.forsyte.apalache.tla.bmcmt.{ArenaCell, PureArena}
import at.forsyte.apalache.tla.lir.{ConstT1, StrT1, TlaType1}
import at.forsyte.apalache.tla.types.tla

/**
* A cache for uninterpreted literals, that are translated to uninterpreted SMT constants, with a unique sort per
* uninterpreted type. Since two values are equal iff they are literally the same literal, we force inequality between
* all the respective SMT constants.
*
* Note that Strings are just a special kind of uninterpreted type.
*
* @author
* Jure Kukovec
*/
class UninterpretedLiteralCache extends Cache[PureArena, (TlaType1, String), ArenaCell] {

/**
* Given a pair `(utype,idx)`, where `utype` represents an uninterpreted type name (possibly "Str") and `idx` some
* unique index within that type, returns an extension of `arena`, containing a cell, which represents "idx_OF_utype"
* (or "idx", if utype = "Str"), and said cell.
*
* Note that two values are equal (and get cached to the same cell) iff they have the same type and the same index, so
* e.g. "1_OF_A" and "1_OF_B" (passed here as ("A", "1") and ("B", "1")) get cached to different, incomparable cells,
* despite having the same index "1".
*/
protected def create(
arena: PureArena,
typeAndIndex: (TlaType1, String)): (PureArena, ArenaCell) = {
val (utype, _) = typeAndIndex
require(utype == StrT1 || utype.isInstanceOf[ConstT1], "Type must be Str, or an uninterpreted type.")
// introduce a new cell
val newArena = arena.appendCell(CellT.fromType1(utype))
(newArena, newArena.topCell)
}

/**
* The UninterpretedLiteralCache maintains that a cell cache for a value `idx` of type `tp` is distinct from all other
* values of type `tp` (defined so far).
*
* Whenever possible, try to use [[addAllConstraints]] instead of this method, for performance reasons instead:
*
* If we consider a naive implementation of `distinct(a1,..., an)` as `a1 != a2 /\ a1 != a3 /\ ... /\ a{n-1} != an`, a
* `distinct` with `n` elements is equivalent to `dn = n(n-1)/2` disequalities. Suppose we end up with a collection of
* `N` cache values (of a given type). If we called `addConstaintsForElem` after each addition, we'd end up with `d1 +
* d2 + ... + dN` disequalities, i.e. {{{\sum_{n=1}^N n(n-1)/2 = N(N^2 -1)/6}}} In contrast, `addAllConstraints`
* produces `dN = N(N-1)/2` disequalities, which is `O(N^2)`, instead of `O(N^3)`.
*/
override def addConstraintsForElem(ctx: SolverContext): (((TlaType1, String), ArenaCell)) => Unit = {
case ((utype, _), v) =>
require(utype == StrT1 || utype.isInstanceOf[ConstT1], "Type must be Str, or an uninterpreted type.")
val others = values().withFilter { c => c.cellType == CellTFrom(utype) && c != v }.map(_.toBuilder).toSeq
// The cell should differ from the previously created cells.
// We use the SMT constraint (distinct ...).
ctx.assertGroundExpr(tla.distinct(v.toBuilder +: others: _*))
}

/**
* A more efficient implementation, compared to the default one, as it introduces exactly one SMT `distinct` for each
* uninterpreted type instead of one `distinct` per cell.
*/
override def addAllConstraints(ctx: SolverContext): Unit = {
val utypes = cache.keySet.map { _._1 }

val initMap = utypes.map { _ -> Set.empty[ArenaCell] }.toMap

val cellsByUtype = cache.foldLeft(initMap) { case (map, ((utype, _), (cell, _))) =>
map + (utype -> (map(utype) + cell))
}

// For each utype, all cells of that type are distinct
cellsByUtype.foreach { case (_, cells) =>
ctx.assertGroundExpr(tla.distinct(cells.toSeq.map { _.toBuilder }: _*))
}

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux

import at.forsyte.apalache.tla.bmcmt.PureArena
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches.UninterpretedLiteralCache
import at.forsyte.apalache.tla.lir.{StrT1, TlaType1}
import at.forsyte.apalache.tla.types.{tla, ModelValueHandler}
import org.junit.runner.RunWith
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.scalatestplus.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class UninterpretedLiteralCacheTest extends AnyFunSuite with BeforeAndAfterEach {

var cache: UninterpretedLiteralCache = new UninterpretedLiteralCache

def tpAndIdx(s: String): (TlaType1, String) = {
val (utype, idx) = ModelValueHandler.typeAndIndex(s).getOrElse((StrT1, s))
(utype, idx)
}

override def beforeEach(): Unit = {
cache = new UninterpretedLiteralCache
}

test("Cache returns stored values after the first call to getOrCreate") {
val str: String = "idx"

val utypeAndIdx = tpAndIdx(str)

val arena = PureArena.empty

// No cached value for the pair
assert(cache.get(utypeAndIdx).isEmpty)

val (newArena, iCell) = cache.getOrCreate(arena, utypeAndIdx)

// pair now cached, arena has changed
assert(cache.get(utypeAndIdx).nonEmpty && newArena != arena)

val (newArena2, iCell2) = cache.getOrCreate(newArena, utypeAndIdx)

// 2nd call returns the _same_ arena and the previously computed cell
assert(newArena == newArena2 && iCell == iCell2)
}

test("Same index of different types is cached separately") {
val str1: String = "idx"
val str2: String = "idx_OF_A"
val str3: String = "idx_OF_B"

val pa1 = tpAndIdx(str1)
val pa2 = tpAndIdx(str2)
val pa3 = tpAndIdx(str3)

val arena = PureArena.empty

val (newArena1, cell1) = cache.getOrCreate(arena, pa1)

assert(arena != newArena1)

val (newArena2, cell2) = cache.getOrCreate(newArena1, pa2)

assert(newArena2 != newArena1 && cell2 != cell1)

val (newArena3, cell3) = cache.getOrCreate(newArena2, pa3)

assert(newArena3 != newArena2 && cell3 != cell2)
}

test("Constraints are only added when addAllConstraints is explicitly called, and only once per value") {
val mockCtx: MockZ3SolverContext = new MockZ3SolverContext

val str1: String = "1_OF_A"
val str2: String = "2_OF_A"
val str3: String = "3_OF_A"

val pa1 = tpAndIdx(str1)
val pa2 = tpAndIdx(str2)
val pa3 = tpAndIdx(str3)

val a0 = PureArena.empty
val (a1, c1) = cache.getOrCreate(a0, pa1)
// Some extra calls, which shouldn't affect constraint generation
cache.getOrCreate(a0, pa1)
cache.getOrCreate(a0, pa1)
val (a2, c2) = cache.getOrCreate(a1, pa2)
// Some extra calls, which shouldn't affect constraint generation
cache.getOrCreate(a1, pa2)
cache.getOrCreate(a1, pa2)
val (_, c3) = cache.getOrCreate(a2, pa3)
// Some extra calls, which shouldn't affect constraint generation
cache.getOrCreate(a2, pa3)
cache.getOrCreate(a2, pa3)

assert(mockCtx.constraints.isEmpty)

cache.addAllConstraints(mockCtx)

// Due to the optimized `addAllConstraints` override, we only have 1 "distinct"
assert(mockCtx.constraints == Seq(
tla.distinct(c3.toBuilder, c2.toBuilder, c1.toBuilder).build
))
}

test("Constraints are only added when addConstraintsForElem is explicitly called, and only once per value") {
val mockCtx: MockZ3SolverContext = new MockZ3SolverContext

val str1: String = "1_OF_A"
val str2: String = "2_OF_A"
val str3: String = "3_OF_A"

val pa1 = tpAndIdx(str1)
val pa2 = tpAndIdx(str2)
val pa3 = tpAndIdx(str3)

val a0 = PureArena.empty
val (a1, c1) = cache.getOrCreate(a0, pa1)
// Some extra calls, which shouldn't affect constraint generation
cache.getOrCreate(a0, pa1)
cache.getOrCreate(a0, pa1)

cache.addConstraintsForElem(mockCtx)(pa1, c1)

val (a2, c2) = cache.getOrCreate(a1, pa2)
// Some extra calls, which shouldn't affect constraint generation
cache.getOrCreate(a1, pa2)
cache.getOrCreate(a1, pa2)

cache.addConstraintsForElem(mockCtx)(pa2, c2)

val (_, c3) = cache.getOrCreate(a2, pa3)
// Some extra calls, which shouldn't affect constraint generation
cache.getOrCreate(a2, pa3)
cache.getOrCreate(a2, pa3)

cache.addConstraintsForElem(mockCtx)(pa3, c3)

// -ForElem creates 3 "distinct" constraints
assert(mockCtx.constraints == Seq(
tla.distinct(c1.toBuilder).build,
tla.distinct(c2.toBuilder, c1.toBuilder).build,
tla.distinct(c3.toBuilder, c1.toBuilder, c2.toBuilder).build,
))
}

}

0 comments on commit d9a9e90

Please sign in to comment.