Skip to content

Commit

Permalink
Raise an error when postconditions of pure functions contain old() ex…
Browse files Browse the repository at this point in the history
…pressions (#1474)

* Raise an error if old() appears in postcondition of pure functions

* rustfmt, commit more files

* Add a test

* Clippy

* Remove unnecessary debug

* Fix test

* More tests
  • Loading branch information
zgrannan authored Dec 7, 2023
1 parent 26c999b commit 202ca0e
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 66 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use prusti_contracts::*;

struct MyWrapper(u32);

impl MyWrapper {
#[pure]
#[ensures(old(self.0) == self.0)]
fn unwrap(&self) -> u32 { //~ ERROR old expressions should not appear in the postconditions of pure functions
self.0
}
}

fn test(x: &MyWrapper) -> u32 {
// Following error is due to stub encoding of invalid spec for function `unwrap()`
x.unwrap() //~ ERROR precondition of pure function call might not hold
}

fn main() { }
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use prusti_contracts::*;

#[extern_spec]
impl<T> std::option::Option<T> {
#[pure] // <=== Error triggered by this
#[requires(self.is_some())]
#[ensures(old(self) === Some(result))]
pub fn unwrap(self) -> T; //~ ERROR old expressions should not appear in the postconditions of pure functions

#[pure]
#[ensures(result == matches!(self, Some(_)))]
pub const fn is_some(&self) -> bool;
}

#[pure]
#[requires(x.is_some())]
fn test(x: Option<i32>) -> i32 {
// Following error is due to stub encoding of invalid external spec for function `unwrap()`
x.unwrap() //~ ERROR precondition of pure function call might not hold
}

fn main() { }
46 changes: 46 additions & 0 deletions prusti-viper/src/encoder/interface.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use crate::encoder::{
errors::{SpannedEncodingResult, WithSpan},
snapshot::interface::SnapshotEncoderInterface,
Encoder,
};

use prusti_rustc_interface::{
middle::{mir, ty, ty::Binder},
span::Span,
};

use vir_crate::polymorphic as vir_poly;

pub(crate) trait PureFunctionFormalArgsEncoderInterface<'p, 'v: 'p, 'tcx: 'v> {
fn encoder(&self) -> &'p Encoder<'v, 'tcx>;

fn check_type(
&self,
var_span: Span,
ty: Binder<'tcx, ty::Ty<'tcx>>,
) -> SpannedEncodingResult<()>;

fn get_span(&self, local: mir::Local) -> Span;

fn encode_formal_args(
&self,
sig: ty::PolyFnSig<'tcx>,
) -> SpannedEncodingResult<Vec<vir_poly::LocalVar>> {
let mut formal_args = vec![];
for local_idx in 0..sig.skip_binder().inputs().len() {
let local_ty = sig.input(local_idx);
let local = mir::Local::from_usize(local_idx + 1);
let var_name = format!("{local:?}");
let var_span = self.get_span(local);

self.check_type(var_span, local_ty)?;

let var_type = self
.encoder()
.encode_snapshot_type(local_ty.skip_binder())
.with_span(var_span)?;
formal_args.push(vir_poly::LocalVar::new(var_name, var_type))
}
Ok(formal_args)
}
}
75 changes: 43 additions & 32 deletions prusti-viper/src/encoder/mir/pure/pure_functions/encoder_poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use crate::encoder::{
errors::{ErrorCtxt, SpannedEncodingError, SpannedEncodingResult, WithSpan},
high::{generics::HighGenericsEncoderInterface, types::HighTypeEncoderInterface},
interface::PureFunctionFormalArgsEncoderInterface,
mir::{
contracts::{ContractsEncoderInterface, ProcedureContract},
pure::{
Expand Down Expand Up @@ -50,7 +51,7 @@ pub(super) struct PureFunctionEncoder<'p, 'v: 'p, 'tcx: 'v> {
/// Span of the function declaration.
span: Span,
/// Signature of the function to be encoded.
sig: ty::PolyFnSig<'tcx>,
pub(crate) sig: ty::PolyFnSig<'tcx>,
/// Spans of MIR locals, when encoding a local pure function.
local_spans: Option<Vec<Span>>,
}
Expand Down Expand Up @@ -137,6 +138,38 @@ fn encode_mir<'p, 'v: 'p, 'tcx: 'v>(
Ok(body_expr)
}

impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionFormalArgsEncoderInterface<'p, 'v, 'tcx>
for PureFunctionEncoder<'p, 'v, 'tcx>
{
fn encoder(&self) -> &'p Encoder<'v, 'tcx> {
self.encoder
}

fn check_type(
&self,
var_span: Span,
ty: ty::Binder<'tcx, ty::Ty<'tcx>>,
) -> SpannedEncodingResult<()> {
if !self
.encoder
.env()
.query
.type_is_copy(ty, self.parent_def_id)
{
Err(SpannedEncodingError::incorrect(
"pure function parameters must be Copy",
var_span,
))
} else {
Ok(())
}
}

fn get_span(&self, local: mir::Local) -> Span {
self.get_local_span(local)
}
}

impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> {
#[tracing::instrument(
name = "PureFunctionEncoder::new",
Expand Down Expand Up @@ -314,7 +347,7 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> {
let mut precondition = vec![type_precondition, func_precondition];
let mut postcondition = vec![self.encode_postcondition_expr(&contract)?];

let formal_args = self.encode_formal_args()?;
let formal_args = self.encode_formal_args(self.sig)?;
let return_type = self.encode_function_return_type()?;

let res_value_range_pos = self.encoder.error_manager().register_error(
Expand Down Expand Up @@ -545,6 +578,13 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> {
.replace_place(&encoded_return.into(), &pure_fn_return_variable.into())
.set_default_pos(postcondition_pos);

if post.has_old_expression() {
return Err(SpannedEncodingError::incorrect(
"old expressions should not appear in the postconditions of pure functions",
self.span,
));
}

Ok(post)
}

Expand Down Expand Up @@ -620,40 +660,11 @@ impl<'p, 'v: 'p, 'tcx: 'v> PureFunctionEncoder<'p, 'v, 'tcx> {
.with_span(self.span)
}

fn encode_formal_args(&self) -> SpannedEncodingResult<Vec<vir::LocalVar>> {
let mut formal_args = vec![];
for local_idx in 0..self.sig.skip_binder().inputs().len() {
let local_ty = self.sig.input(local_idx);
let local = prusti_rustc_interface::middle::mir::Local::from_usize(local_idx + 1);
let var_name = format!("{local:?}");
let var_span = self.get_local_span(local);

if !self
.encoder
.env()
.query
.type_is_copy(local_ty, self.parent_def_id)
{
return Err(SpannedEncodingError::incorrect(
"pure function parameters must be Copy",
var_span,
));
}

let var_type = self
.encoder
.encode_snapshot_type(local_ty.skip_binder())
.with_span(var_span)?;
formal_args.push(vir::LocalVar::new(var_name, var_type))
}
Ok(formal_args)
}

pub fn encode_function_call_info(&self) -> SpannedEncodingResult<FunctionCallInfo> {
Ok(FunctionCallInfo {
name: self.encode_function_name(),
type_arguments: self.encode_type_arguments()?,
formal_args: self.encode_formal_args()?,
formal_args: self.encode_formal_args(self.sig)?,
return_type: self.encode_function_return_type()?,
})
}
Expand Down
35 changes: 26 additions & 9 deletions prusti-viper/src/encoder/mir/pure/pure_functions/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,11 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx>
substs,
);

let is_bodyless = self.is_trusted(proc_def_id, Some(substs))
|| !self.env().query.has_body(proc_def_id);

let maybe_identifier: SpannedEncodingResult<vir_poly::FunctionIdentifier> = (|| {
let proc_kind = self.get_proc_kind(proc_def_id, Some(substs));
let is_bodyless = self.is_trusted(proc_def_id, Some(substs))
|| !self.env().query.has_body(proc_def_id);
let mut function = if is_bodyless {
pure_function_encoder.encode_bodyless_function()?
} else {
Expand Down Expand Up @@ -393,13 +394,29 @@ impl<'v, 'tcx: 'v> PureFunctionEncoderInterface<'v, 'tcx>
Err(error) => {
self.register_encoding_error(error);
debug!("Error encoding pure function: {:?}", proc_def_id);
let body = self
.env()
.body
.get_pure_fn_body(proc_def_id, substs, parent_def_id);
// TODO(tymap): does stub encoder need substs?
let stub_encoder = StubFunctionEncoder::new(self, proc_def_id, &body, substs);
let function = stub_encoder.encode_function()?;
let function = if !is_bodyless {
let pure_fn_body =
self.env()
.body
.get_pure_fn_body(proc_def_id, substs, parent_def_id);
let encoder = StubFunctionEncoder::new(
self,
proc_def_id,
Some(&pure_fn_body),
substs,
pure_function_encoder.sig,
);
encoder.encode_function()?
} else {
let encoder = StubFunctionEncoder::new(
self,
proc_def_id,
None,
substs,
pure_function_encoder.sig,
);
encoder.encode_function()?
};
self.log_vir_program_before_viper(function.to_string());
let identifier = self.insert_function(function);
self.pure_function_encoder_state
Expand Down
1 change: 1 addition & 0 deletions prusti-viper/src/encoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod encoder;
mod errors;
mod foldunfold;
mod initialisation;
mod interface;
mod loop_encoder;
mod mir_encoder;
mod mir_successor;
Expand Down
67 changes: 42 additions & 25 deletions prusti-viper/src/encoder/stub_function_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,64 +7,82 @@
use crate::encoder::{
errors::{SpannedEncodingResult, WithSpan},
high::generics::HighGenericsEncoderInterface,
mir_encoder::{MirEncoder, PlaceEncoder},
interface::PureFunctionFormalArgsEncoderInterface,
snapshot::interface::SnapshotEncoderInterface,
Encoder,
};
use log::debug;
use prusti_rustc_interface::{
hir::def_id::DefId,
middle::{mir, ty::GenericArgsRef},
middle::{
mir, ty,
ty::{Binder, GenericArgsRef},
},
span::Span,
};
use vir_crate::polymorphic as vir;

use super::mir::specifications::SpecificationsInterface;

pub struct StubFunctionEncoder<'p, 'v: 'p, 'tcx: 'v> {
encoder: &'p Encoder<'v, 'tcx>,
mir: &'p mir::Body<'tcx>,
mir_encoder: MirEncoder<'p, 'v, 'tcx>,
mir: Option<&'p mir::Body<'tcx>>,
proc_def_id: DefId,
substs: GenericArgsRef<'tcx>,
sig: ty::PolyFnSig<'tcx>,
}

impl<'p, 'v, 'tcx> PureFunctionFormalArgsEncoderInterface<'p, 'v, 'tcx>
for StubFunctionEncoder<'p, 'v, 'tcx>
{
fn check_type(&self, _span: Span, _ty: Binder<ty::Ty<'tcx>>) -> SpannedEncodingResult<()> {
Ok(())
}

fn encoder(&self) -> &'p Encoder<'v, 'tcx> {
self.encoder
}

fn get_span(&self, _local: mir::Local) -> Span {
self.encoder.get_spec_span(self.proc_def_id)
}
}

impl<'p, 'v: 'p, 'tcx: 'v> StubFunctionEncoder<'p, 'v, 'tcx> {
#[tracing::instrument(name = "StubFunctionEncoder::new", level = "trace", skip(encoder, mir))]
pub fn new(
encoder: &'p Encoder<'v, 'tcx>,
proc_def_id: DefId,
mir: &'p mir::Body<'tcx>,
mir: Option<&'p mir::Body<'tcx>>,
substs: GenericArgsRef<'tcx>,
sig: ty::PolyFnSig<'tcx>,
) -> Self {
StubFunctionEncoder {
encoder,
mir,
mir_encoder: MirEncoder::new(encoder, mir, proc_def_id),
proc_def_id,
substs,
sig,
}
}

fn default_span(&self) -> Span {
self.mir
.map(|m| m.span)
.unwrap_or_else(|| self.encoder.get_spec_span(self.proc_def_id))
}

#[tracing::instrument(level = "debug", skip(self))]
pub fn encode_function(&self) -> SpannedEncodingResult<vir::Function> {
let function_name = self.encode_function_name();
debug!("Encode stub function {}", function_name);

let formal_args: Vec<_> = self
.mir
.args_iter()
.map(|local| {
let var_name = self.mir_encoder.encode_local_var_name(local);
let mir_type = self.mir_encoder.get_local_ty(local);
self.encoder
.encode_snapshot_type(mir_type)
.map(|var_type| vir::LocalVar::new(var_name, var_type))
})
.collect::<Result<_, _>>()
.with_span(self.mir.span)?;
let formal_args = self.encode_formal_args(self.sig)?;

let type_arguments = self
.encoder
.encode_generic_arguments(self.proc_def_id, self.substs)
.with_span(self.mir.span)?;
.with_span(self.default_span())?;

let return_type = self.encode_function_return_type()?;

Expand All @@ -74,8 +92,6 @@ impl<'p, 'v: 'p, 'tcx: 'v> StubFunctionEncoder<'p, 'v, 'tcx> {
formal_args,
return_type,
pres: vec![false.into()],
// Note: Silicon is currently unsound when declaring a function that ensures `false`
// See: https://github.com/viperproject/silicon/issues/376
posts: vec![],
body: None,
};
Expand All @@ -94,9 +110,10 @@ impl<'p, 'v: 'p, 'tcx: 'v> StubFunctionEncoder<'p, 'v, 'tcx> {
}

pub fn encode_function_return_type(&self) -> SpannedEncodingResult<vir::Type> {
let ty = self.mir.return_ty();
let return_local = mir::Place::return_place().as_local().unwrap();
let span = self.mir_encoder.get_local_span(return_local);
self.encoder.encode_snapshot_type(ty).with_span(span)
let ty = self.sig.output();

self.encoder
.encode_snapshot_type(ty.skip_binder())
.with_span(self.encoder.get_spec_span(self.proc_def_id))
}
}
Loading

0 comments on commit 202ca0e

Please sign in to comment.