diff --git a/prusti-encoder/src/encoders/type/lifted/cast.rs b/prusti-encoder/src/encoders/type/lifted/cast.rs index 81d5e0943a0..f9c5a5c87bc 100644 --- a/prusti-encoder/src/encoders/type/lifted/cast.rs +++ b/prusti-encoder/src/encoders/type/lifted/cast.rs @@ -192,7 +192,7 @@ where .require_local::>(task_key.actual) .unwrap(); if let CastersEncOutputRef::Casters { make_generic, .. } = generic_cast.cast { - GenericCastOutputRef::Cast(Cast::new(T::to_generic_applicator(make_generic), &[])) + GenericCastOutputRef::Cast(Cast::new(T::to_generic_applicator(make_generic), generic_cast.ty_args)) } else { unreachable!() } diff --git a/prusti-encoder/src/encoders/type/lifted/casters.rs b/prusti-encoder/src/encoders/type/lifted/casters.rs index a57b36e36ba..fceee121ee9 100644 --- a/prusti-encoder/src/encoders/type/lifted/casters.rs +++ b/prusti-encoder/src/encoders/type/lifted/casters.rs @@ -20,10 +20,16 @@ impl CastTypePure { casters: &Casters<'vir, Self>, vcx: &'vir vir::VirCtxt<'_>, snap: vir::ExprGen<'vir, Curr, Next>, + ty_args: &'vir [LiftedTy<'vir, LiftedGeneric<'vir>>], ) -> vir::ExprGen<'vir, Curr, Next> { match casters { CastFunctionsOutputRef::AlreadyGeneric => snap, - CastFunctionsOutputRef::Casters { make_generic, .. } => make_generic.apply(vcx, [snap]), + CastFunctionsOutputRef::Casters { make_generic, .. } => make_generic.apply( + vcx, + &std::iter::once(snap) + .chain(ty_args.iter().map(|t| t.expr(vcx))) + .collect::>(), + ), } } } @@ -194,7 +200,7 @@ impl CastersEncOutputRef { } } -pub type MakeGenericCastFunction<'vir> = FunctionIdent<'vir, UnaryArity<'vir>>; +pub type MakeGenericCastFunction<'vir> = FunctionIdent<'vir, UnknownArity<'vir>>; pub type MakeConcreteCastFunction<'vir> = FunctionIdent<'vir, UnknownArity<'vir>>; /// Takes as input the most generic version (c.f. [`MostGenericTy`]) of a Rust @@ -239,21 +245,23 @@ impl TaskEncoder for CastersEnc { .unwrap() .ty_constructor; - let make_generic_arg_tys = [self_ty]; - let make_generic_ident = FunctionIdent::new( - vir::vir_format_identifier!(vcx, "make_generic_s_{base_name}"), - UnaryArity::new(vcx.alloc(make_generic_arg_tys)), - generic_ref.param_snapshot, - ); - - let make_concrete_ty_params = ty + let ty_params = ty .generics() .into_iter() .map(|g| deps.require_ref::(*g).unwrap()) .collect::>(); + let make_generic_arg_tys = std::iter::once(self_ty) + .chain(ty_params.iter().map(|t| t.ty())) + .collect::>(); + let make_generic_ident = FunctionIdent::new( + vir::vir_format_identifier!(vcx, "make_generic_s_{base_name}"), + UnknownArity::new(vcx.alloc(make_generic_arg_tys)), + generic_ref.param_snapshot, + ); + let make_concrete_arg_tys = std::iter::once(generic_ref.param_snapshot) - .chain(make_concrete_ty_params.iter().map(|t| t.ty())) + .chain(ty_params.iter().map(|t| t.ty())) .collect::>(); let make_concrete_ident = FunctionIdent::new( @@ -272,9 +280,12 @@ impl TaskEncoder for CastersEnc { let make_generic_arg = vcx.mk_local_decl("self", self_ty); let make_generic_expr = vcx.mk_local_ex(make_generic_arg.name, make_generic_arg.ty); - let make_generic_arg_decls = vcx.alloc_slice(&[make_generic_arg]); + let make_generic_arg_decls = vcx.alloc_slice(&std::iter::once(make_generic_arg) + .chain(ty_params.iter().map(|t| t.decl())) + .collect::>() + ); - let make_concrete_ty_param_exprs = make_concrete_ty_params + let make_concrete_ty_param_exprs = ty_params .iter() .map(|t| t.expr(vcx)) .collect::>(); @@ -319,7 +330,7 @@ impl TaskEncoder for CastersEnc { let make_concrete_snap_arg_decl = vcx.mk_local_decl("snap", generic_ref.param_snapshot); let make_concrete_arg_decls = vcx.alloc_slice( &std::iter::once(make_concrete_snap_arg_decl) - .chain(make_concrete_ty_params.iter().map(|t| t.decl())) + .chain(ty_params.iter().map(|t| t.decl())) .collect::>(), ); @@ -331,8 +342,15 @@ impl TaskEncoder for CastersEnc { &make_concrete_ty_param_exprs, ); + let arg_ty_exprs = ty_params + .iter() + .map(|t| vcx.mk_local_ex(t.decl().name, t.decl().ty)) + .collect::>(); + let make_generic_args = std::iter::once(vcx.mk_result(self_ty)) + .chain(arg_ty_exprs) + .collect::>(); let make_concrete_post = vcx.mk_eq_expr( - make_generic_ident.apply(vcx, [vcx.mk_result(self_ty)]), + make_generic_ident.apply(vcx, &make_generic_args), vcx.mk_local_ex( make_concrete_snap_arg_decl.name, make_concrete_snap_arg_decl.ty, @@ -454,14 +472,18 @@ impl TaskEncoder for CastersEnc { let generic_predicate = vcx.mk_predicate_app_expr(generic_predicate); + let make_generic_pure_arg_exprs = std::iter::once(concrete_snap) + .chain(arg_ty_exprs.into_iter()) + .collect::>(); + let make_generic_same_snap = vcx.mk_eq_expr( - vcx.mk_old_expr(make_generic_pure.apply(vcx, [concrete_snap])), + vcx.mk_old_expr(make_generic_pure.apply(vcx, &make_generic_pure_arg_exprs)), generic_snap, ); let make_concrete_same_snap = vcx.mk_eq_expr( vcx.mk_old_expr(generic_snap), - make_generic_pure.apply(vcx, [concrete_snap]), + make_generic_pure.apply(vcx, &make_generic_pure_arg_exprs), ); let make_generic = vcx.mk_method( diff --git a/prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs b/prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs index 7164279ba67..3a14cfcd40f 100644 --- a/prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs +++ b/prusti-encoder/src/encoders/type/lifted/rust_ty_cast.rs @@ -61,7 +61,7 @@ impl<'vir> RustTyGenericCastEncOutput<'vir, CastFunctionsOutputRef<'vir>> { vcx: &'vir vir::VirCtxt<'tcx>, snap: vir::ExprGen<'vir, Curr, Next>, ) -> vir::ExprGen<'vir, Curr, Next> { - CastTypePure::cast_to_generic_if_necessary(&self.cast, vcx, snap) + CastTypePure::cast_to_generic_if_necessary(&self.cast, vcx, snap, self.ty_args) } }