diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/caches/RecordDomainCache.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/caches/RecordDomainCache.scala new file mode 100644 index 0000000000..23bef5c036 --- /dev/null +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/caches/RecordDomainCache.scala @@ -0,0 +1,56 @@ +package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches + +import at.forsyte.apalache.tla.bmcmt.smt.SolverContext +import at.forsyte.apalache.tla.bmcmt.types.CellT +import at.forsyte.apalache.tla.bmcmt.{ArenaCell, FixedElemPtr, PureArena} +import at.forsyte.apalache.tla.lir.{SetT1, StrT1} +import at.forsyte.apalache.tla.types.tla + +import scala.collection.immutable.SortedSet + +/** + * Since we have to create record domains many times, we cache them here. + * + * @author + * Jure Kukovec + */ +class RecordDomainCache(strValueCache: UninterpretedLiteralCache) + extends Cache[PureArena, SortedSet[String], (ArenaCell, Seq[ArenaCell])] { + + /** + * Given a set of `keys`, returns a tuple `(rArena, (setCell, allCells))`, where: + * - `setCell` is the cell representing the set `keys` + * - `allCells` is s sequence of cells `c_1, c_2,..., c_|keys|`, representing the set contents (e.g. "a", "b", ...). + * - `rArena` is an extension of `arena`, containing all of the above cells, and a relation + * {{{setCell --(has)--> c_1, c_2,..., c_|keys|}}} + * + * Note that this method internally calls `strValueCache.getOrCreate`, which caches all strings in the set. + */ + override protected def create( + arena: PureArena, + keys: SortedSet[String]): (PureArena, (ArenaCell, Seq[ArenaCell])) = { + + val (arenaWithCachedStrs, allCells) = keys.toList.foldLeft((arena, Seq.empty[ArenaCell])) { + case ((partialArena, partialCells), key) => + val (newArena, cell) = strValueCache.getOrCreate(partialArena, (StrT1, key)) + (newArena, partialCells :+ cell) + } + + // create the domain cell + val arenaWithDomCell = arenaWithCachedStrs.appendCell(CellT.fromType1(SetT1(StrT1))) + val setCell = arenaWithDomCell.topCell + val arenaWithHas = arenaWithDomCell.appendHas(setCell, allCells.map(FixedElemPtr): _*) + + (arenaWithHas, (setCell, allCells)) + } + + /** Return a function to add implementation-specific constraints for a single entry */ + override def addConstraintsForElem(ctx: SolverContext): ((SortedSet[String], (ArenaCell, Seq[ArenaCell]))) => Unit = { + case (_, (setCell, elemCells)) => + elemCells.foreach { elemCell => + // We _know_ the pointer is fixed TRUE by construction, so instead of asserting X == true, we just assert X, where + // X = elemCell \in setCell + ctx.assertGroundExpr(tla.in(elemCell.toBuilder, setCell.toBuilder)) + } + } +} diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/RecordDomainCacheTest.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/RecordDomainCacheTest.scala new file mode 100644 index 0000000000..004243cf2a --- /dev/null +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/RecordDomainCacheTest.scala @@ -0,0 +1,97 @@ +package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux + +import at.forsyte.apalache.tla.bmcmt.PureArena +import at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches.{RecordDomainCache, UninterpretedLiteralCache} +import at.forsyte.apalache.tla.lir.{StrT1, TlaType1} +import at.forsyte.apalache.tla.types.{tla, ModelValueHandler} +import org.junit.runner.RunWith +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite +import org.scalatestplus.junit.JUnitRunner + +import scala.collection.immutable.SortedSet + +@RunWith(classOf[JUnitRunner]) +class RecordDomainCacheTest extends AnyFunSuite with BeforeAndAfterEach { + + var ulitCache: UninterpretedLiteralCache = new UninterpretedLiteralCache + var cache: RecordDomainCache = new RecordDomainCache(ulitCache) + + def tpAndIdx(s: String): (TlaType1, String) = { + val (utype, idx) = ModelValueHandler.typeAndIndex(s).getOrElse((StrT1, s)) + (utype, idx) + } + + override def beforeEach(): Unit = { + ulitCache = new UninterpretedLiteralCache + cache = new RecordDomainCache(ulitCache) + } + + test("Cache returns stored values after the first call to getOrCreate") { + val keys: SortedSet[String] = SortedSet("a", "b", "c") + + val arena = PureArena.empty + + // No cached value for the pair + assert(cache.get(keys).isEmpty) + + val (newArena, (cell, elemCells)) = cache.getOrCreate(arena, keys) + + // set is now cached, arena has changed + assert(cache.get(keys).nonEmpty && newArena != arena) + + val (newArena2, (cell2, elemCells2)) = cache.getOrCreate(newArena, keys) + + // 2nd call returns the _same_ arena and the previously computed cells + assert(newArena == newArena2 && cell == cell2 && elemCells == elemCells2) + } + + test("Constraints are only added when addConstraintsForElem is explicitly called, and only once per value") { + val mockCtx: MockZ3SolverContext = new MockZ3SolverContext + + val k1: SortedSet[String] = SortedSet.empty[String] + val k2: SortedSet[String] = SortedSet("a", "b") + val k3: SortedSet[String] = SortedSet("a", "c") + + val a0 = PureArena.empty + val (a1, (cell1, elemCells1)) = cache.getOrCreate(a0, k1) + // Some extra calls, which shouldn't affect constraint generation + cache.getOrCreate(a0, k1) + cache.getOrCreate(a0, k1) + val (a2, (cell2, elemCells2)) = cache.getOrCreate(a1, k2) + // Some extra calls, which shouldn't affect constraint generation + cache.getOrCreate(a1, k2) + cache.getOrCreate(a1, k2) + val (_, (cell3, elemCells3)) = cache.getOrCreate(a2, k3) + // Some extra calls, which shouldn't affect constraint generation + cache.getOrCreate(a2, k3) + cache.getOrCreate(a2, k3) + + assert(mockCtx.constraints.isEmpty) + + // "a" is a member of 2 sets, but it only gets cached into 1 cell + assert(ulitCache.values().size == 3) + + cache.addAllConstraints(mockCtx) + + // Dependent caches don't discharge constraints unless called, so we should have 0 constraints + // from UninterpretedLiteralCache, only the ones from the domain cache + assert(mockCtx.constraints.size == k1.size + k2.size + k3.size) + assert( + elemCells1.forall { c => + mockCtx.constraints.contains(tla.in(c.toBuilder, cell1.toBuilder).build) + } + ) + assert( + elemCells2.forall { c => + mockCtx.constraints.contains(tla.in(c.toBuilder, cell2.toBuilder).build) + } + ) + assert( + elemCells3.forall { c => + mockCtx.constraints.contains(tla.in(c.toBuilder, cell3.toBuilder).build) + } + ) + } + +}