Skip to content

Commit

Permalink
Impure Generics (#45)
Browse files Browse the repository at this point in the history
* Impure generics

* Fixed bug due to type parameters not being encoded

* Remove fractional perm

* mk_bool doesn't need extra typarams
  • Loading branch information
zgrannan authored May 14, 2024
1 parent 39c876f commit c3b21a1
Show file tree
Hide file tree
Showing 41 changed files with 1,793 additions and 532 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions local-testing/generics/add.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use prusti_contracts::*;

#[pure]
fn id<T>(x: T) -> T {
x
}

#[pure]
fn main(){
let x = id(1);
let y = id(2);
let z = x + y;
}
20 changes: 20 additions & 0 deletions local-testing/generics/container.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use prusti_contracts::*;

struct Container<T> {
content: T,
}

#[requires(container.content == 123)]
#[ensures(result.content == 124)]
fn increment_container(container: Container<i32>) -> Container<i32> {
Container { content: container.content + 1 }
}

#[ensures(result > 123)]
fn client() -> i32 {
let num_container = Container { content: 123 };
let incremented_container = increment_container(num_container);
incremented_container.content
}

fn main(){}
7 changes: 7 additions & 0 deletions local-testing/generics/impure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use prusti_contracts::*;

pub fn test<T>(x: i8, y: T) -> T{ y }

fn main(){
let a = test(-127, 11);
}
7 changes: 7 additions & 0 deletions local-testing/generics/nested2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
struct R(Option<Box<u32>>);
fn main() {
match R(None).0 {
Some(_) => (),
_ => (),
}
}
35 changes: 35 additions & 0 deletions local-testing/generics/pair2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use prusti_contracts::*;

struct Pair<T, U> {
first: T,
second: U,
}

#[requires(pair.second == true)]
#[ensures(result.second == true)]
fn copy_pair<T>(pair: Pair<T, bool>) -> Pair<T, bool> {
Pair {
first: pair.first,
second: pair.second
}
}

fn fst<T, U>(pair: Pair<T, U>) -> T {
let unused = pair.second;
pair.first
}

#[ensures(result == true)]
fn client() -> bool {
let initial_pair = Pair {
first: 42u32,
second: true
};
let copied_pair = copy_pair(initial_pair);
copied_pair.second
}

fn main() {
let pair = Pair { first: 1, second: 2 };
fst(pair);
}
26 changes: 26 additions & 0 deletions local-testing/generics/point.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use prusti_contracts::*;

struct Point<T> {
x: T,
y: T,
}

impl<T: Copy> Point<T> {
#[pure]
fn new(x: T, y: T) -> Point<T> {
Point { x, y }
}

#[pure]
fn x(self) -> T {
self.x
}

#[pure]
fn y(self) -> T {
self.y
}
}

#[ensures(Point::new(1, 2).x() == 1)]
fn main() {}
3 changes: 2 additions & 1 deletion prusti-encoder/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ edition = "2021"
doctest = false # we have no doc tests

[dependencies]
cfg-if = "1.0.0"
prusti-rustc-interface = { path = "../prusti-rustc-interface" }
prusti-interface = { path = "../prusti-interface" }
mir-ssa-analysis = { path = "../mir-ssa-analysis" }
Expand All @@ -21,6 +22,6 @@ tracing = { path = "../tracing" }
rustc_private = true

[features]
default = ["mono_function_encoding"]
# default = ["mono_function_encoding"]
vir_debug = ["vir/vir_debug"]
mono_function_encoding = []
45 changes: 45 additions & 0 deletions prusti-encoder/src/encoder_traits/function_enc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use task_encoder::TaskEncoder;
use prusti_rustc_interface::{
middle::ty::GenericArgs,
span::def_id::DefId
};

/// Task encoders for Rust functions should implement this trait.
pub trait FunctionEnc
where
Self: 'static + Sized + TaskEncoder
{
/// Obtains the function's [`DefId`] from the task key
fn get_def_id(task_key: &Self::TaskKey<'_>) -> DefId;

/// Obtains the caller's [`DefId`] from the task key, if possible
fn get_caller_def_id(task_key: &Self::TaskKey<'_>) -> Option<DefId>;

/// Obtains type substitutions for the function. For polymorphic encoding,
/// this should be the identity substitution obtained from the DefId of the
/// function. For the monomorphic encoding, the substitutions at the call
/// site should be used.
fn get_substs<'tcx>(
vcx: &vir::VirCtxt<'tcx>,
substs_src: &Self::TaskKey<'tcx>,
) -> &'tcx GenericArgs<'tcx>;
}

/// Implementation for polymorphic encoding
impl <T: 'static + for<'vir> TaskEncoder<TaskKey<'vir> = DefId>> FunctionEnc for T {
fn get_def_id(task_key: &Self::TaskKey<'_>) -> DefId {
*task_key
}

fn get_caller_def_id(_: &Self::TaskKey<'_>) -> Option<DefId> {
None
}

fn get_substs<'tcx>(
vcx: &vir::VirCtxt<'tcx>,
def_id: &Self::TaskKey<'tcx>,
) -> &'tcx GenericArgs<'tcx> {
GenericArgs::identity_for_item(vcx.tcx(), *def_id)
}

}
192 changes: 192 additions & 0 deletions prusti-encoder/src/encoder_traits/impure_function_enc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
use prusti_rustc_interface::middle::mir;
use task_encoder::{TaskEncoder, TaskEncoderDependencies};
use vir::{MethodIdent, UnknownArity, ViperIdent};

