diff --git a/src/main/scala/viper/silver/ast/utility/GenericTriggerGenerator.scala b/src/main/scala/viper/silver/ast/utility/GenericTriggerGenerator.scala index e6e37edda..b890b48df 100644 --- a/src/main/scala/viper/silver/ast/utility/GenericTriggerGenerator.scala +++ b/src/main/scala/viper/silver/ast/utility/GenericTriggerGenerator.scala @@ -8,6 +8,7 @@ package viper.silver.ast.utility import java.util.concurrent.atomic.AtomicInteger import reflect.ClassTag +import viper.silver.ast object GenericTriggerGenerator { case class TriggerSet[E](exps: Seq[E]) @@ -195,10 +196,19 @@ abstract class GenericTriggerGenerator[Node <: AnyRef, else results.flatten + case e if modifyPossibleTriggers.isDefinedAt(e) => modifyPossibleTriggers.apply(e)(results) + case _ => results.flatten }) } + /* + * Hook for clients to add more cases to getFunctionAppsContaining to modify the found possible triggers. + * Used e.g. to wrap trigger expressions inferred from inside old-expression into old() + */ + def modifyPossibleTriggers: PartialFunction[Node, Seq[Seq[(PossibleTrigger, Seq[Var], Seq[Var])]] => + Seq[(PossibleTrigger, Seq[Var], Seq[Var])]] = PartialFunction.empty + /* Precondition: if vars is non-empty then every (f,vs) pair in functs * satisfies the property that vars and vs are not disjoint. * diff --git a/src/main/scala/viper/silver/ast/utility/Triggers.scala b/src/main/scala/viper/silver/ast/utility/Triggers.scala index f561512ef..93b89e609 100644 --- a/src/main/scala/viper/silver/ast/utility/Triggers.scala +++ b/src/main/scala/viper/silver/ast/utility/Triggers.scala @@ -57,6 +57,19 @@ object Triggers { case LabelledOld(pt: PossibleTrigger, _) => pt.getArgs case _ => sys.error(s"Unexpected expression $e") } + + override def modifyPossibleTriggers = { + case ast.Old(_) => results => + results.flatten.map(t => { + val exp = t._1 + (ast.Old(exp)(exp.pos, exp.info, exp.errT), t._2, t._3) + }) + case ast.LabelledOld(_, l) => results => + results.flatten.map(t => { + val exp = t._1 + (ast.LabelledOld(exp, l)(exp.pos, exp.info, exp.errT), t._2, t._3) + }) + } } /**