Skip to content

Commit

Permalink
Support cycles in encoders (#47)
Browse files Browse the repository at this point in the history
* domain fields don't need a full encoding of their type yet

* type alias for do_encode_full result

* parameterise TaskEncoderDependencies by the owning encoder

* remove some dependency unwraps

* remove 'tcx lifetime, use 'vir

* check for cycles when requesting dependencies or emitting output ref

* add try operators for some emit output refs
  • Loading branch information
Aurel300 authored May 14, 2024
1 parent c3b21a1 commit c3b2a75
Show file tree
Hide file tree
Showing 34 changed files with 634 additions and 798 deletions.
16 changes: 8 additions & 8 deletions prusti-encoder/src/encoder_traits/function_enc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ pub trait FunctionEnc
/// this should be the identity substitution obtained from the DefId of the
/// function. For the monomorphic encoding, the substitutions at the call
/// site should be used.
fn get_substs<'tcx>(
vcx: &vir::VirCtxt<'tcx>,
substs_src: &Self::TaskKey<'tcx>,
) -> &'tcx GenericArgs<'tcx>;
fn get_substs<'vir>(
vcx: &vir::VirCtxt<'vir>,
substs_src: &Self::TaskKey<'vir>,
) -> &'vir GenericArgs<'vir>;
}

/// Implementation for polymorphic encoding
Expand All @@ -35,10 +35,10 @@ impl <T: 'static + for<'vir> TaskEncoder<TaskKey<'vir> = DefId>> FunctionEnc for
None
}

fn get_substs<'tcx>(
vcx: &vir::VirCtxt<'tcx>,
def_id: &Self::TaskKey<'tcx>,
) -> &'tcx GenericArgs<'tcx> {
fn get_substs<'vir>(
vcx: &vir::VirCtxt<'vir>,
def_id: &Self::TaskKey<'vir>,
) -> &'vir GenericArgs<'vir> {
GenericArgs::identity_for_item(vcx.tcx(), *def_id)
}

Expand Down
34 changes: 17 additions & 17 deletions prusti-encoder/src/encoder_traits/impure_function_enc.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use prusti_rustc_interface::middle::mir;
use task_encoder::{TaskEncoder, TaskEncoderDependencies};
use task_encoder::{EncodeFullError, TaskEncoder, TaskEncoderDependencies};
use vir::{MethodIdent, UnknownArity, ViperIdent};

