From 2acb05748fdef975c42d6974139dac233cee73aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marin=20Ver=C5=A1i=C4=87?= Date: Sun, 26 Nov 2023 22:41:40 +0300 Subject: [PATCH] [feature] #2: Keep original trait param identifiers --- CHANGELOG.md | 1 + src/lib.rs | 297 +++++++++++++++++++--------------------------- src/main_trait.rs | 20 ++-- 3 files changed, 133 insertions(+), 185 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f5a320..ecce4b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support bounds for main trait parameters - Support const parameters in main trait +- Keep original trait param identifiers ## [0.5.0] - 2023-11-23 diff --git a/src/lib.rs b/src/lib.rs index 5c6c4b0..d4fc6bd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,6 @@ mod disjoint; mod main_trait; mod validate; -use param::ParamResolver; use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use proc_macro_error::{abort, proc_macro_error, OptionExt}; @@ -362,16 +361,12 @@ pub fn disjoint_impls(input: TokenStream) -> TokenStream { impl Parse for ItemImpls { fn parse(input: ParseStream) -> syn::parse::Result { - let mut trait_ = input.parse::().ok(); - - trait_ - .as_mut() - .map(ParamResolver::resolve_non_predicate_params); + let main_trait = input.parse::().ok(); let mut item_impls = FxHashMap::default(); while let Ok(mut item) = input.parse::() { // TODO: Resolve predicate param idents - item.resolve_non_predicate_params(); + param::resolve_non_predicate_params(&mut item, main_trait.as_ref()); item_impls .entry((*item.self_ty).clone()) @@ -379,7 +374,7 @@ impl Parse for ItemImpls { .push(item); } - Ok(ItemImpls::new(trait_, item_impls)) + Ok(ItemImpls::new(main_trait, item_impls)) } } @@ -412,7 +407,7 @@ mod helper_trait { idx: usize, ) -> Option { let assoc_type_param_count = AssocBounds::find(impls).type_param_idents.len(); - let type_param_idents = (0..assoc_type_param_count).map(param::gen_indexed_type_param_name); + let type_param_idents = (0..assoc_type_param_count).map(param::gen_indexed_param_name); if let Some(mut helper_trait) = main_trait.cloned() { helper_trait.vis = syn::Visibility::Public(syn::parse_quote!(pub)); @@ -420,7 +415,7 @@ mod helper_trait { let start_idx = helper_trait.generics.type_params().count(); (start_idx..(assoc_type_param_count + start_idx)) - .map(param::gen_indexed_type_param_name) + .map(param::gen_indexed_param_name) .for_each(|type_param_ident| { helper_trait .generics @@ -506,28 +501,11 @@ mod param { use rustc_hash::FxHashMap; use quote::format_ident; - use syn::{ - visit::{visit_path, Visit}, - visit_mut::{visit_const_param_mut, visit_path_mut, visit_type_param_mut, VisitMut}, - ItemImpl, - }; - - /// Resolve lifetimes, type params and const params into position based identifiers - pub trait ParamResolver { - /// Replaces all param identifiers with a position based identifier. - /// This makes easier to compare params across different impls. - /// - /// For: - /// `impl, V> Trait for U` - /// resolved impl signature would be: - /// `impl<_T2, _T1: IntoIterator, V> Trait for _T2` - fn resolve_non_predicate_params(&mut self); - } + use syn::{visit::Visit, visit_mut::VisitMut}; - struct NonPredicateParamResolver { - lifetimes: FxHashMap, - type_params: FxHashMap, - const_params: FxHashMap, + struct NonPredicateParamResolver<'ast> { + main_trait_generics: Vec<&'ast syn::Ident>, + params: FxHashMap, } /// Indexer for params used in traits, impl trait or self type, but not predicates. @@ -537,96 +515,65 @@ mod param { /// `U` = 1, /// `V` = undetermined struct NonPredicateParamIndexer<'ast> { - lifetime_params: FxHashMap<&'ast syn::Ident, Option>, - type_params: FxHashMap<&'ast syn::Ident, Option>, - const_params: FxHashMap<&'ast syn::Ident, Option>, - - curr_lifetime_param_pos_idx: usize, - curr_type_param_pos_idx: usize, - curr_const_param_pos_idx: usize, - } - - impl ParamResolver for ItemImpl { - fn resolve_non_predicate_params(&mut self) { - let mut non_predicate_param_indexer = NonPredicateParamIndexer::new(&self.generics); - non_predicate_param_indexer.visit_item_impl(self); - let mut param_resolver = NonPredicateParamResolver::new(non_predicate_param_indexer); - param_resolver.visit_item_impl_mut(self); - - // TODO: Add unnamed lifetimes (&u32) or elided lifetimes (&'_ u32) - // TODO: Remove unused lifetimes. Example where 'b is unused: - // impl<'a: 'b, 'b: 'a, T: 'b > Kara<'a, T> for &'a T { - // - //self.generics.params = self - // .generics - // .params - // .into_iter() - // .filter(|param| match param { - // syn::GenericParam::Lifetime(lifetime) - // if param_resolver.0.get(&lifetime.lifetime.ident).1 => - // { - // syn::GenericParam::Lifetime(lifetime) - // } - // param => param, - // }) - // .collect(); - } - } - - impl ParamResolver for syn::ItemTrait { - fn resolve_non_predicate_params(&mut self) { - let mut non_predicate_param_indexer = NonPredicateParamIndexer::new(&self.generics); - non_predicate_param_indexer.visit_item_trait(self); - let mut param_resolver = NonPredicateParamResolver::new(non_predicate_param_indexer); - param_resolver.visit_item_trait_mut(self); - } - } - - fn gen_indexed_lifetime_param_name(idx: usize) -> syn::Ident { - format_ident!("_LŠČ{idx}") + params: FxHashMap<&'ast syn::Ident, Option>, + curr_param_pos_idx: usize, } - pub(super) fn gen_indexed_type_param_name(idx: usize) -> syn::Ident { - format_ident!("_TŠČ{idx}") + pub fn resolve_non_predicate_params( + item_impl: &mut syn::ItemImpl, + main_trait: Option<&syn::ItemTrait>, + ) { + let mut non_predicate_param_indexer = NonPredicateParamIndexer::new(&item_impl.generics); + non_predicate_param_indexer.visit_item_impl(item_impl); + let mut param_resolver = + NonPredicateParamResolver::new(non_predicate_param_indexer, main_trait); + param_resolver.visit_item_impl_mut(item_impl); + + // TODO: Add unnamed lifetimes (&u32) or elided lifetimes (&'_ u32) + // TODO: Remove unused lifetimes. Example where 'b is unused: + // impl<'a: 'b, 'b: 'a, T: 'b > Kara<'a, T> for &'a T { + // + //self.generics.params = self + // .generics + // .params + // .into_iter() + // .filter(|param| match param { + // syn::GenericParam::Lifetime(lifetime) + // if param_resolver.0.get(&lifetime.lifetime.ident).1 => + // { + // syn::GenericParam::Lifetime(lifetime) + // } + // param => param, + // }) + // .collect(); } - fn gen_indexed_const_param_name(idx: usize) -> syn::Ident { - format_ident!("_CŠČ{idx}") + pub(super) fn gen_indexed_param_name(idx: usize) -> syn::Ident { + format_ident!("_{idx}") } impl<'ast> NonPredicateParamIndexer<'ast> { fn new(generics: &'ast syn::Generics) -> Self { - let lifetime_params = generics - .lifetimes() - .map(|param| (¶m.lifetime.ident, None)) - .collect(); - let type_params = generics - .type_params() - .map(|param| (¶m.ident, None)) - .collect(); - let const_params = generics - .const_params() - .map(|param| (¶m.ident, None)) + let params = generics + .params + .iter() + .map(|param| match param { + syn::GenericParam::Lifetime(lifetime_param) => &lifetime_param.lifetime.ident, + syn::GenericParam::Type(type_param) => &type_param.ident, + syn::GenericParam::Const(const_param) => &const_param.ident, + }) + .map(|param| (param, None)) .collect(); Self { - lifetime_params, - type_params, - const_params, - - curr_lifetime_param_pos_idx: 0, - curr_type_param_pos_idx: 0, - curr_const_param_pos_idx: 0, + params, + curr_param_pos_idx: 0, } } } impl<'ast> Visit<'ast> for NonPredicateParamIndexer<'ast> { - fn visit_item_trait(&mut self, node: &'ast syn::ItemTrait) { - self.visit_generics(&node.generics); - } - - fn visit_item_impl(&mut self, node: &'ast ItemImpl) { + fn visit_item_impl(&mut self, node: &'ast syn::ItemImpl) { if let Some((_, trait_, _)) = &node.trait_ { // NOTE: Calling `visit_path` on a trait would conflict // with resolving params on `TypePath` so it's not done @@ -638,86 +585,74 @@ mod param { self.visit_type(&node.self_ty); } - // Called only for a trait definition, never for impl block - fn visit_lifetime_param(&mut self, node: &'ast syn::LifetimeParam) { - self.visit_lifetime(&node.lifetime); - } - - // Called only for a trait definition, never for impl block - fn visit_type_param(&mut self, node: &'ast syn::TypeParam) { - *self.type_params.get_mut(&node.ident).unwrap() = Some(self.curr_type_param_pos_idx); - - if let Some(curr_pos_idx) = self.curr_type_param_pos_idx.checked_add(1) { - self.curr_type_param_pos_idx = curr_pos_idx; - } - } - - // Called only for a trait definition, never for impl block - fn visit_const_param(&mut self, node: &'ast syn::ConstParam) { - *self.const_params.get_mut(&node.ident).unwrap() = Some(self.curr_const_param_pos_idx); - - if let Some(curr_pos_idx) = self.curr_const_param_pos_idx.checked_add(1) { - self.curr_const_param_pos_idx = curr_pos_idx; - } - } - fn visit_lifetime(&mut self, node: &'ast syn::Lifetime) { - *self.lifetime_params.get_mut(&node.ident).unwrap() = - Some(self.curr_lifetime_param_pos_idx); + *self.params.get_mut(&node.ident).unwrap() = Some(self.curr_param_pos_idx); - if let Some(curr_pos_idx) = self.curr_lifetime_param_pos_idx.checked_add(1) { - self.curr_lifetime_param_pos_idx = curr_pos_idx; + if let Some(curr_pos_idx) = self.curr_param_pos_idx.checked_add(1) { + self.curr_param_pos_idx = curr_pos_idx; } } fn visit_path(&mut self, node: &'ast syn::Path) { - if let Some(param_idx) = node.get_ident().and_then(|i| self.const_params.get_mut(&i)) { - *param_idx = Some(self.curr_const_param_pos_idx); + if let Some(param_idx) = node.get_ident().and_then(|i| self.params.get_mut(&i)) { + *param_idx = Some(self.curr_param_pos_idx); - if let Some(curr_pos_idx) = self.curr_const_param_pos_idx.checked_add(1) { - self.curr_const_param_pos_idx = curr_pos_idx; + if let Some(curr_pos_idx) = self.curr_param_pos_idx.checked_add(1) { + self.curr_param_pos_idx = curr_pos_idx; } } else if let Some(first_segment) = node.segments.first() { - self.type_params + self.params .entry(&first_segment.ident) .and_modify(|param_idx| { if param_idx.is_none() { - *param_idx = Some(self.curr_type_param_pos_idx); + *param_idx = Some(self.curr_param_pos_idx); } }); - if let Some(pos_idx) = self.curr_type_param_pos_idx.checked_add(1) { - self.curr_type_param_pos_idx = pos_idx; + if let Some(pos_idx) = self.curr_param_pos_idx.checked_add(1) { + self.curr_param_pos_idx = pos_idx; } } - visit_path(self, node); + syn::visit::visit_path(self, node); } fn visit_expr(&mut self, _: &'ast syn::Expr) { - if let Some(curr_pos_idx) = self.curr_const_param_pos_idx.checked_add(1) { - self.curr_const_param_pos_idx = curr_pos_idx; + if let Some(curr_pos_idx) = self.curr_param_pos_idx.checked_add(1) { + self.curr_param_pos_idx = curr_pos_idx; } } fn visit_where_clause(&mut self, _node: &'ast syn::WhereClause) {} } - impl NonPredicateParamResolver { - fn new(indexer: NonPredicateParamIndexer) -> Self { + impl<'ast> NonPredicateParamResolver<'ast> { + fn new( + indexer: NonPredicateParamIndexer, + main_trait: Option<&'ast syn::ItemTrait>, + ) -> Self { + let main_trait_generics: Vec<_> = main_trait + .map(|main_trait| { + main_trait + .generics + .params + .iter() + .map(|param| match param { + syn::GenericParam::Lifetime(lifetime_param) => { + &lifetime_param.lifetime.ident + } + syn::GenericParam::Type(type_param) => &type_param.ident, + syn::GenericParam::Const(const_param) => &const_param.ident, + }) + .collect() + }) + .unwrap_or_default(); + Self { - lifetimes: indexer - .lifetime_params - .into_iter() - .filter_map(|(param, idx)| idx.map(|idx| (param.clone(), idx))) - .collect(), - type_params: indexer - .type_params - .into_iter() - .filter_map(|(param, idx)| idx.map(|idx| (param.clone(), idx))) - .collect(), - const_params: indexer - .const_params + main_trait_generics, + + params: indexer + .params .into_iter() .filter_map(|(param, idx)| idx.map(|idx| (param.clone(), idx))) .collect(), @@ -725,39 +660,55 @@ mod param { } } - impl VisitMut for NonPredicateParamResolver { + impl VisitMut for NonPredicateParamResolver<'_> { fn visit_lifetime_mut(&mut self, node: &mut syn::Lifetime) { - if let Some(&idx) = self.lifetimes.get(&node.ident) { - node.ident = gen_indexed_lifetime_param_name(idx); + if let Some(&idx) = self.params.get(&node.ident) { + node.ident = if let Some(&main_trait_generic) = self.main_trait_generics.get(idx) { + main_trait_generic.clone() + } else { + gen_indexed_param_name(idx) + }; } + + syn::visit_mut::visit_lifetime_mut(self, node); } fn visit_type_param_mut(&mut self, node: &mut syn::TypeParam) { - if let Some(&idx) = self.type_params.get(&node.ident) { - node.ident = gen_indexed_type_param_name(idx); + if let Some(&idx) = self.params.get(&node.ident) { + node.ident = if let Some(&main_trait_generic) = self.main_trait_generics.get(idx) { + main_trait_generic.clone() + } else { + gen_indexed_param_name(idx) + }; } - visit_type_param_mut(self, node); + syn::visit_mut::visit_type_param_mut(self, node); } fn visit_const_param_mut(&mut self, node: &mut syn::ConstParam) { - if let Some(&idx) = self.const_params.get(&node.ident) { - node.ident = gen_indexed_const_param_name(idx); + if let Some(&idx) = self.params.get(&node.ident) { + node.ident = if let Some(&main_trait_generic) = self.main_trait_generics.get(idx) { + main_trait_generic.clone() + } else { + gen_indexed_param_name(idx) + }; } - visit_const_param_mut(self, node); + syn::visit_mut::visit_const_param_mut(self, node); } fn visit_path_mut(&mut self, node: &mut syn::Path) { - if let Some(first_segment) = node.segments.first_mut() { - if let Some(&idx) = self.type_params.get(&first_segment.ident) { - first_segment.ident = gen_indexed_type_param_name(idx); - } else if let Some(&idx) = self.const_params.get(&first_segment.ident) { - first_segment.ident = gen_indexed_const_param_name(idx); + if let Some(path) = node.segments.first_mut() { + if let Some(&idx) = self.params.get(&path.ident) { + path.ident = if let Some(&trait_param) = self.main_trait_generics.get(idx) { + trait_param.clone() + } else { + gen_indexed_param_name(idx) + }; } } - visit_path_mut(self, node); + syn::visit_mut::visit_path_mut(self, node); } } @@ -780,7 +731,7 @@ mod param { // } //} //impl<'ast> Visit<'ast> for PredicateIndexer<'ast> { - // fn visit_item_impl(&mut self, node: &'ast ItemImpl) { + // fn visit_item_impl(&mut self, node: &'ast syn::ItemImpl) { // self.visit_generics(&node.generics); // } // diff --git a/src/main_trait.rs b/src/main_trait.rs index 44b0411..448fd50 100644 --- a/src/main_trait.rs +++ b/src/main_trait.rs @@ -252,21 +252,17 @@ fn gen_helper_trait_bound( ) -> TokenStream2 { let assoc_bounds = gen_assoc_bounds(type_param_idents); - let main_trait_ty_generics = main_trait.map(|main_trait| { - let main_trait_params = main_trait.generics.params.iter().map(|param| match param { - syn::GenericParam::Lifetime(lifetime_param) => &lifetime_param.lifetime.ident, - syn::GenericParam::Type(type_param) => &type_param.ident, - syn::GenericParam::Const(const_param) => &const_param.ident, + if let Some(main_trait) = main_trait { + let main_trait_ty_generics = main_trait.generics.params.iter().map(|param| match param { + syn::GenericParam::Lifetime(syn::LifetimeParam { lifetime, .. }) => quote! {#lifetime}, + syn::GenericParam::Type(syn::TypeParam { ident, .. }) => quote! {#ident}, + syn::GenericParam::Const(syn::ConstParam { ident, .. }) => quote! {#ident}, }); - quote! { - #(#main_trait_params,)* - } - }); - - quote! { - #helper_trait_ident<#main_trait_ty_generics #(#assoc_bounds),*> + return quote! { #helper_trait_ident<#(#main_trait_ty_generics,)* #(#assoc_bounds),*> }; } + + quote! { #helper_trait_ident<#(#assoc_bounds),*> } } impl ImplItemResolver {