diff --git a/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/caches/UninterpretedLiteralCache.scala b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/caches/UninterpretedLiteralCache.scala new file mode 100644 index 0000000000..2d1a7b9dac --- /dev/null +++ b/tla-bmcmt/src/main/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/caches/UninterpretedLiteralCache.scala @@ -0,0 +1,80 @@ +package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches + +import at.forsyte.apalache.tla.bmcmt.smt.SolverContext +import at.forsyte.apalache.tla.bmcmt.types.{CellT, CellTFrom} +import at.forsyte.apalache.tla.bmcmt.{ArenaCell, PureArena} +import at.forsyte.apalache.tla.lir.{ConstT1, StrT1, TlaType1} +import at.forsyte.apalache.tla.types.tla + +/** + * A cache for uninterpreted literals, that are translated to uninterpreted SMT constants, with a unique sort per + * uninterpreted type. Since two values are equal iff they are literally the same literal, we force inequality between + * all the respective SMT constants. + * + * Note that Strings are just a special kind of uninterpreted type. + * + * @author + * Jure Kukovec + */ +class UninterpretedLiteralCache extends Cache[PureArena, (TlaType1, String), ArenaCell] { + + /** + * Given a pair `(utype,idx)`, where `utype` represents an uninterpreted type name (possibly "Str") and `idx` some + * unique index within that type, returns an extension of `arena`, containing a cell, which represents "idx_OF_utype" + * (or "idx", if utype = "Str"), and said cell. + * + * Note that two values are equal (and get cached to the same cell) iff they have the same type and the same index, so + * e.g. "1_OF_A" and "1_OF_B" (passed here as ("A", "1") and ("B", "1")) get cached to different, incomparable cells, + * despite having the same index "1". + */ + protected def create( + arena: PureArena, + typeAndIndex: (TlaType1, String)): (PureArena, ArenaCell) = { + val (utype, _) = typeAndIndex + require(utype == StrT1 || utype.isInstanceOf[ConstT1], "Type must be Str, or an uninterpreted type.") + // introduce a new cell + val newArena = arena.appendCell(CellT.fromType1(utype)) + (newArena, newArena.topCell) + } + + /** + * The UninterpretedLiteralCache maintains that a cell cache for a value `idx` of type `tp` is distinct from all other + * values of type `tp` (defined so far). + * + * Whenever possible, try to use [[addAllConstraints]] instead of this method, for performance reasons instead: + * + * If we consider a naive implementation of `distinct(a1,..., an)` as `a1 != a2 /\ a1 != a3 /\ ... /\ a{n-1} != an`, a + * `distinct` with `n` elements is equivalent to `dn = n(n-1)/2` disequalities. Suppose we end up with a collection of + * `N` cache values (of a given type). If we called `addConstaintsForElem` after each addition, we'd end up with `d1 + + * d2 + ... + dN` disequalities, i.e. {{{\sum_{n=1}^N n(n-1)/2 = N(N^2 -1)/6}}} In contrast, `addAllConstraints` + * produces `dN = N(N-1)/2` disequalities, which is `O(N^2)`, instead of `O(N^3)`. + */ + override def addConstraintsForElem(ctx: SolverContext): (((TlaType1, String), ArenaCell)) => Unit = { + case ((utype, _), v) => + require(utype == StrT1 || utype.isInstanceOf[ConstT1], "Type must be Str, or an uninterpreted type.") + val others = values().withFilter { c => c.cellType == CellTFrom(utype) && c != v }.map(_.toBuilder).toSeq + // The cell should differ from the previously created cells. + // We use the SMT constraint (distinct ...). + ctx.assertGroundExpr(tla.distinct(v.toBuilder +: others: _*)) + } + + /** + * A more efficient implementation, compared to the default one, as it introduces exactly one SMT `distinct` for each + * uninterpreted type instead of one `distinct` per cell. + */ + override def addAllConstraints(ctx: SolverContext): Unit = { + val utypes = cache.keySet.map { _._1 } + + val initMap = utypes.map { _ -> Set.empty[ArenaCell] }.toMap + + val cellsByUtype = cache.foldLeft(initMap) { case (map, ((utype, _), (cell, _))) => + map + (utype -> (map(utype) + cell)) + } + + // For each utype, all cells of that type are distinct + cellsByUtype.foreach { case (_, cells) => + ctx.assertGroundExpr(tla.distinct(cells.toSeq.map { _.toBuilder }: _*)) + } + + } +} diff --git a/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/UninterpretedLiteralCacheTest.scala b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/UninterpretedLiteralCacheTest.scala new file mode 100644 index 0000000000..868cfa9e87 --- /dev/null +++ b/tla-bmcmt/src/test/scala/at/forsyte/apalache/tla/bmcmt/stratifiedRules/aux/UninterpretedLiteralCacheTest.scala @@ -0,0 +1,147 @@ +package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux + +import at.forsyte.apalache.tla.bmcmt.PureArena +import at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.caches.UninterpretedLiteralCache +import at.forsyte.apalache.tla.lir.{StrT1, TlaType1} +import at.forsyte.apalache.tla.types.{tla, ModelValueHandler} +import org.junit.runner.RunWith +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite +import org.scalatestplus.junit.JUnitRunner + +@RunWith(classOf[JUnitRunner]) +class UninterpretedLiteralCacheTest extends AnyFunSuite with BeforeAndAfterEach { + + var cache: UninterpretedLiteralCache = new UninterpretedLiteralCache + + def tpAndIdx(s: String): (TlaType1, String) = { + val (utype, idx) = ModelValueHandler.typeAndIndex(s).getOrElse((StrT1, s)) + (utype, idx) + } + + override def beforeEach(): Unit = { + cache = new UninterpretedLiteralCache + } + + test("Cache returns stored values after the first call to getOrCreate") { + val str: String = "idx" + + val utypeAndIdx = tpAndIdx(str) + + val arena = PureArena.empty + + // No cached value for the pair + assert(cache.get(utypeAndIdx).isEmpty) + + val (newArena, iCell) = cache.getOrCreate(arena, utypeAndIdx) + + // pair now cached, arena has changed + assert(cache.get(utypeAndIdx).nonEmpty && newArena != arena) + + val (newArena2, iCell2) = cache.getOrCreate(newArena, utypeAndIdx) + + // 2nd call returns the _same_ arena and the previously computed cell + assert(newArena == newArena2 && iCell == iCell2) + } + + test("Same index of different types is cached separately") { + val str1: String = "idx" + val str2: String = "idx_OF_A" + val str3: String = "idx_OF_B" + + val pa1 = tpAndIdx(str1) + val pa2 = tpAndIdx(str2) + val pa3 = tpAndIdx(str3) + + val arena = PureArena.empty + + val (newArena1, cell1) = cache.getOrCreate(arena, pa1) + + assert(arena != newArena1) + + val (newArena2, cell2) = cache.getOrCreate(newArena1, pa2) + + assert(newArena2 != newArena1 && cell2 != cell1) + + val (newArena3, cell3) = cache.getOrCreate(newArena2, pa3) + + assert(newArena3 != newArena2 && cell3 != cell2) + } + + test("Constraints are only added when addAllConstraints is explicitly called, and only once per value") { + val mockCtx: MockZ3SolverContext = new MockZ3SolverContext + + val str1: String = "1_OF_A" + val str2: String = "2_OF_A" + val str3: String = "3_OF_A" + + val pa1 = tpAndIdx(str1) + val pa2 = tpAndIdx(str2) + val pa3 = tpAndIdx(str3) + + val a0 = PureArena.empty + val (a1, c1) = cache.getOrCreate(a0, pa1) + // Some extra calls, which shouldn't affect constraint generation + cache.getOrCreate(a0, pa1) + cache.getOrCreate(a0, pa1) + val (a2, c2) = cache.getOrCreate(a1, pa2) + // Some extra calls, which shouldn't affect constraint generation + cache.getOrCreate(a1, pa2) + cache.getOrCreate(a1, pa2) + val (_, c3) = cache.getOrCreate(a2, pa3) + // Some extra calls, which shouldn't affect constraint generation + cache.getOrCreate(a2, pa3) + cache.getOrCreate(a2, pa3) + + assert(mockCtx.constraints.isEmpty) + + cache.addAllConstraints(mockCtx) + + // Due to the optimized `addAllConstraints` override, we only have 1 "distinct" + assert(mockCtx.constraints == Seq( + tla.distinct(c3.toBuilder, c2.toBuilder, c1.toBuilder).build + )) + } + + test("Constraints are only added when addConstraintsForElem is explicitly called, and only once per value") { + val mockCtx: MockZ3SolverContext = new MockZ3SolverContext + + val str1: String = "1_OF_A" + val str2: String = "2_OF_A" + val str3: String = "3_OF_A" + + val pa1 = tpAndIdx(str1) + val pa2 = tpAndIdx(str2) + val pa3 = tpAndIdx(str3) + + val a0 = PureArena.empty + val (a1, c1) = cache.getOrCreate(a0, pa1) + // Some extra calls, which shouldn't affect constraint generation + cache.getOrCreate(a0, pa1) + cache.getOrCreate(a0, pa1) + + cache.addConstraintsForElem(mockCtx)(pa1, c1) + + val (a2, c2) = cache.getOrCreate(a1, pa2) + // Some extra calls, which shouldn't affect constraint generation + cache.getOrCreate(a1, pa2) + cache.getOrCreate(a1, pa2) + + cache.addConstraintsForElem(mockCtx)(pa2, c2) + + val (_, c3) = cache.getOrCreate(a2, pa3) + // Some extra calls, which shouldn't affect constraint generation + cache.getOrCreate(a2, pa3) + cache.getOrCreate(a2, pa3) + + cache.addConstraintsForElem(mockCtx)(pa3, c3) + + // -ForElem creates 3 "distinct" constraints + assert(mockCtx.constraints == Seq( + tla.distinct(c1.toBuilder).build, + tla.distinct(c2.toBuilder, c1.toBuilder).build, + tla.distinct(c3.toBuilder, c1.toBuilder, c2.toBuilder).build, + )) + } + +}