From 631c023111844c479ddc7ed87dfc5b0ebe951fd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marin=20Ver=C5=A1i=C4=87?= Date: Mon, 8 Jan 2024 20:17:36 +0300 Subject: [PATCH] add support for resolution of parameters not bounded by trait or self type --- src/lib.rs | 236 ++++++++++++-------- src/main_trait.rs | 2 +- tests/dispatch_with_same_param.rs | 55 +++-- tests/overlapping_trait_and_param_idents.rs | 65 ++++++ 4 files changed, 239 insertions(+), 119 deletions(-) create mode 100644 tests/overlapping_trait_and_param_idents.rs diff --git a/src/lib.rs b/src/lib.rs index 74a17d7..a703229 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -135,10 +135,15 @@ impl quote::ToTokens for TraitBound<'_> { first_args.colon2_token.to_tokens(tokens); quote!(<).to_tokens(tokens); - first_args.args.iter().for_each(|args| match args { - syn::GenericArgument::AssocType(_) => {} - _ => args.to_tokens(tokens), - }); + first_args + .args + .iter() + .filter_map(|arg| match arg { + syn::GenericArgument::AssocType(_) => None, + _ => Some(arg), + }) + .collect::>() + .to_tokens(tokens); quote!(>).to_tokens(tokens); } _ => first_elem.arguments.to_tokens(tokens), @@ -507,26 +512,65 @@ mod param { use quote::format_ident; use syn::{visit::Visit, visit_mut::VisitMut}; - struct NonPredicateParamResolver { - params: FxHashMap, - } - /// Indexer for params used in traits, impl trait or self type, but not predicates. /// /// For `impl, V> Trait for U` resolved indices would be: /// `T` = 0, /// `U` = 1, /// `V` = undetermined - struct NonPredicateParamIndexer<'ast> { - params: FxHashMap<&'ast syn::Ident, Option>, + struct NonPredicateParamIndexer { + indexed_params: FxHashMap, + unindexed_params: FxHashMap, curr_param_pos_idx: usize, } + struct NonPredicateParamResolver { + params: FxHashMap, + } + pub fn resolve_non_predicate_params(item_impl: &mut syn::ItemImpl) { - let mut non_predicate_param_indexer = NonPredicateParamIndexer::new(&item_impl.generics); + let item_impl_generics = item_impl.generics.params.iter().cloned(); + + let mut non_predicate_param_indexer = NonPredicateParamIndexer::new( + item_impl_generics + .map(|param| (get_param_ident(¶m).clone(), param)) + .collect(), + 0, + ); + non_predicate_param_indexer.visit_item_impl(item_impl); - let mut param_resolver = NonPredicateParamResolver::new(non_predicate_param_indexer); - param_resolver.visit_item_impl_mut(item_impl); + + let mut prev_unindexed_params_count = usize::MAX; + let mut indexed_params = non_predicate_param_indexer.indexed_params; + let mut curr_unindexed_params_count = non_predicate_param_indexer.unindexed_params.len(); + + while !non_predicate_param_indexer.unindexed_params.is_empty() + // NOTE: This discards parameters only used in where clause + && prev_unindexed_params_count != curr_unindexed_params_count + { + non_predicate_param_indexer = NonPredicateParamIndexer::new( + non_predicate_param_indexer.unindexed_params, + non_predicate_param_indexer.curr_param_pos_idx, + ); + + non_predicate_param_indexer.visit_indexed_params( + indexed_params + .iter() + .map(|(_, (idx, param))| (*idx, param)) + .collect(), + ); + + prev_unindexed_params_count = curr_unindexed_params_count; + indexed_params.extend(non_predicate_param_indexer.indexed_params); + curr_unindexed_params_count = non_predicate_param_indexer.unindexed_params.len(); + } + + NonPredicateParamResolver::new( + indexed_params + .into_iter() + .map(|(ident, (idx, _))| (ident, idx)), + ) + .visit_item_impl_mut(item_impl); // TODO: Add unnamed lifetimes (&u32) or elided lifetimes (&'_ u32) // TODO: Remove unused lifetimes. Example where 'b is unused: @@ -551,21 +595,44 @@ mod param { format_ident!("_{idx}") } - impl<'ast> NonPredicateParamIndexer<'ast> { - fn new(generics: &'ast syn::Generics) -> Self { - let params = get_param_idents(generics.params.iter()) - .map(|param| (param, None)) - .collect(); + impl NonPredicateParamIndexer { + fn new( + unindexed_params: FxHashMap, + curr_param_pos_idx: usize, + ) -> Self { + let indexed_params = FxHashMap::default(); Self { - params, - curr_param_pos_idx: 0, + indexed_params, + unindexed_params, + curr_param_pos_idx, + } + } + + fn visit_param_ident(&mut self, param_ident: &syn::Ident) -> bool { + if let Some(removed) = self.unindexed_params.remove(param_ident) { + self.indexed_params + .insert(param_ident.clone(), (self.curr_param_pos_idx, removed)); + self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap(); + + return true; + } + + false + } + + fn visit_indexed_params(&mut self, node: FxHashMap) { + let mut indexed_params = node.into_iter().collect::>(); + indexed_params.sort_by_key(|(k, _)| *k); + + for (_, param) in indexed_params { + self.visit_generic_param(param); } } } - impl<'ast> Visit<'ast> for NonPredicateParamIndexer<'ast> { - fn visit_item_impl(&mut self, node: &'ast syn::ItemImpl) { + impl Visit<'_> for NonPredicateParamIndexer { + fn visit_item_impl(&mut self, node: &syn::ItemImpl) { if let Some((_, trait_, _)) = &node.trait_ { // NOTE: Calling `visit_path` on a trait would conflict // with resolving params on `TypePath` so it's not done @@ -578,59 +645,71 @@ mod param { // // had `Visit::visit_path` been used on `T` to resolve // trait generics it would also rename the trait ident itself - let path = &trait_.segments.last().unwrap(); + let path = trait_.segments.last().unwrap(); self.visit_path_arguments(&path.arguments); } self.visit_type(&node.self_ty); } - fn visit_lifetime(&mut self, node: &'ast syn::Lifetime) { - if let Some(lifetime) = self.params.get_mut(&node.ident) { - *lifetime = Some(self.curr_param_pos_idx); - } - - self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap(); + fn visit_lifetime(&mut self, node: &syn::Lifetime) { + self.visit_param_ident(&node.ident); } - fn visit_path(&mut self, node: &'ast syn::Path) { - let path = node.segments.first().unwrap(); - - if let Some(param_idx) = self.params.get_mut(&path.ident) { - if param_idx.is_none() { - *param_idx = Some(self.curr_param_pos_idx); - } - - self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap(); - } else { + fn visit_path(&mut self, node: &syn::Path) { + if !self.visit_param_ident(&node.segments.first().unwrap().ident) { syn::visit::visit_path(self, node); } } - fn visit_expr(&mut self, node: &'ast syn::Expr) { - if let syn::Expr::Path(path) = node { - self.visit_expr_path(path); - } else { - self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap(); - } - } - - fn visit_where_clause(&mut self, _node: &'ast syn::WhereClause) {} + // TODO: Is this required? I don't think it is anymore + //fn visit_expr(&mut self, node: &syn::Expr) { + // if let syn::Expr::Path(path) = node { + // self.visit_expr_path(path); + // } else { + // self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap(); + // } + //} } impl NonPredicateParamResolver { - fn new(indexer: NonPredicateParamIndexer) -> Self { + fn new(params: impl IntoIterator) -> Self { Self { - params: indexer - .params - .into_iter() - .filter_map(|(param, idx)| idx.map(|idx| (param.clone(), idx))) - .collect(), + params: params.into_iter().collect(), } } } impl VisitMut for NonPredicateParamResolver { + fn visit_item_impl_mut(&mut self, node: &mut syn::ItemImpl) { + for attr in &mut node.attrs { + self.visit_attribute_mut(attr); + } + + self.visit_generics_mut(&mut node.generics); + if let Some((_, trait_, _)) = &mut node.trait_ { + // NOTE: Calling `visit_path` on a trait would conflict + // with resolving params on `TypePath` so it's not done + // + // # Example + // + // ``` + // trait T {} + // ``` + // + // had `Visit::visit_path` been used on `T` to resolve + // trait generics it would also rename the trait ident itself + let path = trait_.segments.last_mut().unwrap(); + self.visit_path_arguments_mut(&mut path.arguments); + } + + self.visit_type_mut(&mut node.self_ty); + + for item in &mut node.items { + self.visit_impl_item_mut(item); + } + } + fn visit_lifetime_mut(&mut self, node: &mut syn::Lifetime) { if let Some(&idx) = self.params.get(&node.ident) { node.ident = gen_indexed_param_name(idx); @@ -666,48 +745,11 @@ mod param { } } - //struct PredicateIndexer<'ast> { - // type_params: FxHashMap<&'ast syn::Ident, Option>, - // curr_pos_idx: usize, - //} - //impl<'ast> PredicateIndexer<'ast> { - // fn new(type_params: FxHashMap<&'ast syn::Ident, Option>) -> Self { - // let curr_pos_idx: usize = type_params - // .values() - // .filter_map(|x| *x) - // .reduce(|acc, x| x.max(acc)) - // .unwrap_or(0); - // - // Self { - // type_params, - // curr_pos_idx, - // } - // } - //} - //impl<'ast> Visit<'ast> for PredicateIndexer<'ast> { - // fn visit_item_impl(&mut self, node: &'ast syn::ItemImpl) { - // self.visit_generics(&node.generics); - // } - // - // fn visit_path_segment(&mut self, node: &'ast syn::PathSegment) { - // self.type_params.entry(&node.ident).and_modify(|param_idx| { - // if param_idx.is_none() { - // // Param encountered for the first time - // *param_idx = Some(self.curr_pos_idx); - // } - // }); - // - // self.curr_pos_idx = self.curr_pos_idx.checked_add(1).unwrap(); - // } - //} - - pub fn get_param_idents<'a>( - generic_params: impl Iterator, - ) -> impl Iterator { - generic_params.into_iter().map(|param| match param { - syn::GenericParam::Lifetime(lifetime_param) => &lifetime_param.lifetime.ident, - syn::GenericParam::Type(type_param) => &type_param.ident, - syn::GenericParam::Const(const_param) => &const_param.ident, - }) + pub fn get_param_ident(generic_param: &syn::GenericParam) -> &syn::Ident { + match generic_param { + syn::GenericParam::Lifetime(syn::LifetimeParam { lifetime, .. }) => &lifetime.ident, + syn::GenericParam::Type(syn::TypeParam { ident, .. }) => ident, + syn::GenericParam::Const(syn::ConstParam { ident, .. }) => ident, + } } } diff --git a/src/main_trait.rs b/src/main_trait.rs index f080aeb..f2a32b8 100644 --- a/src/main_trait.rs +++ b/src/main_trait.rs @@ -197,6 +197,7 @@ impl GenericsResolver { let trait_bounds = trait_bounds.into_iter(); quote! { #param_ident: #(#trait_bounds)+* } }); + let where_clause_predicates = assoc_bound_predicates .chain(core::iter::once_with(|| { let helper_trait_bound = @@ -205,7 +206,6 @@ impl GenericsResolver { quote! { Self: #helper_trait_bound } })) .collect(); - Self { assoc_bound_type_params, where_clause_predicates, diff --git a/tests/dispatch_with_same_param.rs b/tests/dispatch_with_same_param.rs index ca582bf..d70c3c1 100644 --- a/tests/dispatch_with_same_param.rs +++ b/tests/dispatch_with_same_param.rs @@ -1,22 +1,22 @@ use disjoint_impls::disjoint_impls; -pub trait Dispatch { +pub trait Dispatch<'a, T> { type Group; } pub enum GroupA {} -impl Dispatch<()> for String { +impl Dispatch<'_, ()> for String { type Group = GroupA; } -impl Dispatch<()> for Vec { +impl Dispatch<'_, ()> for Vec { type Group = GroupA; } pub enum GroupB {} -impl Dispatch<()> for i32 { +impl Dispatch<'_, ()> for i32 { type Group = GroupB; } -impl Dispatch<()> for u32 { +impl Dispatch<'_, ()> for u32 { type Group = GroupB; } @@ -25,28 +25,41 @@ disjoint_impls! { const NAME: &'static str; } - impl> Kita for T { - const NAME: &'static str = "1st Blanket A"; + // NOTE: Dispatch trait parameters must be the same + impl<'b, 'a, T: Dispatch<'b, (), Group = GroupA>> Kita for &'a T { + const NAME: &'static str = "Blanket A"; } - impl> Kita for T { - const NAME: &'static str = "1st Blanket B"; + impl<'a, 'c, T: Dispatch<'a, (), Group = GroupB>> Kita for &'c T { + const NAME: &'static str = "Blanket B"; + } +} + +/* +pub trait Kita { + const NAME: &'static str; +} + +const _: () = { + pub trait _Kita0<_0: ?Sized> { + const NAME: &'static str; } - impl> Kita for T { - const NAME: &'static str = "2nd Blanket A"; + impl<'_2, '_0, _1: Dispatch<'_2, (), Group = GroupA>> _Kita0 for &'_0 _1 { + const NAME: &'static str = "Blanket A"; } - impl> Kita for T { - const NAME: &'static str = "2nd Blanket B"; + impl<'_2, '_0, _1: Dispatch<'_2, (), Group = GroupB>> _Kita0 for &'_0 _1 { + const NAME: &'static str = "Blanket B"; } -} -/* + impl<'_2, '_0, _1> Kita for &'_0 _1 where _1: Dispatch<'_2, ()>, Self: _Kita0<<_1 as Dispatch<'_2, ()>>::Group> { + const NAME: &'static str = >::Group>>::NAME; + } +}; */ -#[test] -fn dispatch_with_same_param() { - assert_eq!("Blanket A", String::NAME); - assert_eq!("Blanket A", Vec::::NAME); - assert_eq!("Blanket B", u32::NAME); - assert_eq!("Blanket B", i32::NAME); +fn main() { + assert_eq!("Blanket A", <&String>::NAME); + assert_eq!("Blanket A", <&Vec::>::NAME); + assert_eq!("Blanket B", <&u32>::NAME); + assert_eq!("Blanket B", <&i32>::NAME); } diff --git a/tests/overlapping_trait_and_param_idents.rs b/tests/overlapping_trait_and_param_idents.rs new file mode 100644 index 0000000..74f920e --- /dev/null +++ b/tests/overlapping_trait_and_param_idents.rs @@ -0,0 +1,65 @@ +use disjoint_impls::disjoint_impls; + +pub trait Dispatch { + type Group; +} + +pub enum GroupA {} +impl Dispatch for String { + type Group = GroupA; +} +impl Dispatch for Vec { + type Group = GroupA; +} + +pub enum GroupB {} +impl Dispatch for i32 { + type Group = GroupB; +} +impl Dispatch for u32 { + type Group = GroupB; +} + +disjoint_impls! { + pub trait U where U: From { + const NAME: &'static str; + } + + impl> U for T where U: From { + const NAME: &'static str = "Blanket A"; + } + impl> U for T where U: From { + const NAME: &'static str = "Blanket B"; + } +} + +/* +pub trait U where U: From + From { + const NAME: &'static str; +} +const _: () = { + pub trait _U0<_0: ?Sized, U> where U: From + From { + const NAME: &'static str; + } + + impl<_0, _1: Dispatch> _U0 for _1 where _0: From + From { + const NAME: &'static str = "Blanket A"; + } + impl<_0, _1: Dispatch> _U0 for _1 where _0: From + From { + const NAME: &'static str = "Blanket B"; + } + + impl<_0, _1> U<_0> for _1 where _0: From + From, _1: Dispatch, Self: _U0<<_1 as Dispatch>::Group, _0> { + const NAME: &'static str = ::Group, _0>>::NAME; + } +}; + +*/ + +#[test] +fn overlapping_trait_and_param_idents() { + assert_eq!("Blanket A", >::NAME); + assert_eq!("Blanket A", as U>::NAME); + assert_eq!("Blanket B", >::NAME); + assert_eq!("Blanket B", >::NAME); +}