From 93798d6de672e0953972ef58754758448125056f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Pereira?= Date: Thu, 21 Sep 2023 10:55:21 +0200 Subject: [PATCH] cleanup --- .../scala/viper/gobra/frontend/Desugar.scala | 1 + .../typing/ghost/GhostExprTyping.scala | 1 + .../gobra/translator/context/Context.scala | 2 +- .../combinators/FinalTypeEncoding.scala | 4 +- .../encodings/combinators/TypeEncoding.scala | 22 ++-------- .../combinators/TypeEncodingCombiner.scala | 2 +- .../encodings/maps/MapEncoding.scala | 40 ++++++++++++++++++- .../gobra/translator/util/ViperWriter.scala | 8 +++- .../features/maps/maps-simple1.gobra | 17 +++++--- 9 files changed, 68 insertions(+), 29 deletions(-) diff --git a/src/main/scala/viper/gobra/frontend/Desugar.scala b/src/main/scala/viper/gobra/frontend/Desugar.scala index 9a0704d51..cf4cf6c0c 100644 --- a/src/main/scala/viper/gobra/frontend/Desugar.scala +++ b/src/main/scala/viper/gobra/frontend/Desugar.scala @@ -4304,6 +4304,7 @@ object Desugar extends LazyLogging { } yield underlyingType(dright.typ) match { case _: in.SequenceT | _: in.SetT => in.Contains(dleft, dright)(src) case _: in.MultisetT => in.LessCmp(in.IntLit(0)(src), in.Contains(dleft, dright)(src))(src) + case _: in.MapT => in.Contains(dleft, dright)(src) case t => violation(s"expected a sequence or (multi)set type, but got $t") } diff --git a/src/main/scala/viper/gobra/frontend/info/implementation/typing/ghost/GhostExprTyping.scala b/src/main/scala/viper/gobra/frontend/info/implementation/typing/ghost/GhostExprTyping.scala index a49c58c22..d52d4d17a 100644 --- a/src/main/scala/viper/gobra/frontend/info/implementation/typing/ghost/GhostExprTyping.scala +++ b/src/main/scala/viper/gobra/frontend/info/implementation/typing/ghost/GhostExprTyping.scala @@ -139,6 +139,7 @@ trait GhostExprTyping extends BaseTyping { this: TypeInfoImpl => case PIn(left, right) => isExpr(left).out ++ isExpr(right).out ++ { underlyingType(exprType(right)) match { case t : GhostCollectionType => ghostComparableTypes.errors(exprType(left), t.elem)(expr) + case t : MapT => ghostComparableTypes.errors(exprType(left), t.key)(expr) case _ : AdtT => noMessages case t => error(right, s"expected a ghost collection, but got $t") } diff --git a/src/main/scala/viper/gobra/translator/context/Context.scala b/src/main/scala/viper/gobra/translator/context/Context.scala index e6aab54e5..224a5c613 100644 --- a/src/main/scala/viper/gobra/translator/context/Context.scala +++ b/src/main/scala/viper/gobra/translator/context/Context.scala @@ -87,7 +87,7 @@ trait Context { def expression(x: in.Expr): CodeWriter[vpr.Exp] = typeEncoding.finalExpression(this)(x) - def triggerExpr(x: in.TriggerExpr): CodeWriter[vpr.Exp] = typeEncoding.finalTriggerExpr(this)(x) + def triggerExpr(x: in.TriggerExpr): CodeWriter[vpr.Exp] = typeEncoding.triggerExpr(this)(x) def assertion(x: in.Assertion): CodeWriter[vpr.Exp] = typeEncoding.finalAssertion(this)(x) diff --git a/src/main/scala/viper/gobra/translator/encodings/combinators/FinalTypeEncoding.scala b/src/main/scala/viper/gobra/translator/encodings/combinators/FinalTypeEncoding.scala index 6434533d9..21144f527 100644 --- a/src/main/scala/viper/gobra/translator/encodings/combinators/FinalTypeEncoding.scala +++ b/src/main/scala/viper/gobra/translator/encodings/combinators/FinalTypeEncoding.scala @@ -46,6 +46,9 @@ class FinalTypeEncoding(te: TypeEncoding) extends TypeEncoding { override def equal(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = te.equal(ctx) orElse expectedMatch("equal") override def goEqual(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = te.goEqual(ctx) orElse expectedMatch("equal") override def expression(ctx: Context): in.Expr ==> CodeWriter[vpr.Exp] = te.expression(ctx) orElse expectedMatch("expression") + override def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = te.triggerExpr(ctx) orElse { + case e: in.Expr => te.expression(ctx)(e) + } override def assertion(ctx: Context): in.Assertion ==> CodeWriter[vpr.Exp] = te.assertion(ctx) orElse expectedMatch("assertion") override def reference(ctx: Context): in.Location ==> CodeWriter[vpr.Exp] = te.reference(ctx) orElse expectedMatch("reference") override def addressFootprint(ctx: Context): (in.Location, in.Expr) ==> CodeWriter[vpr.Exp] = te.addressFootprint(ctx) orElse expectedMatch("addressFootprint") @@ -61,7 +64,6 @@ class FinalTypeEncoding(te: TypeEncoding) extends TypeEncoding { override def extendFunction(ctx: Context): in.Member ==> Extension[MemberWriter[vpr.Function]] = te.extendFunction(ctx) orElse { _ => identity } override def extendPredicate(ctx: Context): in.Member ==> Extension[MemberWriter[vpr.Predicate]] = te.extendPredicate(ctx) orElse { _ => identity } override def extendExpression(ctx: Context): in.Expr ==> Extension[CodeWriter[vpr.Exp]] = te.extendExpression(ctx) orElse { _ => identity } - override def extendTriggerExpr(ctx: Context): in.TriggerExpr ==> Extension[CodeWriter[vpr.Exp]] = te.extendTriggerExpr(ctx) orElse { _ => identity } override def extendAssertion(ctx: Context): in.Assertion ==> Extension[CodeWriter[vpr.Exp]] = te.extendAssertion(ctx) orElse { _ => identity } override def extendStatement(ctx: Context): in.Stmt ==> Extension[CodeWriter[vpr.Stmt]] = te.extendStatement(ctx) orElse { _ => identity } } diff --git a/src/main/scala/viper/gobra/translator/encodings/combinators/TypeEncoding.scala b/src/main/scala/viper/gobra/translator/encodings/combinators/TypeEncoding.scala index 973f22040..5d30ff2e2 100644 --- a/src/main/scala/viper/gobra/translator/encodings/combinators/TypeEncoding.scala +++ b/src/main/scala/viper/gobra/translator/encodings/combinators/TypeEncoding.scala @@ -234,20 +234,18 @@ trait TypeEncoding extends Generator { } // TODO: doc + // TODO: optimize assert2(true, ...) + // TODO: key in map, instead of key in domain(map) // TODO: enable consistency checks on triggers when --checkConsistency - def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = PartialFunction.empty - - /* - { + def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = { // use predicate access encoding but then take just the predicate access, i.e. remove `acc` and the permission amount: case in.Accessible.Predicate(op) => for { v <- ctx.assertion(in.Access(in.Accessible.Predicate(op), in.FullPerm(op.info))(op.info)) pap = v.asInstanceOf[vpr.PredicateAccessPredicate] } yield pap.loc - case e: in.Expr => ctx.expression(e) + // case e: in.Expr => ctx.expression(e) } - */ /** * Encodes assertions. @@ -422,18 +420,6 @@ trait TypeEncoding extends Generator { val f = expression(ctx); { case n@f(v) => extendExpression(ctx).lift(n).fold(v)(_(v)) } } - /** Adds to the encoding of [[triggerExpr]]. The extension is applied to the result of the final trigger expression - * encoding. - */ - def extendTriggerExpr(@unused ctx: Context): in.TriggerExpr ==> Extension[CodeWriter[vpr.Exp]] = PartialFunction.empty - - final def finalTriggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = { - val f = triggerExpr(ctx); - { - case n@f(v) => extendTriggerExpr(ctx).lift(n).fold(v)(_(v)) - } - } - /** Adds to the encoding of [[assertion]]. The extension is applied to the result of the final assertion encoding. */ def extendAssertion(@unused ctx: Context): in.Assertion ==> Extension[CodeWriter[vpr.Exp]] = PartialFunction.empty final def finalAssertion(ctx: Context): in.Assertion ==> CodeWriter[vpr.Exp] = { diff --git a/src/main/scala/viper/gobra/translator/encodings/combinators/TypeEncodingCombiner.scala b/src/main/scala/viper/gobra/translator/encodings/combinators/TypeEncodingCombiner.scala index d9d05577d..c4e80a245 100644 --- a/src/main/scala/viper/gobra/translator/encodings/combinators/TypeEncodingCombiner.scala +++ b/src/main/scala/viper/gobra/translator/encodings/combinators/TypeEncodingCombiner.scala @@ -49,6 +49,7 @@ abstract class TypeEncodingCombiner(encodings: Vector[TypeEncoding], defaults: V override def equal(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = combiner(_.equal(ctx)) override def goEqual(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = combiner(_.goEqual(ctx)) override def expression(ctx: Context): in.Expr ==> CodeWriter[vpr.Exp] = combiner(_.expression(ctx)) + override def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = combiner(_.triggerExpr(ctx)) override def assertion(ctx: Context): in.Assertion ==> CodeWriter[vpr.Exp] = combiner(_.assertion(ctx)) override def reference(ctx: Context): in.Location ==> CodeWriter[vpr.Exp] = combiner(_.reference(ctx)) override def addressFootprint(ctx: Context): (in.Location, in.Expr) ==> CodeWriter[vpr.Exp] = combiner(_.addressFootprint(ctx)) @@ -64,7 +65,6 @@ abstract class TypeEncodingCombiner(encodings: Vector[TypeEncoding], defaults: V override def extendFunction(ctx: Context): in.Member ==> Extension[MemberWriter[vpr.Function]] = extender(_.extendFunction(ctx)) override def extendPredicate(ctx: Context): in.Member ==> Extension[MemberWriter[vpr.Predicate]] = extender(_.extendPredicate(ctx)) override def extendExpression(ctx: Context): in.Expr ==> Extension[CodeWriter[vpr.Exp]] = extender(_.extendExpression(ctx)) - override def extendTriggerExpr(ctx: Context): in.TriggerExpr ==> Extension[CodeWriter[vpr.Exp]] = extender(_.extendTriggerExpr(ctx)) override def extendAssertion(ctx: Context): in.Assertion ==> Extension[CodeWriter[vpr.Exp]] = extender(_.extendAssertion(ctx)) override def extendStatement(ctx: Context): in.Stmt ==> Extension[CodeWriter[vpr.Stmt]] = extender(_.extendStatement(ctx)) } diff --git a/src/main/scala/viper/gobra/translator/encodings/maps/MapEncoding.scala b/src/main/scala/viper/gobra/translator/encodings/maps/MapEncoding.scala index 2667b7cfb..5867590f3 100644 --- a/src/main/scala/viper/gobra/translator/encodings/maps/MapEncoding.scala +++ b/src/main/scala/viper/gobra/translator/encodings/maps/MapEncoding.scala @@ -61,6 +61,7 @@ class MapEncoding extends LeafTypeEncoding { * R[ (e: map[K]V)[idx] ] -> [e] == null? [ dflt(V) ] : goMapLookup(e[idx]) * R[ keySet(e: map[K]V) ] -> [e] == null? 0 : MapDomain(getCorrespondingMap(e)) * R[ valueSet(e: map[K]V) ] -> [e] == null? 0 : MapRange(getCorrespondingMap(e)) + * TODO: doc contains */ override def expression(ctx: Context): in.Expr ==> CodeWriter[vpr.Exp] = { def goE(x: in.Expr): CodeWriter[vpr.Exp] = ctx.expression(x) @@ -88,6 +89,15 @@ class MapEncoding extends LeafTypeEncoding { case l@in.IndexedExp(_ :: ctx.Map(_, _), _, _) => for {(res, _) <- goMapLookup(l)(ctx)} yield res + case l@in.Contains(key, exp :: ctx.Map(keys, values)) => + for { + keyVpr <- goE(key) + isComp <- MapEncoding.checkKeyComparability(key)(ctx) + correspondingMap <- getCorrespondingMap(exp, keys, values)(ctx) + containsExp = withSrc(vpr.MapContains(keyVpr, correspondingMap), l) + checkCompAndContains <- assert(isComp, containsExp, comparabilityErrorT)(ctx) + } yield checkCompAndContains + case k@in.MapKeys(mapExp :: ctx.Map(keys, values), _) => for { vprMap <- goE(mapExp) @@ -116,7 +126,35 @@ class MapEncoding extends LeafTypeEncoding { override def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = { default(super.triggerExpr(ctx)) { - case in.IndexedExp(_ :: ctx.Map(_, _), _, _) => unit(vpr.TrueLit()()) + case l@in.IndexedExp(m :: ctx.Map(keys, values), idx, _) => + for { + vIdx <- ctx.expression(idx) + correspondingMap <- getCorrespondingMap(m, keys, values)(ctx) + lookupRes = withSrc(vpr.MapLookup(correspondingMap, vIdx), l) + } yield lookupRes + + case l@in.Contains(key, m :: ctx.Map(keys, values)) => + for { + vKey <- ctx.expression(key) + correspondingMap <- getCorrespondingMap(m, keys, values)(ctx) + contains = withSrc(vpr.MapContains(correspondingMap, vKey), l) + } yield contains + + case l@in.Contains(key, in.MapKeys(m :: ctx.Map(keys, values), _)) => + for { + vKey <- ctx.expression(key) + correspondingMap <- getCorrespondingMap(m, keys, values)(ctx) + vDomainMap = withSrc(vpr.MapDomain(correspondingMap), l) + contains = withSrc(vpr.AnySetContains(vKey, vDomainMap), l) + } yield contains + + case l@in.Contains(key, in.MapValues(m :: ctx.Map(keys, values), _)) => + for { + vKey <- ctx.expression(key) + correspondingMap <- getCorrespondingMap(m, keys, values)(ctx) + vRangeMap = withSrc(vpr.MapRange(correspondingMap), l) + contains = withSrc(vpr.AnySetContains(vKey, vRangeMap), l) + } yield contains } } diff --git a/src/main/scala/viper/gobra/translator/util/ViperWriter.scala b/src/main/scala/viper/gobra/translator/util/ViperWriter.scala index de0f8bc67..cbbf6ae82 100644 --- a/src/main/scala/viper/gobra/translator/util/ViperWriter.scala +++ b/src/main/scala/viper/gobra/translator/util/ViperWriter.scala @@ -399,8 +399,12 @@ object ViperWriter { /* Can be used in expressions. */ def assert(cond: vpr.Exp, exp: vpr.Exp, reasonT: (Source.Verifier.Info, ErrorReason) => VerificationError)(ctx: Context): Writer[vpr.Exp] = { // In the future, this might do something more sophisticated - val (res, errT) = ctx.condition.assert(cond, exp, reasonT) - errorT(errT).map(_ => res) + if (cond.isInstanceOf[vpr.TrueLit]) { + unit(exp) + } else { + val (res, errT) = ctx.condition.assert(cond, exp, reasonT) + errorT(errT).map(_ => res) + } } /* Emits Viper statements. */ diff --git a/src/test/resources/regressions/features/maps/maps-simple1.gobra b/src/test/resources/regressions/features/maps/maps-simple1.gobra index 8a53e22cd..e98f05d1a 100644 --- a/src/test/resources/regressions/features/maps/maps-simple1.gobra +++ b/src/test/resources/regressions/features/maps/maps-simple1.gobra @@ -52,6 +52,9 @@ func test6() { m[3] = 10 v3, ok3 := m[3] assert ok3 && v3 == 10 + + // check if key exists in the map + assert 3 in m } type T struct { @@ -110,24 +113,24 @@ func test11() { requires acc(m, _) requires "key" in domain(m) func test12(m map[string]string) (r string){ - return m["key"] + return m["key"] } requires acc(m, _) requires "value" in range(m) func test13(m map[string]string) { - assert exists k string :: m[k] == "value" + assert exists k string :: m[k] == "value" } func test14() (res map[int]int) { x := 1 y := 2 - m := map[int]int{x: y, y: x} + m := map[int]int{x: y, y: x} return m } func test15() (res map[int]int) { - m := map[int]int{C1: C2, C2: C1} + m := map[int]int{C1: C2, C2: C1} return m } @@ -137,4 +140,8 @@ func test16() { assert x == 0 x, contained := m[2] assert x == 0 && !contained -} \ No newline at end of file +} + +requires m != nil ==> acc(m) +requires forall s string :: { s in domain(m) } s in domain(m) ==> acc(m[s]) +func test17(m map[string]*int) {}