Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jcp19 committed Sep 21, 2023
1 parent 002ef76 commit 93798d6
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 29 deletions.
1 change: 1 addition & 0 deletions src/main/scala/viper/gobra/frontend/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4304,6 +4304,7 @@ object Desugar extends LazyLogging {
} yield underlyingType(dright.typ) match {
case _: in.SequenceT | _: in.SetT => in.Contains(dleft, dright)(src)
case _: in.MultisetT => in.LessCmp(in.IntLit(0)(src), in.Contains(dleft, dright)(src))(src)
case _: in.MapT => in.Contains(dleft, dright)(src)
case t => violation(s"expected a sequence or (multi)set type, but got $t")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ trait GhostExprTyping extends BaseTyping { this: TypeInfoImpl =>
case PIn(left, right) => isExpr(left).out ++ isExpr(right).out ++ {
underlyingType(exprType(right)) match {
case t : GhostCollectionType => ghostComparableTypes.errors(exprType(left), t.elem)(expr)
case t : MapT => ghostComparableTypes.errors(exprType(left), t.key)(expr)
case _ : AdtT => noMessages
case t => error(right, s"expected a ghost collection, but got $t")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ trait Context {

def expression(x: in.Expr): CodeWriter[vpr.Exp] = typeEncoding.finalExpression(this)(x)

def triggerExpr(x: in.TriggerExpr): CodeWriter[vpr.Exp] = typeEncoding.finalTriggerExpr(this)(x)
def triggerExpr(x: in.TriggerExpr): CodeWriter[vpr.Exp] = typeEncoding.triggerExpr(this)(x)

def assertion(x: in.Assertion): CodeWriter[vpr.Exp] = typeEncoding.finalAssertion(this)(x)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class FinalTypeEncoding(te: TypeEncoding) extends TypeEncoding {
override def equal(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = te.equal(ctx) orElse expectedMatch("equal")
override def goEqual(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = te.goEqual(ctx) orElse expectedMatch("equal")
override def expression(ctx: Context): in.Expr ==> CodeWriter[vpr.Exp] = te.expression(ctx) orElse expectedMatch("expression")
override def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = te.triggerExpr(ctx) orElse {
case e: in.Expr => te.expression(ctx)(e)
}
override def assertion(ctx: Context): in.Assertion ==> CodeWriter[vpr.Exp] = te.assertion(ctx) orElse expectedMatch("assertion")
override def reference(ctx: Context): in.Location ==> CodeWriter[vpr.Exp] = te.reference(ctx) orElse expectedMatch("reference")
override def addressFootprint(ctx: Context): (in.Location, in.Expr) ==> CodeWriter[vpr.Exp] = te.addressFootprint(ctx) orElse expectedMatch("addressFootprint")
Expand All @@ -61,7 +64,6 @@ class FinalTypeEncoding(te: TypeEncoding) extends TypeEncoding {
override def extendFunction(ctx: Context): in.Member ==> Extension[MemberWriter[vpr.Function]] = te.extendFunction(ctx) orElse { _ => identity }
override def extendPredicate(ctx: Context): in.Member ==> Extension[MemberWriter[vpr.Predicate]] = te.extendPredicate(ctx) orElse { _ => identity }
override def extendExpression(ctx: Context): in.Expr ==> Extension[CodeWriter[vpr.Exp]] = te.extendExpression(ctx) orElse { _ => identity }
override def extendTriggerExpr(ctx: Context): in.TriggerExpr ==> Extension[CodeWriter[vpr.Exp]] = te.extendTriggerExpr(ctx) orElse { _ => identity }
override def extendAssertion(ctx: Context): in.Assertion ==> Extension[CodeWriter[vpr.Exp]] = te.extendAssertion(ctx) orElse { _ => identity }
override def extendStatement(ctx: Context): in.Stmt ==> Extension[CodeWriter[vpr.Stmt]] = te.extendStatement(ctx) orElse { _ => identity }
}
Original file line number Diff line number Diff line change
Expand Up @@ -234,20 +234,18 @@ trait TypeEncoding extends Generator {
}

// TODO: doc
// TODO: optimize assert2(true, ...)
// TODO: key in map, instead of key in domain(map)
// TODO: enable consistency checks on triggers when --checkConsistency
def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = PartialFunction.empty

/*
{
def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = {
// use predicate access encoding but then take just the predicate access, i.e. remove `acc` and the permission amount:
case in.Accessible.Predicate(op) =>
for {
v <- ctx.assertion(in.Access(in.Accessible.Predicate(op), in.FullPerm(op.info))(op.info))
pap = v.asInstanceOf[vpr.PredicateAccessPredicate]
} yield pap.loc
case e: in.Expr => ctx.expression(e)
// case e: in.Expr => ctx.expression(e)
}
*/

/**
* Encodes assertions.
Expand Down Expand Up @@ -422,18 +420,6 @@ trait TypeEncoding extends Generator {
val f = expression(ctx); { case n@f(v) => extendExpression(ctx).lift(n).fold(v)(_(v)) }
}

/** Adds to the encoding of [[triggerExpr]]. The extension is applied to the result of the final trigger expression
* encoding.
*/
def extendTriggerExpr(@unused ctx: Context): in.TriggerExpr ==> Extension[CodeWriter[vpr.Exp]] = PartialFunction.empty

final def finalTriggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = {
val f = triggerExpr(ctx);
{
case n@f(v) => extendTriggerExpr(ctx).lift(n).fold(v)(_(v))
}
}

/** Adds to the encoding of [[assertion]]. The extension is applied to the result of the final assertion encoding. */
def extendAssertion(@unused ctx: Context): in.Assertion ==> Extension[CodeWriter[vpr.Exp]] = PartialFunction.empty
final def finalAssertion(ctx: Context): in.Assertion ==> CodeWriter[vpr.Exp] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ abstract class TypeEncodingCombiner(encodings: Vector[TypeEncoding], defaults: V
override def equal(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = combiner(_.equal(ctx))
override def goEqual(ctx: Context): (in.Expr, in.Expr, in.Node) ==> CodeWriter[vpr.Exp] = combiner(_.goEqual(ctx))
override def expression(ctx: Context): in.Expr ==> CodeWriter[vpr.Exp] = combiner(_.expression(ctx))
override def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = combiner(_.triggerExpr(ctx))
override def assertion(ctx: Context): in.Assertion ==> CodeWriter[vpr.Exp] = combiner(_.assertion(ctx))
override def reference(ctx: Context): in.Location ==> CodeWriter[vpr.Exp] = combiner(_.reference(ctx))
override def addressFootprint(ctx: Context): (in.Location, in.Expr) ==> CodeWriter[vpr.Exp] = combiner(_.addressFootprint(ctx))
Expand All @@ -64,7 +65,6 @@ abstract class TypeEncodingCombiner(encodings: Vector[TypeEncoding], defaults: V
override def extendFunction(ctx: Context): in.Member ==> Extension[MemberWriter[vpr.Function]] = extender(_.extendFunction(ctx))
override def extendPredicate(ctx: Context): in.Member ==> Extension[MemberWriter[vpr.Predicate]] = extender(_.extendPredicate(ctx))
override def extendExpression(ctx: Context): in.Expr ==> Extension[CodeWriter[vpr.Exp]] = extender(_.extendExpression(ctx))
override def extendTriggerExpr(ctx: Context): in.TriggerExpr ==> Extension[CodeWriter[vpr.Exp]] = extender(_.extendTriggerExpr(ctx))
override def extendAssertion(ctx: Context): in.Assertion ==> Extension[CodeWriter[vpr.Exp]] = extender(_.extendAssertion(ctx))
override def extendStatement(ctx: Context): in.Stmt ==> Extension[CodeWriter[vpr.Stmt]] = extender(_.extendStatement(ctx))
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class MapEncoding extends LeafTypeEncoding {
* R[ (e: map[K]V)[idx] ] -> [e] == null? [ dflt(V) ] : goMapLookup(e[idx])
* R[ keySet(e: map[K]V) ] -> [e] == null? 0 : MapDomain(getCorrespondingMap(e))
* R[ valueSet(e: map[K]V) ] -> [e] == null? 0 : MapRange(getCorrespondingMap(e))
* TODO: doc contains
*/
override def expression(ctx: Context): in.Expr ==> CodeWriter[vpr.Exp] = {
def goE(x: in.Expr): CodeWriter[vpr.Exp] = ctx.expression(x)
Expand Down Expand Up @@ -88,6 +89,15 @@ class MapEncoding extends LeafTypeEncoding {

case l@in.IndexedExp(_ :: ctx.Map(_, _), _, _) => for {(res, _) <- goMapLookup(l)(ctx)} yield res

case l@in.Contains(key, exp :: ctx.Map(keys, values)) =>
for {
keyVpr <- goE(key)
isComp <- MapEncoding.checkKeyComparability(key)(ctx)
correspondingMap <- getCorrespondingMap(exp, keys, values)(ctx)
containsExp = withSrc(vpr.MapContains(keyVpr, correspondingMap), l)
checkCompAndContains <- assert(isComp, containsExp, comparabilityErrorT)(ctx)
} yield checkCompAndContains

case k@in.MapKeys(mapExp :: ctx.Map(keys, values), _) =>
for {
vprMap <- goE(mapExp)
Expand Down Expand Up @@ -116,7 +126,35 @@ class MapEncoding extends LeafTypeEncoding {

override def triggerExpr(ctx: Context): in.TriggerExpr ==> CodeWriter[vpr.Exp] = {
default(super.triggerExpr(ctx)) {
case in.IndexedExp(_ :: ctx.Map(_, _), _, _) => unit(vpr.TrueLit()())
case l@in.IndexedExp(m :: ctx.Map(keys, values), idx, _) =>
for {
vIdx <- ctx.expression(idx)
correspondingMap <- getCorrespondingMap(m, keys, values)(ctx)
lookupRes = withSrc(vpr.MapLookup(correspondingMap, vIdx), l)
} yield lookupRes

case l@in.Contains(key, m :: ctx.Map(keys, values)) =>
for {
vKey <- ctx.expression(key)
correspondingMap <- getCorrespondingMap(m, keys, values)(ctx)
contains = withSrc(vpr.MapContains(correspondingMap, vKey), l)
} yield contains

case l@in.Contains(key, in.MapKeys(m :: ctx.Map(keys, values), _)) =>
for {
vKey <- ctx.expression(key)
correspondingMap <- getCorrespondingMap(m, keys, values)(ctx)
vDomainMap = withSrc(vpr.MapDomain(correspondingMap), l)
contains = withSrc(vpr.AnySetContains(vKey, vDomainMap), l)
} yield contains

case l@in.Contains(key, in.MapValues(m :: ctx.Map(keys, values), _)) =>
for {
vKey <- ctx.expression(key)
correspondingMap <- getCorrespondingMap(m, keys, values)(ctx)
vRangeMap = withSrc(vpr.MapRange(correspondingMap), l)
contains = withSrc(vpr.AnySetContains(vKey, vRangeMap), l)
} yield contains
}
}

Expand Down
8 changes: 6 additions & 2 deletions src/main/scala/viper/gobra/translator/util/ViperWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,12 @@ object ViperWriter {
/* Can be used in expressions. */
def assert(cond: vpr.Exp, exp: vpr.Exp, reasonT: (Source.Verifier.Info, ErrorReason) => VerificationError)(ctx: Context): Writer[vpr.Exp] = {
// In the future, this might do something more sophisticated
val (res, errT) = ctx.condition.assert(cond, exp, reasonT)
errorT(errT).map(_ => res)
if (cond.isInstanceOf[vpr.TrueLit]) {
unit(exp)
} else {
val (res, errT) = ctx.condition.assert(cond, exp, reasonT)
errorT(errT).map(_ => res)
}
}

/* Emits Viper statements. */
Expand Down
17 changes: 12 additions & 5 deletions src/test/resources/regressions/features/maps/maps-simple1.gobra
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func test6() {
m[3] = 10
v3, ok3 := m[3]
assert ok3 && v3 == 10

// check if key exists in the map
assert 3 in m
}

type T struct {
Expand Down Expand Up @@ -110,24 +113,24 @@ func test11() {
requires acc(m, _)
requires "key" in domain(m)
func test12(m map[string]string) (r string){
return m["key"]
return m["key"]
}

requires acc(m, _)
requires "value" in range(m)
func test13(m map[string]string) {
assert exists k string :: m[k] == "value"
assert exists k string :: m[k] == "value"
}

func test14() (res map[int]int) {
x := 1
y := 2
m := map[int]int{x: y, y: x}
m := map[int]int{x: y, y: x}
return m
}

func test15() (res map[int]int) {
m := map[int]int{C1: C2, C2: C1}
m := map[int]int{C1: C2, C2: C1}
return m
}

Expand All @@ -137,4 +140,8 @@ func test16() {
assert x == 0
x, contained := m[2]
assert x == 0 && !contained
}
}

requires m != nil ==> acc(m)
requires forall s string :: { s in domain(m) } s in domain(m) ==> acc(m[s])
func test17(m map[string]*int) {}

0 comments on commit 93798d6

Please sign in to comment.