diff --git a/prusti-viper/src/encoder/definition_collector.rs b/prusti-viper/src/encoder/definition_collector.rs index d746a2c95cb..bb4d9ed4baf 100644 --- a/prusti-viper/src/encoder/definition_collector.rs +++ b/prusti-viper/src/encoder/definition_collector.rs @@ -115,21 +115,24 @@ impl<'p, 'v: 'p, 'tcx: 'v> Collector<'p, 'v, 'tcx> { let leak_checked_methods = methods .into_iter() .map(|mut method| -> SpannedEncodingResult { - let ret_index = method.basic_blocks_labels - .iter() - .position(|label| label == "return"); - if ret_index.is_none() { - return Err(SpannedEncodingError::internal("encoded method does not contain a `return` label; cannot add leak checks", self.error_span)); - } - let ret_index = method.block_index(ret_index.unwrap()); - for identifier in &self.used_obligations { - let leak_check = self.encoder.get_obligation_leak_check(identifier)?; - let leak_check = (*leak_check).clone(); - method.add_stmt( - ret_index, - leak_check - ); - } + method = method.patch_statements(|stmt| -> SpannedEncodingResult::<_> { + match stmt { + vir::Stmt::LeakCheck(vir::LeakCheck { scope_id }) => { + let mut check_body = vir::Expr::Const(vir::ConstExpr { value: vir::Const::Bool(true), position: vir::Position::default() }); + for identifier in &self.used_obligations { + let current_check = self.encoder.get_obligation_leak_check(identifier, scope_id)?; + check_body = vir::Expr::BinOp(vir::BinOp { + op_kind: vir::BinaryOpKind::And, + left: Box::new(check_body), + right: Box::new(current_check), + position: vir::Position::default(), + }) + } + Ok(vir::Stmt::Assert(vir::Assert { expr: check_body, position: vir::Position::default() })) + }, + _ => { Ok(stmt) } + } + }).unwrap(); Ok(method) }).collect::>>()?; Ok(vir::Program { diff --git a/prusti-viper/src/encoder/encoder.rs b/prusti-viper/src/encoder/encoder.rs index f588ebbce7a..12c736c94b4 100644 --- a/prusti-viper/src/encoder/encoder.rs +++ b/prusti-viper/src/encoder/encoder.rs @@ -71,7 +71,7 @@ pub struct Encoder<'v, 'tcx: 'v> { /// A map containing all functions: identifier → function definition. functions: RefCell>>, obligations: RefCell>>, - obligation_checks: RefCell>>, + obligation_checks: RefCell>>, builtin_domains: RefCell>, builtin_domains_in_progress: RefCell>, builtin_methods: RefCell>, @@ -350,7 +350,7 @@ impl<'v, 'tcx> Encoder<'v, 'tcx> { let mut args = vec![vir::LocalVar::new("scope_id", vir::Type::Int)]; let mut check_args = vec![]; let mut concrete_args = vec![vir::Expr::Const(vir::ConstExpr { - value: vir::Const::Int(-1), + value: vir::Const::Int(-2), position: vir::Position::default(), })]; for local_idx in 1..sig.skip_binder().inputs().len() { @@ -374,18 +374,15 @@ impl<'v, 'tcx> Encoder<'v, 'tcx> { args: concrete_args, formal_arguments: args.clone(), }; - let check = vir::Stmt::Assert(vir::Assert { - expr: vir::Expr::ForPerm(vir::ForPerm { - variables: check_args, - access: obligation_access, - body: Box::new(vir::Expr::Const(vir::ConstExpr { - value: vir::Const::Bool(false), - position: vir::Position::default(), - })), + let check = vir::ForPerm { + variables: check_args, + access: obligation_access, + body: Box::new(vir::Expr::Const(vir::ConstExpr { + value: vir::Const::Bool(false), position: vir::Position::default(), - }), + })), position: vir::Position::default(), - }); + }; self.obligation_checks.borrow_mut().insert(ident.clone()/*obligation_name.clone().into()*/, Rc::new(check)); } Ok(()) @@ -401,11 +398,41 @@ impl<'v, 'tcx> Encoder<'v, 'tcx> { } } - pub(super) fn get_obligation_leak_check(&self, identifier: &vir::FunctionIdentifier) -> SpannedEncodingResult> { + pub(super) fn get_obligation_leak_check(&self, identifier: &vir::FunctionIdentifier, scope_id: isize) -> SpannedEncodingResult { self.ensure_obligation_encoded(identifier)?; if self.obligation_checks.borrow().contains_key(identifier) { let map = self.obligation_checks.borrow(); - Ok(map[identifier].clone()) + let check = map[identifier].clone(); + Ok(match (*check).clone() { + vir::ForPerm { + variables, + access: vir::ObligationAccess { + name, + args, + formal_arguments, + }, + body, + position, + } => vir::Expr::ForPerm(vir::ForPerm { + variables, + access: vir::ObligationAccess { + name, + args: args.into_iter().enumerate().map(|(i, a)| { + if i == 0 { + vir::Expr::Const(vir::ConstExpr { + value: vir::Const::Int(scope_id.try_into().unwrap()), + position: vir::Position::default(), + }) + } else { + a + } + }).collect(), + formal_arguments, + }, + body, + position, + }) + }) } else { unreachable!("Not found obligation check: {:?}", identifier); } diff --git a/prusti-viper/src/encoder/foldunfold/requirements.rs b/prusti-viper/src/encoder/foldunfold/requirements.rs index 1e3223eba2d..64c8dcc8da7 100644 --- a/prusti-viper/src/encoder/foldunfold/requirements.rs +++ b/prusti-viper/src/encoder/foldunfold/requirements.rs @@ -214,6 +214,10 @@ impl RequiredStmtPermissionsGetter for vir::Stmt { base.get_required_stmt_permissions(predicates) } + &vir::Stmt::LeakCheck(..) => { + FxHashSet::default() + } + ref x => unimplemented!("{}", x), } } diff --git a/prusti-viper/src/encoder/foldunfold/semantics.rs b/prusti-viper/src/encoder/foldunfold/semantics.rs index a35c6cba526..5b4dbc53209 100644 --- a/prusti-viper/src/encoder/foldunfold/semantics.rs +++ b/prusti-viper/src/encoder/foldunfold/semantics.rs @@ -511,6 +511,8 @@ impl ApplyOnState for vir::Stmt { } } + &vir::Stmt::LeakCheck(..) => {} + ref x => unimplemented!("{}", x), } Ok(()) diff --git a/prusti-viper/src/encoder/procedure_encoder.rs b/prusti-viper/src/encoder/procedure_encoder.rs index 24460019665..17ebe118a78 100644 --- a/prusti-viper/src/encoder/procedure_encoder.rs +++ b/prusti-viper/src/encoder/procedure_encoder.rs @@ -1054,6 +1054,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { self.cfg_method.add_stmts(inv_post_block_perms, stmts); } + // TODO: also check for obligations in postconditions here let mid_groups = if preconds.is_empty() { // Encode the mid G group (start - G - B1 - invariant_perm - *G* - B1 - invariant_fnspec - B2 - G - B1 - end) let mid_g = self.encode_blocks_group( @@ -1173,6 +1174,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { &scope_ids, )?; self.cfg_method.add_stmts(end_body_block, stmts); + self.cfg_method.add_stmt(end_body_block, vir::Stmt::LeakCheck(vir::LeakCheck { scope_id: loop_head.index() as isize })); } self.cfg_method.add_stmt( end_body_block, @@ -5178,6 +5180,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> ProcedureEncoder<'p, 'v, 'tcx> { position: func_pos, }), ); + self.cfg_method.add_stmt(return_cfg_block, vir::Stmt::LeakCheck(vir::LeakCheck { scope_id: -1 })); // Assert type invariants self.cfg_method.add_stmt( diff --git a/vir/defs/polymorphic/ast/stmt.rs b/vir/defs/polymorphic/ast/stmt.rs index 39de9ff181b..a51563ffff8 100644 --- a/vir/defs/polymorphic/ast/stmt.rs +++ b/vir/defs/polymorphic/ast/stmt.rs @@ -64,6 +64,7 @@ pub enum Stmt { /// * place to the enumeration instance /// * field that encodes the variant Downcast(Downcast), + LeakCheck(LeakCheck), } impl fmt::Display for Stmt { @@ -88,6 +89,7 @@ impl fmt::Display for Stmt { Stmt::ExpireBorrows(expire_borrows) => expire_borrows.fmt(f), Stmt::If(if_stmt) => if_stmt.fmt(f), Stmt::Downcast(downcast) => downcast.fmt(f), + Stmt::LeakCheck(leak_check) => leak_check.fmt(f), } } } @@ -450,6 +452,17 @@ impl fmt::Display for Downcast { } } +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct LeakCheck { + pub scope_id: isize, +} + +impl fmt::Display for LeakCheck { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "leak for scope_id = {}", self.scope_id) + } +} + impl Stmt { pub fn is_comment(&self) -> bool { matches!(self, Stmt::Comment(_)) @@ -576,6 +589,7 @@ pub trait StmtFolder { Stmt::ExpireBorrows(expire_borrows) => self.fold_expire_borrows(expire_borrows), Stmt::If(if_stmt) => self.fold_if(if_stmt), Stmt::Downcast(downcast) => self.fold_downcast(downcast), + Stmt::LeakCheck(leak_check) => self.fold_leak_check(leak_check), } } @@ -761,6 +775,10 @@ pub trait StmtFolder { field, }) } + + fn fold_leak_check(&mut self, statement: LeakCheck) -> Stmt { + Stmt::LeakCheck(statement) + } } pub trait FallibleStmtFolder { @@ -793,6 +811,7 @@ pub trait FallibleStmtFolder { } Stmt::If(if_stmt) => self.fallible_fold_if(if_stmt), Stmt::Downcast(downcast) => self.fallible_fold_downcast(downcast), + Stmt::LeakCheck(leak_check) => self.fallible_fold_leak_check(leak_check), } } @@ -1009,6 +1028,10 @@ pub trait FallibleStmtFolder { field, })) } + + fn fallible_fold_leak_check(&mut self, statement: LeakCheck) -> Result { + Ok(Stmt::LeakCheck(statement)) + } } pub trait StmtWalker { @@ -1035,6 +1058,7 @@ pub trait StmtWalker { Stmt::ExpireBorrows(expire_borrows) => self.walk_expire_borrows(expire_borrows), Stmt::If(if_stmt) => self.walk_if(if_stmt), Stmt::Downcast(downcast) => self.walk_downcast(downcast), + Stmt::LeakCheck(leak_check) => self.walk_leak_check(leak_check), } } @@ -1163,6 +1187,8 @@ pub trait StmtWalker { let Downcast { base, .. } = statement; self.walk_expr(base); } + + fn walk_leak_check(&mut self, _statement: &LeakCheck) {} } pub trait FallibleStmtWalker { @@ -1195,6 +1221,7 @@ pub trait FallibleStmtWalker { } Stmt::If(if_stmt) => self.fallible_walk_if(if_stmt), Stmt::Downcast(downcast) => self.fallible_walk_downcast(downcast), + Stmt::LeakCheck(leak_check) => self.fallible_walk_leak_check(leak_check), } } @@ -1369,6 +1396,10 @@ pub trait FallibleStmtWalker { self.fallible_walk_expr(base)?; Ok(()) } + + fn fallible_walk_leak_check(&mut self, _statement: &LeakCheck) -> Result<(), Self::Error> { + Ok(()) + } } pub fn stmts_to_str(stmts: &[Stmt]) -> String { diff --git a/vir/src/converter/polymorphic_to_legacy.rs b/vir/src/converter/polymorphic_to_legacy.rs index e978dd5603e..228c0eed0b1 100644 --- a/vir/src/converter/polymorphic_to_legacy.rs +++ b/vir/src/converter/polymorphic_to_legacy.rs @@ -756,7 +756,10 @@ impl From for legacy::Stmt { ), polymorphic::Stmt::Downcast(downcast) => { legacy::Stmt::Downcast(downcast.base.into(), downcast.field.into()) - } + }, + polymorphic::Stmt::LeakCheck(_) => { + panic!("all leak check markers needs to removed before convering polymorphic VIR to legacy!"); + }, } } } diff --git a/vir/src/converter/type_substitution.rs b/vir/src/converter/type_substitution.rs index ee04db0a7de..c8d12395295 100644 --- a/vir/src/converter/type_substitution.rs +++ b/vir/src/converter/type_substitution.rs @@ -592,6 +592,7 @@ impl Generic for Stmt { } Stmt::If(if_stmt) => Stmt::If(if_stmt.substitute(map)), Stmt::Downcast(downcast) => Stmt::Downcast(downcast.substitute(map)), + Stmt::LeakCheck(leak_check) => Stmt::LeakCheck(leak_check.substitute(map)), } } } @@ -799,6 +800,12 @@ impl Generic for Downcast { } } +impl Generic for LeakCheck { + fn substitute(self, _map: &FxHashMap) -> Self { + self + } +} + // method impl Generic for CfgMethod { fn substitute(self, map: &FxHashMap) -> Self {