use crate::encoders::{
Expand Down Expand Up @@ -33,15 +33,18 @@ where
{
/// Generates the identifier for the method; for a monomorphic encoding,
/// this should be a name including (mangled) type arguments
fn mk_method_ident<'vir, 'tcx>(
vcx: &'vir vir::VirCtxt<'tcx>,
task_key: &Self::TaskKey<'tcx>,
fn mk_method_ident<'vir>(
vcx: &'vir vir::VirCtxt<'vir>,
task_key: &Self::TaskKey<'vir>,
) -> ViperIdent<'vir>;

fn encode<'vir, 'tcx: 'vir>(
task_key: Self::TaskKey<'tcx>,
deps: &mut TaskEncoderDependencies<'vir>,
) -> ImpureFunctionEncOutput<'vir> {
fn encode<'vir>(
task_key: Self::TaskKey<'vir>,
deps: &mut TaskEncoderDependencies<'vir, Self>,
) -> Result<
ImpureFunctionEncOutput<'vir>,
EncodeFullError<'vir, Self>,
> {
let def_id = Self::get_def_id(&task_key);
let caller_def_id = Self::get_caller_def_id(&task_key);
let trusted = crate::encoders::with_proc_spec(def_id, |def_spec| {
Expand All @@ -52,8 +55,7 @@ where
use mir::visit::Visitor;
let substs = Self::get_substs(vcx, &task_key);
let local_defs = deps
.require_local::<MirLocalDefEnc>((def_id, substs, caller_def_id))
.unwrap();
.require_local::<MirLocalDefEnc>((def_id, substs, caller_def_id))?;

// Argument count for the Viper method:
// - one (`Ref`) for the return place;
Expand All @@ -70,15 +72,14 @@ where
let method_name = Self::mk_method_ident(vcx, &task_key);
let mut args = vec![&vir::TypeData::Ref; arg_count];
let param_ty_decls = deps
.require_local::<LiftedTyParamsEnc>(substs)
.unwrap()
.require_local::<LiftedTyParamsEnc>(substs)?
.iter()
.map(|g| g.decl())
.collect::<Vec<_>>();
args.extend(param_ty_decls.iter().map(|decl| decl.ty));
let args = UnknownArity::new(vcx.alloc_slice(&args));
let method_ref = MethodIdent::new(method_name, args);
deps.emit_output_ref::<Self>(task_key, ImpureFunctionEncOutputRef { method_ref });
deps.emit_output_ref(task_key, ImpureFunctionEncOutputRef { method_ref })?;

// Do not encode the method body if it is external, trusted or just
// a call stub.
Expand Down Expand Up @@ -157,8 +158,7 @@ where
};

let spec = deps
.require_local::<MirSpecEnc>((def_id, substs, None, false))
.unwrap();
.require_local::<MirSpecEnc>((def_id, substs, None, false))?;
let (spec_pres, spec_posts) = (spec.pres, spec.posts);

let mut pres = Vec::with_capacity(arg_count - 1);
Expand All @@ -177,7 +177,7 @@ where
posts.push(local_defs.locals[mir::RETURN_PLACE].impure_pred);
posts.extend(spec_posts);

ImpureFunctionEncOutput {
Ok(ImpureFunctionEncOutput {
method: vcx.mk_method(
method_ref,
vcx.alloc_slice(&args),
Expand All @@ -186,7 +186,7 @@ where
vcx.alloc_slice(&posts),
blocks,
),
}
})
})
}
}
30 changes: 15 additions & 15 deletions prusti-encoder/src/encoder_traits/pure_func_app_enc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use prusti_rustc_interface::{
},
span::def_id::DefId,
};
use task_encoder::TaskEncoderDependencies;
use task_encoder::{TaskEncoder, TaskEncoderDependencies};

