Skip to content

Commit

Permalink
fix generation of main trait generics
Browse files Browse the repository at this point in the history
  • Loading branch information
mversic committed Jan 8, 2024
1 parent 5443687 commit 9f7f6ce
Show file tree
Hide file tree
Showing 15 changed files with 321 additions and 264 deletions.
45 changes: 21 additions & 24 deletions src/disjoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,12 @@ pub fn gen(mut impls: Vec<ItemImpl>, idx: usize) -> Vec<ItemImpl> {

impls.iter_mut().for_each(|impl_| {
if let Some((_, trait_, _)) = impl_.trait_.as_mut() {
if let Some(last_seg) = trait_.segments.last_mut() {
last_seg.ident = helper_trait::gen_ident(&last_seg.ident, idx);
}
let path = trait_.segments.last_mut().unwrap();
path.ident = helper_trait::gen_ident(&path.ident, idx);
} else if let syn::Type::Path(type_path) = &*impl_.self_ty {
if let Some(last_seg) = type_path.path.segments.last() {
let helper_trait_ident = helper_trait::gen_ident(&last_seg.ident, idx);
impl_.trait_ = Some((None, parse_quote!(#helper_trait_ident), parse_quote![for]));
}
let path = type_path.path.segments.last().unwrap();
let helper_trait_ident = helper_trait::gen_ident(&path.ident, idx);
impl_.trait_ = Some((None, parse_quote!(#helper_trait_ident), parse_quote![for]));
}
});

Expand All @@ -58,23 +56,22 @@ pub fn gen(mut impls: Vec<ItemImpl>, idx: usize) -> Vec<ItemImpl> {
let params = update_disjoint_impl_generics(impl_, params);

if let Some((_, trait_, _)) = impl_.trait_.as_mut() {
if let Some(last_seg) = trait_.segments.last_mut() {
match &mut last_seg.arguments {
syn::PathArguments::None => {
last_seg.arguments = syn::PathArguments::AngleBracketed(
syn::parse_quote!(<#(#params),*>),
)
}
syn::PathArguments::AngleBracketed(bracketed) => {
bracketed.args = params
.into_iter()
.map::<syn::GenericArgument, _>(|param| syn::parse_quote!(#param))
.chain(core::mem::take(&mut bracketed.args))
.collect();
}
syn::PathArguments::Parenthesized(_) => {
unreachable!("Not a valid trait name")
}
let path = trait_.segments.last_mut().unwrap();

match &mut path.arguments {
syn::PathArguments::None => {
path.arguments =
syn::PathArguments::AngleBracketed(syn::parse_quote!(<#(#params),*>))
}
syn::PathArguments::AngleBracketed(bracketed) => {
bracketed.args = params
.into_iter()
.map::<syn::GenericArgument, _>(|param| syn::parse_quote!(#param))
.chain(core::mem::take(&mut bracketed.args))
.collect();
}
syn::PathArguments::Parenthesized(_) => {
unreachable!("Not a valid trait name")
}
}
}
Expand Down
120 changes: 40 additions & 80 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ impl Parse for ItemImpls {
let mut item_impls = FxHashMap::default();
while let Ok(mut item) = input.parse::<ItemImpl>() {
// TODO: Resolve predicate param idents
param::resolve_non_predicate_params(&mut item, main_trait.as_ref());
param::resolve_non_predicate_params(&mut item);

item_impls
.entry((
Expand Down Expand Up @@ -421,10 +421,8 @@ mod helper_trait {
if let Some(mut helper_trait) = main_trait.cloned() {
helper_trait.vis = syn::Visibility::Public(syn::parse_quote!(pub));
helper_trait.ident = gen_ident(&helper_trait.ident, idx);
let start_idx = helper_trait.generics.type_params().count();

helper_trait.generics.params = (start_idx..(start_idx + assoc_type_param_count))
.map(param::gen_indexed_param_name)
helper_trait.generics.params = type_param_idents
.map(|type_param_ident| syn::parse_quote!(#type_param_ident: ?Sized))
.chain(helper_trait.generics.params)
.collect();
Expand All @@ -435,15 +433,14 @@ mod helper_trait {

let self_ty = &impl_group_id.1;
if let syn::Type::Path(type_path) = self_ty {
if let Some(last_seg) = type_path.path.segments.last() {
let helper_trait_ident = gen_ident(&last_seg.ident, idx);

return Some(quote! {
pub trait #helper_trait_ident<#(#type_param_idents: ?Sized),*> {
#(#items)*
}
});
}
let path = type_path.path.segments.last().unwrap();
let helper_trait_ident = gen_ident(&path.ident, idx);

return Some(quote! {
pub trait #helper_trait_ident<#(#type_param_idents: ?Sized),*> {
#(#items)*
}
});
}
}

Expand Down Expand Up @@ -510,8 +507,7 @@ mod param {
use quote::format_ident;
use syn::{visit::Visit, visit_mut::VisitMut};

struct NonPredicateParamResolver<'ast> {
main_trait_generics: Vec<&'ast syn::Ident>,
struct NonPredicateParamResolver {
params: FxHashMap<syn::Ident, usize>,
}

Expand All @@ -526,14 +522,10 @@ mod param {
curr_param_pos_idx: usize,
}

pub fn resolve_non_predicate_params(
item_impl: &mut syn::ItemImpl,
main_trait: Option<&syn::ItemTrait>,
) {
pub fn resolve_non_predicate_params(item_impl: &mut syn::ItemImpl) {
let mut non_predicate_param_indexer = NonPredicateParamIndexer::new(&item_impl.generics);
non_predicate_param_indexer.visit_item_impl(item_impl);
let mut param_resolver =
NonPredicateParamResolver::new(non_predicate_param_indexer, main_trait);
let mut param_resolver = NonPredicateParamResolver::new(non_predicate_param_indexer);
param_resolver.visit_item_impl_mut(item_impl);

// TODO: Add unnamed lifetimes (&u32) or elided lifetimes (&'_ u32)
Expand Down Expand Up @@ -586,9 +578,8 @@ mod param {
//
// had `Visit::visit_path` been used on `T<T>` to resolve
// trait generics it would also rename the trait ident itself
if let Some(last_seg) = &trait_.segments.last() {
self.visit_path_arguments(&last_seg.arguments);
}
let path = &trait_.segments.last().unwrap();
self.visit_path_arguments(&path.arguments);
}

self.visit_type(&node.self_ty);
Expand All @@ -599,50 +590,37 @@ mod param {
*lifetime = Some(self.curr_param_pos_idx);
}

if let Some(curr_pos_idx) = self.curr_param_pos_idx.checked_add(1) {
self.curr_param_pos_idx = curr_pos_idx;
}
self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap();
}

fn visit_path(&mut self, node: &'ast syn::Path) {
if let Some(path) = node.segments.first() {
if let Some(param_idx) = self.params.get_mut(&path.ident) {
if param_idx.is_none() {
*param_idx = Some(self.curr_param_pos_idx);
}
}
let path = node.segments.first().unwrap();

if let Some(pos_idx) = self.curr_param_pos_idx.checked_add(1) {
self.curr_param_pos_idx = pos_idx;
if let Some(param_idx) = self.params.get_mut(&path.ident) {
if param_idx.is_none() {
*param_idx = Some(self.curr_param_pos_idx);
}
}

syn::visit::visit_path(self, node);
self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap();
} else {
syn::visit::visit_path(self, node);
}
}

fn visit_expr(&mut self, node: &'ast syn::Expr) {
if let syn::Expr::Path(path) = node {
syn::visit::visit_expr_path(self, path);
} else if let Some(curr_pos_idx) = self.curr_param_pos_idx.checked_add(1) {
self.curr_param_pos_idx = curr_pos_idx;
self.visit_expr_path(path);
} else {
self.curr_param_pos_idx = self.curr_param_pos_idx.checked_add(1).unwrap();
}
}

fn visit_where_clause(&mut self, _node: &'ast syn::WhereClause) {}
}

impl<'ast> NonPredicateParamResolver<'ast> {
fn new(
indexer: NonPredicateParamIndexer,
main_trait: Option<&'ast syn::ItemTrait>,
) -> Self {
let main_trait_generics: Vec<_> = main_trait
.map(|main_trait| get_param_idents(main_trait.generics.params.iter()).collect())
.unwrap_or_default();

impl NonPredicateParamResolver {
fn new(indexer: NonPredicateParamIndexer) -> Self {
Self {
main_trait_generics,

params: indexer
.params
.into_iter()
Expand All @@ -652,55 +630,39 @@ mod param {
}
}

impl VisitMut for NonPredicateParamResolver<'_> {
impl VisitMut for NonPredicateParamResolver {
fn visit_lifetime_mut(&mut self, node: &mut syn::Lifetime) {
if let Some(&idx) = self.params.get(&node.ident) {
node.ident = if let Some(&main_trait_generic) = self.main_trait_generics.get(idx) {
main_trait_generic.clone()
} else {
gen_indexed_param_name(idx)
};
node.ident = gen_indexed_param_name(idx);
}

syn::visit_mut::visit_lifetime_mut(self, node);
}

fn visit_type_param_mut(&mut self, node: &mut syn::TypeParam) {
if let Some(&idx) = self.params.get(&node.ident) {
node.ident = if let Some(&main_trait_generic) = self.main_trait_generics.get(idx) {
main_trait_generic.clone()
} else {
gen_indexed_param_name(idx)
};
node.ident = gen_indexed_param_name(idx);
}

syn::visit_mut::visit_type_param_mut(self, node);
}

fn visit_const_param_mut(&mut self, node: &mut syn::ConstParam) {
if let Some(&idx) = self.params.get(&node.ident) {
node.ident = if let Some(&main_trait_generic) = self.main_trait_generics.get(idx) {
main_trait_generic.clone()
} else {
gen_indexed_param_name(idx)
};
node.ident = gen_indexed_param_name(idx);
}

syn::visit_mut::visit_const_param_mut(self, node);
}

fn visit_path_mut(&mut self, node: &mut syn::Path) {
if let Some(path) = node.segments.first_mut() {
if let Some(&idx) = self.params.get(&path.ident) {
path.ident = if let Some(&trait_param) = self.main_trait_generics.get(idx) {
trait_param.clone()
} else {
gen_indexed_param_name(idx)
};
}
}
let path = node.segments.first_mut().unwrap();

syn::visit_mut::visit_path_mut(self, node);
if let Some(&idx) = self.params.get(&path.ident) {
path.ident = gen_indexed_param_name(idx);
} else {
syn::visit_mut::visit_path_mut(self, node);
}
}
}

Expand Down Expand Up @@ -735,9 +697,7 @@ mod param {
// }
// });
//
// if let Some(curr_pos_idx) = self.curr_pos_idx.checked_add(1) {
// self.curr_pos_idx = curr_pos_idx;
// }
// self.curr_pos_idx = self.curr_pos_idx.checked_add(1).unwrap();
// }
//}

Expand Down
Loading

0 comments on commit 9f7f6ce

Please sign in to comment.