Skip to content

Commit

Permalink
add support for resolution of parameters not bounded by trait or self…
Browse files Browse the repository at this point in the history
… type
  • Loading branch information
mversic committed Jan 8, 2024
1 parent 9f7f6ce commit 631c023
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 119 deletions.
236 changes: 139 additions & 97 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,15 @@ impl quote::ToTokens for TraitBound<'_> {
first_args.colon2_token.to_tokens(tokens);

quote!(<).to_tokens(tokens);
first_args.args.iter().for_each(|args| match args {
syn::GenericArgument::AssocType(_) => {}
_ => args.to_tokens(tokens),
});
first_args
.args
.iter()
.filter_map(|arg| match arg {
syn::GenericArgument::AssocType(_) => None,
_ => Some(arg),
})
.collect::<syn::punctuated::Punctuated<_, syn::Token![,]>>()
.to_tokens(tokens);
quote!(>).to_tokens(tokens);
}
_ => first_elem.arguments.to_tokens(tokens),
Expand Down Expand Up @@ -507,26 +512,65 @@ mod param {
use quote::format_ident;
use syn::{visit::Visit, visit_mut::VisitMut};

struct NonPredicateParamResolver {
params: FxHashMap<syn::Ident, usize>,
}

/// Indexer for params used in traits, impl trait or self type, but not predicates.
///
/// For `impl<U, T: IntoIterator<Item = V>, V> Trait<T> for U` resolved indices would be:
/// `T` = 0,
/// `U` = 1,
/// `V` = undetermined
struct NonPredicateParamIndexer<'ast> {
params: FxHashMap<&'ast syn::Ident, Option<usize>>,
struct NonPredicateParamIndexer {
indexed_params: FxHashMap<syn::Ident, (usize, syn::GenericParam)>,
unindexed_params: FxHashMap<syn::Ident, syn::GenericParam>,
curr_param_pos_idx: usize,
}

struct NonPredicateParamResolver {
params: FxHashMap<syn::Ident, usize>,
}

pub fn resolve_non_predicate_params(item_impl: &mut syn::ItemImpl) {
let mut non_predicate_param_indexer = NonPredicateParamIndexer::new(&item_impl.generics);
let item_impl_generics = item_impl.generics.params.iter().cloned();

let mut non_predicate_param_indexer = NonPredicateParamIndexer::new(
item_impl_generics
.map(|param| (get_param_ident(&param).clone(), param))
.collect(),
0,
);

non_predicate_param_indexer.visit_item_impl(item_impl);
let mut param_resolver = NonPredicateParamResolver::new(non_predicate_param_indexer);
param_resolver.visit_item_impl_mut(item_impl);

let mut prev_unindexed_params_count = usize::MAX;
let mut indexed_params = non_predicate_param_indexer.indexed_params;
let mut curr_unindexed_params_count = non_predicate_param_indexer.unindexed_params.len();

while !non_predicate_param_indexer.unindexed_params.is_empty()
// NOTE: This discards parameters only used in where clause
&& prev_unindexed_params_count != curr_unindexed_params_count
{
non_predicate_param_indexer = NonPredicateParamIndexer::new(
non_predicate_param_indexer.unindexed_params,
non_predicate_param_indexer.curr_param_pos_idx,
);

non_predicate_param_indexer.visit_indexed_params(
indexed_params
.iter()
.map(|(_, (idx, param))| (*idx, param))
.collect(),
);

prev_unindexed_params_count = curr_unindexed_params_count;
indexed_params.extend(non_predicate_param_indexer.indexed_params);
curr_unindexed_params_count = non_predicate_param_indexer.unindexed_params.len();
}

NonPredicateParamResolver::new(
indexed_params
.into_iter()
.map(|(ident, (idx, _))| (ident, idx)),
)
.visit_item_impl_mut(item_impl);

// TODO: Add unnamed lifetimes (&u32) or elided lifetimes (&'_ u32)
// TODO: Remove unused lifetimes. Example where 'b is unused:
Expand All @@ -551,21 +595,44 @@ mod param {
format_ident!("_{idx}")
}

impl<'ast> NonPredicateParamIndexer<'ast> {
fn new(generics: &'ast syn::Generics) -> Self {
let params = get_param_idents(generics.params.iter())
.map(|param| (param, None))
.collect();
impl NonPredicateParamIndexer {
fn new(
unindexed_params: FxHashMap<syn::Ident, syn::GenericParam>,
curr_param_pos_idx: usize,
) -> Self {
let indexed_params = FxHashMap::default();

Self {
params,
curr_param_pos_idx: 0,
indexed_params,
unindexed_params,
curr_param_pos_idx,
}
}

fn visit_param_ident(&mut self, param_ident: &syn::Ident) -> bool {
if let Some(removed) = self.unindexed_params.remove(param_ident) {
self.indexed_params
.insert(param_ident.clone(), (self.curr_param_pos_idx, removed));
self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap();

return true;
}

false
}

fn visit_indexed_params(&mut self, node: FxHashMap<usize, &syn::GenericParam>) {
let mut indexed_params = node.into_iter().collect::<Vec<_>>();
indexed_params.sort_by_key(|(k, _)| *k);

for (_, param) in indexed_params {
self.visit_generic_param(param);
}
}
}

impl<'ast> Visit<'ast> for NonPredicateParamIndexer<'ast> {
fn visit_item_impl(&mut self, node: &'ast syn::ItemImpl) {
impl Visit<'_> for NonPredicateParamIndexer {
fn visit_item_impl(&mut self, node: &syn::ItemImpl) {
if let Some((_, trait_, _)) = &node.trait_ {
// NOTE: Calling `visit_path` on a trait would conflict
// with resolving params on `TypePath` so it's not done
Expand All @@ -578,59 +645,71 @@ mod param {
//
// had `Visit::visit_path` been used on `T<T>` to resolve
// trait generics it would also rename the trait ident itself
let path = &trait_.segments.last().unwrap();
let path = trait_.segments.last().unwrap();
self.visit_path_arguments(&path.arguments);
}

self.visit_type(&node.self_ty);
}

fn visit_lifetime(&mut self, node: &'ast syn::Lifetime) {
if let Some(lifetime) = self.params.get_mut(&node.ident) {
*lifetime = Some(self.curr_param_pos_idx);
}

self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap();
fn visit_lifetime(&mut self, node: &syn::Lifetime) {
self.visit_param_ident(&node.ident);
}

fn visit_path(&mut self, node: &'ast syn::Path) {
let path = node.segments.first().unwrap();

if let Some(param_idx) = self.params.get_mut(&path.ident) {
if param_idx.is_none() {
*param_idx = Some(self.curr_param_pos_idx);
}

self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap();
} else {
fn visit_path(&mut self, node: &syn::Path) {
if !self.visit_param_ident(&node.segments.first().unwrap().ident) {
syn::visit::visit_path(self, node);
}
}

fn visit_expr(&mut self, node: &'ast syn::Expr) {
if let syn::Expr::Path(path) = node {
self.visit_expr_path(path);
} else {
self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap();
}
}

fn visit_where_clause(&mut self, _node: &'ast syn::WhereClause) {}
// TODO: Is this required? I don't think it is anymore
//fn visit_expr(&mut self, node: &syn::Expr) {
// if let syn::Expr::Path(path) = node {
// self.visit_expr_path(path);
// } else {
// self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap();
// }
//}
}

impl NonPredicateParamResolver {
fn new(indexer: NonPredicateParamIndexer) -> Self {
fn new(params: impl IntoIterator<Item = (syn::Ident, usize)>) -> Self {
Self {
params: indexer
.params
.into_iter()
.filter_map(|(param, idx)| idx.map(|idx| (param.clone(), idx)))
.collect(),
params: params.into_iter().collect(),
}
}
}

impl VisitMut for NonPredicateParamResolver {
fn visit_item_impl_mut(&mut self, node: &mut syn::ItemImpl) {
for attr in &mut node.attrs {
self.visit_attribute_mut(attr);
}

self.visit_generics_mut(&mut node.generics);
if let Some((_, trait_, _)) = &mut node.trait_ {
// NOTE: Calling `visit_path` on a trait would conflict
// with resolving params on `TypePath` so it's not done
//
// # Example
//
// ```
// trait T<T> {}
// ```
//
// had `Visit::visit_path` been used on `T<T>` to resolve
// trait generics it would also rename the trait ident itself
let path = trait_.segments.last_mut().unwrap();
self.visit_path_arguments_mut(&mut path.arguments);
}

self.visit_type_mut(&mut node.self_ty);

for item in &mut node.items {
self.visit_impl_item_mut(item);
}
}

fn visit_lifetime_mut(&mut self, node: &mut syn::Lifetime) {
if let Some(&idx) = self.params.get(&node.ident) {
node.ident = gen_indexed_param_name(idx);
Expand Down Expand Up @@ -666,48 +745,11 @@ mod param {
}
}

//struct PredicateIndexer<'ast> {
// type_params: FxHashMap<&'ast syn::Ident, Option<usize>>,
// curr_pos_idx: usize,
//}
//impl<'ast> PredicateIndexer<'ast> {
// fn new(type_params: FxHashMap<&'ast syn::Ident, Option<usize>>) -> Self {
// let curr_pos_idx: usize = type_params
// .values()
// .filter_map(|x| *x)
// .reduce(|acc, x| x.max(acc))
// .unwrap_or(0);
//
// Self {
// type_params,
// curr_pos_idx,
// }
// }
//}
//impl<'ast> Visit<'ast> for PredicateIndexer<'ast> {
// fn visit_item_impl(&mut self, node: &'ast syn::ItemImpl) {
// self.visit_generics(&node.generics);
// }
//
// fn visit_path_segment(&mut self, node: &'ast syn::PathSegment) {
// self.type_params.entry(&node.ident).and_modify(|param_idx| {
// if param_idx.is_none() {
// // Param encountered for the first time
// *param_idx = Some(self.curr_pos_idx);
// }
// });
//
// self.curr_pos_idx = self.curr_pos_idx.checked_add(1).unwrap();
// }
//}

pub fn get_param_idents<'a>(
generic_params: impl Iterator<Item = &'a syn::GenericParam>,
) -> impl Iterator<Item = &'a syn::Ident> {
generic_params.into_iter().map(|param| match param {
syn::GenericParam::Lifetime(lifetime_param) => &lifetime_param.lifetime.ident,
syn::GenericParam::Type(type_param) => &type_param.ident,
syn::GenericParam::Const(const_param) => &const_param.ident,
})
pub fn get_param_ident(generic_param: &syn::GenericParam) -> &syn::Ident {
match generic_param {
syn::GenericParam::Lifetime(syn::LifetimeParam { lifetime, .. }) => &lifetime.ident,
syn::GenericParam::Type(syn::TypeParam { ident, .. }) => ident,
syn::GenericParam::Const(syn::ConstParam { ident, .. }) => ident,
}
}
}
2 changes: 1 addition & 1 deletion src/main_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ impl GenericsResolver {
let trait_bounds = trait_bounds.into_iter();
quote! { #param_ident: #(#trait_bounds)+* }
});

let where_clause_predicates = assoc_bound_predicates
.chain(core::iter::once_with(|| {
let helper_trait_bound =
Expand All @@ -205,7 +206,6 @@ impl GenericsResolver {
quote! { Self: #helper_trait_bound }
}))
.collect();

Self {
assoc_bound_type_params,
where_clause_predicates,
Expand Down
Loading

0 comments on commit 631c023

Please sign in to comment.