use crate::encoders::{
lifted::{
Expand All @@ -16,7 +16,7 @@ use crate::encoders::{
/// Encoders (such as [`crate::encoders::MirPureEnc`],
/// [`crate::encoders::MirImpureEnc`]) implement this trait to encode
/// applications of Rust functions annotated as pure.
pub trait PureFuncAppEnc<'tcx: 'vir, 'vir> {
pub trait PureFuncAppEnc<'vir, E: TaskEncoder + 'vir + ?Sized> {
/// Extra arguments required for the encoder to encode an argument to the
/// function (in mir this is an `Operand`)
type EncodeOperandArgs;
Expand All @@ -29,32 +29,32 @@ pub trait PureFuncAppEnc<'tcx: 'vir, 'vir> {

/// The type of the data source that can provide local declarations; this is used
/// when getting the type of the function.
type LocalDeclsSrc: ?Sized + HasLocalDecls<'tcx>;
type LocalDeclsSrc: ?Sized + HasLocalDecls<'vir>;

// Are we monomorphizing functions?
fn monomorphize(&self) -> bool;

/// Task encoder dependencies are required for encoding Viper casts between
/// generic and concrete types.
fn deps(&mut self) -> &mut TaskEncoderDependencies<'vir>;
fn deps(&mut self) -> &mut TaskEncoderDependencies<'vir, E>;

/// The data source that can provide local declarations, necesary for determining
/// the function type
fn local_decls_src(&self) -> &Self::LocalDeclsSrc;
fn vcx(&self) -> &'vir vir::VirCtxt<'tcx>;
fn vcx(&self) -> &'vir vir::VirCtxt<'vir>;

/// Encodes an operand (an argument to a function) as a pure Viper expression.
fn encode_operand(
&mut self,
args: &Self::EncodeOperandArgs,
operand: &mir::Operand<'tcx>,
operand: &mir::Operand<'vir>,
) -> vir::ExprGen<'vir, Self::Curr, Self::Next>;

/// Obtains the function's definition ID and the substitutions made at the callsite
fn get_def_id_and_caller_substs(
&self,
func: &mir::Operand<'tcx>,
) -> (DefId, &'tcx List<GenericArg<'tcx>>) {
func: &mir::Operand<'vir>,
) -> (DefId, &'vir List<GenericArg<'vir>>) {
let func_ty = func.ty(self.local_decls_src(), self.vcx().tcx());
match func_ty.kind() {
&ty::TyKind::FnDef(def_id, arg_tys) => (def_id, arg_tys),
Expand All @@ -67,9 +67,9 @@ pub trait PureFuncAppEnc<'tcx: 'vir, 'vir> {
/// are inserted to convert from/to generic and concrete arguments as necessary.
fn encode_fn_args(
&mut self,
sig: Binder<'tcx, FnSig<'tcx>>,
substs: &'tcx List<GenericArg<'tcx>>,
args: &[mir::Operand<'tcx>],
sig: Binder<'vir, FnSig<'vir>>,
substs: &'vir List<GenericArg<'vir>>,
args: &[mir::Operand<'vir>],
encode_operand_args: &Self::EncodeOperandArgs,
) -> Vec<vir::ExprGen<'vir, Self::Curr, Self::Next>> {
let mono = self.monomorphize();
Expand Down Expand Up @@ -118,10 +118,10 @@ pub trait PureFuncAppEnc<'tcx: 'vir, 'vir> {
fn encode_pure_func_app(
&mut self,
def_id: DefId,
sig: Binder<'tcx, FnSig<'tcx>>,
substs: &'tcx List<GenericArg<'tcx>>,
args: &Vec<mir::Operand<'tcx>>,
destination: &mir::Place<'tcx>,
sig: Binder<'vir, FnSig<'vir>>,
substs: &'vir List<GenericArg<'vir>>,
args: &Vec<mir::Operand<'vir>>,
destination: &mir::Place<'vir>,
caller_def_id: DefId,
encode_operand_args: &Self::EncodeOperandArgs,
) -> vir::ExprGen<'vir, Self::Curr, Self::Next> {
Expand Down
30 changes: 15 additions & 15 deletions prusti-encoder/src/encoder_traits/pure_function_enc.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use prusti_rustc_interface::
middle::{mir, ty::Ty}
;
use prusti_rustc_interface::{
middle::{mir, ty::{GenericArgs, Ty}},
span::def_id::DefId,
};
use task_encoder::{TaskEncoder, TaskEncoderDependencies};
use vir::{CallableIdent, ExprGen, FunctionIdent, Reify, UnknownArity, ViperIdent};

Expand Down Expand Up @@ -33,21 +34,20 @@ where

/// Generates the identifier for the function; for a monomorphic encoding,
/// this should be a name including (mangled) type arguments
fn mk_function_ident<'vir, 'tcx>(
vcx: &'vir vir::VirCtxt<'tcx>,
task_key: &Self::TaskKey<'tcx>,
fn mk_function_ident<'vir>(
vcx: &'vir vir::VirCtxt<'vir>,
task_key: &Self::TaskKey<'vir>,
) -> ViperIdent<'vir>;


/// Adds an assertion connecting the type of an argument (or return) of the
/// function with the appropriate type based on the param, e.g. in f<T,
/// U>(u: U) -> T, this would be called to require that the type of `u` be
/// `U`
fn mk_type_assertion<'vir, 'tcx: 'vir, Curr, Next>(
vcx: &'vir vir::VirCtxt<'tcx>,
deps: &mut TaskEncoderDependencies<'vir>,
fn mk_type_assertion<'vir, Curr, Next>(
vcx: &'vir vir::VirCtxt<'vir>,
deps: &mut TaskEncoderDependencies<'vir, Self>,
arg: ExprGen<'vir, Curr, Next>, // Snapshot encoded argument
ty: Ty<'tcx>,
ty: Ty<'vir>,
) -> Option<ExprGen<'vir, Curr, Next>> {
let lifted_ty = deps
.require_local::<LiftedTyEnc<EncodeGenericsAsLifted>>(ty)
Expand Down Expand Up @@ -77,9 +77,9 @@ where
}
}

fn encode<'vir, 'tcx: 'vir>(
task_key: Self::TaskKey<'tcx>,
deps: &mut TaskEncoderDependencies<'vir>,
fn encode<'vir>(
task_key: Self::TaskKey<'vir>,
deps: &mut TaskEncoderDependencies<'vir, Self>,
) -> MirFunctionEncOutput<'vir> {
let def_id = Self::get_def_id(&task_key);
let caller_def_id = Self::get_caller_def_id(&task_key);
Expand All @@ -106,7 +106,7 @@ where
let ident_args = UnknownArity::new(vcx.alloc_slice(&ident_args));
let return_type = local_defs.locals[mir::RETURN_PLACE].ty;
let function_ref = FunctionIdent::new(function_ident, ident_args, return_type.snapshot);
deps.emit_output_ref::<Self>(task_key, MirFunctionEncOutputRef { function_ref });
deps.emit_output_ref(task_key, MirFunctionEncOutputRef { function_ref });

let spec = deps
.require_local::<MirSpecEnc>((def_id, substs, None, true))
Expand Down
34 changes: 14 additions & 20 deletions prusti-encoder/src/encoders/const.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use rustc_middle::mir::interpret::{ConstValue, Scalar, GlobalAlloc};
use task_encoder::{
TaskEncoder,
TaskEncoderDependencies,
EncodeFullResult,
};
use vir::{CallableIdent, Arity};

Expand All @@ -25,8 +26,8 @@ use super::{lifted::{casters::CastTypePure, rust_ty_cast::RustTyCastersEnc}, rus
impl TaskEncoder for ConstEnc {
task_encoder::encoder_cache!(ConstEnc);

type TaskDescription<'tcx> = (
mir::ConstantKind<'tcx>,
type TaskDescription<'vir> = (
mir::ConstantKind<'vir>,
usize, // current encoding depth
DefId, // DefId of the current function
);
Expand All @@ -37,21 +38,15 @@ impl TaskEncoder for ConstEnc {
*task
}

fn do_encode_full<'tcx: 'vir, 'vir>(
task_key: &Self::TaskKey<'tcx>,
deps: &mut TaskEncoderDependencies<'vir>,
) -> Result<(
Self::OutputFullLocal<'vir>,
Self::OutputFullDependency<'vir>,
), (
Self::EncodingError,
Option<Self::OutputFullDependency<'vir>>,
)> {
deps.emit_output_ref::<Self>(*task_key, ());
fn do_encode_full<'vir>(
task_key: &Self::TaskKey<'vir>,
deps: &mut TaskEncoderDependencies<'vir, Self>,
) -> EncodeFullResult<'vir, Self> {
deps.emit_output_ref(*task_key, ())?;
let (const_, encoding_depth, def_id) = *task_key;
let res = match const_ {
mir::ConstantKind::Val(val, ty) => {
let kind = deps.require_local::<RustTySnapshotsEnc>(ty).unwrap().generic_snapshot.specifics;
let kind = deps.require_local::<RustTySnapshotsEnc>(ty)?.generic_snapshot.specifics;
match val {
ConstValue::Scalar(Scalar::Int(int)) => {
let prim = kind.expect_primitive();
Expand Down Expand Up @@ -84,12 +79,11 @@ impl TaskEncoder for ConstEnc {
let ref_ty = kind.expect_structlike();
let str_ty = ty.peel_refs();
let str_snap = deps
.require_local::<RustTySnapshotsEnc>(str_ty)
.unwrap()
.require_local::<RustTySnapshotsEnc>(str_ty)?
.generic_snapshot
.specifics
.expect_structlike();
let cast = deps.require_local::<RustTyCastersEnc<CastTypePure>>(str_ty).unwrap();
let cast = deps.require_local::<RustTyCastersEnc<CastTypePure>>(str_ty)?;
vir::with_vcx(|vcx| {
// first, we create a string snapshot
let snap = str_snap.field_snaps_to_snap.apply(vcx, &[]);
Expand All @@ -112,10 +106,10 @@ impl TaskEncoder for ConstEnc {
kind: PureKind::Constant(uneval.promoted.unwrap()),
caller_def_id: Some(def_id)
};
let expr = deps.require_local::<MirPureEnc>(task).unwrap().expr;
let expr = deps.require_local::<MirPureEnc>(task)?.expr;
use vir::Reify;
expr.reify(vcx, (uneval.def, &[]))
}),
Ok(expr.reify(vcx, (uneval.def, &[])))
})?,
mir::ConstantKind::Ty(_) => todo!("ConstantKind::Ty"),
};
Ok((res, ()))
Expand Down
25 changes: 8 additions & 17 deletions prusti-encoder/src/encoders/generic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use task_encoder::{TaskEncoder, TaskEncoderDependencies};
use task_encoder::{TaskEncoder, TaskEncoderDependencies, EncodeFullResult};
use vir::{
BinaryArity, CallableIdent, DomainIdent, DomainParamData, FunctionIdent,
KnownArityAny, NullaryArity, PredicateIdent, TypeData, UnaryArity, ViperIdent,
Expand Down Expand Up @@ -39,7 +39,7 @@ const SNAPSHOT_PARAM_DOMAIN: TypeData<'static> = TypeData::Domain("s_Param", &[]
impl TaskEncoder for GenericEnc {
task_encoder::encoder_cache!(GenericEnc);

type TaskDescription<'tcx> = (); // ?
type TaskDescription<'vir> = (); // ?

type OutputRef<'vir> = GenericEncOutputRef<'vir>;
type OutputFullLocal<'vir> = GenericEncOutput<'vir>;
Expand All @@ -51,19 +51,10 @@ impl TaskEncoder for GenericEnc {
}

#[allow(non_snake_case)]
fn do_encode_full<'tcx: 'vir, 'vir>(
task_key: &Self::TaskKey<'tcx>,
deps: &mut TaskEncoderDependencies<'vir>,
) -> Result<
(
Self::OutputFullLocal<'vir>,
Self::OutputFullDependency<'vir>,
),
(
Self::EncodingError,
Option<Self::OutputFullDependency<'vir>>,
),
> {
fn do_encode_full<'vir>(
task_key: &Self::TaskKey<'vir>,
deps: &mut TaskEncoderDependencies<'vir, Self>,
) -> EncodeFullResult<'vir, Self> {
let ref_to_pred =
PredicateIdent::new(ViperIdent::new("p_Param"), BinaryArity::new(&[&TypeData::Ref, &TYP_DOMAIN]));
let type_domain_ident = DomainIdent::nullary(ViperIdent::new("Type"));
Expand Down Expand Up @@ -98,10 +89,10 @@ impl TaskEncoder for GenericEnc {
};

#[allow(clippy::unit_arg)]
deps.emit_output_ref::<Self>(
deps.emit_output_ref(
*task_key,
output_ref
);
)?;

let typ = FunctionIdent::new(
ViperIdent::new("typ"),
Expand Down
Loading

0 comments on commit c3b2a75

Please sign in to comment.