use crate::encoders::{
lifted::func_def_ty_params::LiftedTyParamsEnc, ImpureEncVisitor, MirImpureEnc, MirLocalDefEnc, MirSpecEnc
};

use super::function_enc::FunctionEnc;

#[derive(Clone, Debug)]
pub struct ImpureFunctionEncError;

#[derive(Clone, Debug)]
pub struct ImpureFunctionEncOutputRef<'vir> {
pub method_ref: MethodIdent<'vir, UnknownArity<'vir>>,
}
impl<'vir> task_encoder::OutputRefAny for ImpureFunctionEncOutputRef<'vir> {}

#[derive(Clone, Debug)]
pub struct ImpureFunctionEncOutput<'vir> {
pub method: vir::Method<'vir>,
}

const ENCODE_REACH_BB: bool = false;

pub trait ImpureFunctionEnc
where
Self: 'static
+ Sized
+ FunctionEnc
+ for<'vir> TaskEncoder<OutputRef<'vir> = ImpureFunctionEncOutputRef<'vir>>,
{
/// Generates the identifier for the method; for a monomorphic encoding,
/// this should be a name including (mangled) type arguments
fn mk_method_ident<'vir, 'tcx>(
vcx: &'vir vir::VirCtxt<'tcx>,
task_key: &Self::TaskKey<'tcx>,
) -> ViperIdent<'vir>;

fn encode<'vir, 'tcx: 'vir>(
task_key: Self::TaskKey<'tcx>,
deps: &mut TaskEncoderDependencies<'vir>,
) -> ImpureFunctionEncOutput<'vir> {
let def_id = Self::get_def_id(&task_key);
let caller_def_id = Self::get_caller_def_id(&task_key);
let trusted = crate::encoders::with_proc_spec(def_id, |def_spec| {
def_spec.trusted.extract_inherit().unwrap_or_default()
})
.unwrap_or_default();
vir::with_vcx(|vcx| {
use mir::visit::Visitor;
let substs = Self::get_substs(vcx, &task_key);
let local_defs = deps
.require_local::<MirLocalDefEnc>((def_id, substs, caller_def_id))
.unwrap();

// Argument count for the Viper method:
// - one (`Ref`) for the return place;
// - one (`Ref`) for each MIR argument.
//
// Note that the return place is modelled as an argument of the
// Viper method. This corresponds to an execution model where the
// method can return data to the caller without a copy--it directly
// modifies a place provided by the caller.
//
// TODO: type parameters
let arg_count = local_defs.arg_count + 1;

let method_name = Self::mk_method_ident(vcx, &task_key);
let mut args = vec![&vir::TypeData::Ref; arg_count];
let param_ty_decls = deps
.require_local::<LiftedTyParamsEnc>(substs)
.unwrap()
.iter()
.map(|g| g.decl())
.collect::<Vec<_>>();
args.extend(param_ty_decls.iter().map(|decl| decl.ty));
let args = UnknownArity::new(vcx.alloc_slice(&args));
let method_ref = MethodIdent::new(method_name, args);
deps.emit_output_ref::<Self>(task_key, ImpureFunctionEncOutputRef { method_ref });

// Do not encode the method body if it is external, trusted or just
// a call stub.
let local_def_id = def_id.as_local().filter(|_| !trusted);
let blocks = if let Some(local_def_id) = local_def_id {
let body = vcx
.body_mut()
.get_impure_fn_body(local_def_id, substs, caller_def_id);
// let body = vcx.tcx().mir_promoted(local_def_id).0.borrow();

let fpcs_analysis = mir_state_analysis::run_free_pcs(&body, vcx.tcx());

//let ssa_analysis = SsaAnalysis::analyse(&body);

let block_count = body.basic_blocks.len();

// Local count for the Viper method:
// - one for each basic block;
// - one (`Ref`) for each non-argument, non-return local.
let _local_count = block_count + 1 * (body.local_decls.len() - arg_count);

let mut encoded_blocks = Vec::with_capacity(
// extra blocks: Start, End
2 + block_count,
);
let mut start_stmts = Vec::new();
for local in (arg_count..body.local_decls.len()).map(mir::Local::from) {
let name_p = local_defs.locals[local].local.name;
start_stmts.push(
vcx.mk_local_decl_stmt(vir::vir_local_decl! { vcx; [name_p] : Ref }, None),
)
}
if ENCODE_REACH_BB {
start_stmts.extend((0..block_count).map(|block| {
let name = vir::vir_format!(vcx, "_reach_bb{block}");
vcx.mk_local_decl_stmt(
vir::vir_local_decl! { vcx; [name] : Bool },
Some(vcx.mk_todo_expr("false")),
)
}));
}
encoded_blocks.push(vcx.mk_cfg_block(
vcx.alloc(vir::CfgBlockLabelData::Start),
vcx.alloc_slice(&start_stmts),
vcx.mk_goto_stmt(vcx.alloc(vir::CfgBlockLabelData::BasicBlock(0))),
));

let mut visitor = ImpureEncVisitor {
monomorphize: MirImpureEnc::monomorphize(),
vcx,
deps,
def_id,
local_decls: &body.local_decls,
//ssa_analysis,
fpcs_analysis,
local_defs,

tmp_ctr: 0,

current_fpcs: None,

current_stmts: None,
current_terminator: None,
encoded_blocks,
};
visitor.visit_body(&body);

visitor.encoded_blocks.push(vcx.mk_cfg_block(
vcx.alloc(vir::CfgBlockLabelData::End),
&[],
vcx.alloc(vir::TerminatorStmtData::Exit),
));
Some(vcx.alloc_slice(&visitor.encoded_blocks))
} else {
None
};

let spec = deps
.require_local::<MirSpecEnc>((def_id, substs, None, false))
.unwrap();
let (spec_pres, spec_posts) = (spec.pres, spec.posts);

let mut pres = Vec::with_capacity(arg_count - 1);
let mut args = Vec::with_capacity(arg_count + substs.len());
for arg_idx in 0..arg_count {
let name_p = local_defs.locals[arg_idx.into()].local.name;
args.push(vir::vir_local_decl! { vcx; [name_p] : Ref });
if arg_idx != 0 {
pres.push(local_defs.locals[arg_idx.into()].impure_pred);
}
}
args.extend(param_ty_decls.iter());
pres.extend(spec_pres);

let mut posts = Vec::with_capacity(spec_posts.len() + 1);
posts.push(local_defs.locals[mir::RETURN_PLACE].impure_pred);
posts.extend(spec_posts);

ImpureFunctionEncOutput {
method: vcx.mk_method(
method_ref,
vcx.alloc_slice(&args),
&[],
vcx.alloc_slice(&pres),
vcx.alloc_slice(&posts),
blocks,
),
}
})
}
}
2 changes: 2 additions & 0 deletions prusti-encoder/src/encoder_traits/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
pub mod pure_function_enc;
pub mod pure_func_app_enc;
pub mod function_enc;
pub mod impure_function_enc;
Loading

0 comments on commit c3b21a1

Please sign in to comment.