Skip to content

Commit

Permalink
Add type parameters to pure make_generic method for consistency (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
erdmannc authored Jul 15, 2024
1 parent bffc118 commit 07ddee2
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 19 deletions.
2 changes: 1 addition & 1 deletion prusti-encoder/src/encoders/type/lifted/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ where
.require_local::<RustTyCastersEnc<T>>(task_key.actual)
.unwrap();
if let CastersEncOutputRef::Casters { make_generic, .. } = generic_cast.cast {
GenericCastOutputRef::Cast(Cast::new(T::to_generic_applicator(make_generic), &[]))
GenericCastOutputRef::Cast(Cast::new(T::to_generic_applicator(make_generic), generic_cast.ty_args))
} else {
unreachable!()
}
Expand Down
56 changes: 39 additions & 17 deletions prusti-encoder/src/encoders/type/lifted/casters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ impl CastTypePure {
casters: &Casters<'vir, Self>,
vcx: &'vir vir::VirCtxt<'_>,
snap: vir::ExprGen<'vir, Curr, Next>,
ty_args: &'vir [LiftedTy<'vir, LiftedGeneric<'vir>>],
) -> vir::ExprGen<'vir, Curr, Next> {
match casters {
CastFunctionsOutputRef::AlreadyGeneric => snap,
CastFunctionsOutputRef::Casters { make_generic, .. } => make_generic.apply(vcx, [snap]),
CastFunctionsOutputRef::Casters { make_generic, .. } => make_generic.apply(
vcx,
&std::iter::once(snap)
.chain(ty_args.iter().map(|t| t.expr(vcx)))
.collect::<Vec<_>>(),
),
}
}
}
Expand Down Expand Up @@ -194,7 +200,7 @@ impl<G: Copy, C> CastersEncOutputRef<G, C> {
}
}

pub type MakeGenericCastFunction<'vir> = FunctionIdent<'vir, UnaryArity<'vir>>;
pub type MakeGenericCastFunction<'vir> = FunctionIdent<'vir, UnknownArity<'vir>>;
pub type MakeConcreteCastFunction<'vir> = FunctionIdent<'vir, UnknownArity<'vir>>;

/// Takes as input the most generic version (c.f. [`MostGenericTy`]) of a Rust
Expand Down Expand Up @@ -239,21 +245,23 @@ impl TaskEncoder for CastersEnc<CastTypePure> {
.unwrap()
.ty_constructor;

let make_generic_arg_tys = [self_ty];
let make_generic_ident = FunctionIdent::new(
vir::vir_format_identifier!(vcx, "make_generic_s_{base_name}"),
UnaryArity::new(vcx.alloc(make_generic_arg_tys)),
generic_ref.param_snapshot,
);

let make_concrete_ty_params = ty
let ty_params = ty
.generics()
.into_iter()
.map(|g| deps.require_ref::<LiftedGenericEnc>(*g).unwrap())
.collect::<Vec<_>>();

let make_generic_arg_tys = std::iter::once(self_ty)
.chain(ty_params.iter().map(|t| t.ty()))
.collect::<Vec<_>>();
let make_generic_ident = FunctionIdent::new(
vir::vir_format_identifier!(vcx, "make_generic_s_{base_name}"),
UnknownArity::new(vcx.alloc(make_generic_arg_tys)),
generic_ref.param_snapshot,
);

let make_concrete_arg_tys = std::iter::once(generic_ref.param_snapshot)
.chain(make_concrete_ty_params.iter().map(|t| t.ty()))
.chain(ty_params.iter().map(|t| t.ty()))
.collect::<Vec<_>>();

let make_concrete_ident = FunctionIdent::new(
Expand All @@ -272,9 +280,12 @@ impl TaskEncoder for CastersEnc<CastTypePure> {
let make_generic_arg = vcx.mk_local_decl("self", self_ty);
let make_generic_expr = vcx.mk_local_ex(make_generic_arg.name, make_generic_arg.ty);

let make_generic_arg_decls = vcx.alloc_slice(&[make_generic_arg]);
let make_generic_arg_decls = vcx.alloc_slice(&std::iter::once(make_generic_arg)
.chain(ty_params.iter().map(|t| t.decl()))
.collect::<Vec<_>>()
);

let make_concrete_ty_param_exprs = make_concrete_ty_params
let make_concrete_ty_param_exprs = ty_params
.iter()
.map(|t| t.expr(vcx))
.collect::<Vec<_>>();
Expand Down Expand Up @@ -319,7 +330,7 @@ impl TaskEncoder for CastersEnc<CastTypePure> {
let make_concrete_snap_arg_decl = vcx.mk_local_decl("snap", generic_ref.param_snapshot);
let make_concrete_arg_decls = vcx.alloc_slice(
&std::iter::once(make_concrete_snap_arg_decl)
.chain(make_concrete_ty_params.iter().map(|t| t.decl()))
.chain(ty_params.iter().map(|t| t.decl()))
.collect::<Vec<_>>(),
);

Expand All @@ -331,8 +342,15 @@ impl TaskEncoder for CastersEnc<CastTypePure> {
&make_concrete_ty_param_exprs,
);

let arg_ty_exprs = ty_params
.iter()
.map(|t| vcx.mk_local_ex(t.decl().name, t.decl().ty))
.collect::<Vec<_>>();
let make_generic_args = std::iter::once(vcx.mk_result(self_ty))
.chain(arg_ty_exprs)
.collect::<Vec<_>>();
let make_concrete_post = vcx.mk_eq_expr(
make_generic_ident.apply(vcx, [vcx.mk_result(self_ty)]),
make_generic_ident.apply(vcx, &make_generic_args),
vcx.mk_local_ex(
make_concrete_snap_arg_decl.name,
make_concrete_snap_arg_decl.ty,
Expand Down Expand Up @@ -454,14 +472,18 @@ impl TaskEncoder for CastersEnc<CastTypeImpure> {

let generic_predicate = vcx.mk_predicate_app_expr(generic_predicate);

let make_generic_pure_arg_exprs = std::iter::once(concrete_snap)
.chain(arg_ty_exprs.into_iter())
.collect::<Vec<_>>();

let make_generic_same_snap = vcx.mk_eq_expr(
vcx.mk_old_expr(make_generic_pure.apply(vcx, [concrete_snap])),
vcx.mk_old_expr(make_generic_pure.apply(vcx, &make_generic_pure_arg_exprs)),
generic_snap,
);

let make_concrete_same_snap = vcx.mk_eq_expr(
vcx.mk_old_expr(generic_snap),
make_generic_pure.apply(vcx, [concrete_snap]),
make_generic_pure.apply(vcx, &make_generic_pure_arg_exprs),
);

let make_generic = vcx.mk_method(
Expand Down
2 changes: 1 addition & 1 deletion prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl<'vir> RustTyGenericCastEncOutput<'vir, CastFunctionsOutputRef<'vir>> {
vcx: &'vir vir::VirCtxt<'tcx>,
snap: vir::ExprGen<'vir, Curr, Next>,
) -> vir::ExprGen<'vir, Curr, Next> {
CastTypePure::cast_to_generic_if_necessary(&self.cast, vcx, snap)
CastTypePure::cast_to_generic_if_necessary(&self.cast, vcx, snap, self.ty_args)
}
}

Expand Down

0 comments on commit 07ddee2

Please sign in to comment.