diff --git a/crates/sol-macro-expander/src/expand/mod.rs b/crates/sol-macro-expander/src/expand/mod.rs index d4c4a0458..85ec5d0e4 100644 --- a/crates/sol-macro-expander/src/expand/mod.rs +++ b/crates/sol-macro-expander/src/expand/mod.rs @@ -6,8 +6,8 @@ use crate::{ }; use alloy_sol_macro_input::{ContainsSolAttrs, SolAttrs}; use ast::{ - EventParameter, File, Item, ItemError, ItemEvent, ItemFunction, Parameters, SolIdent, SolPath, - Spanned, Type, VariableDeclaration, Visit, + visit_mut, EventParameter, File, Item, ItemError, ItemEvent, ItemFunction, Parameters, + SolIdent, SolPath, Spanned, Type, VariableDeclaration, Visit, VisitMut, }; use indexmap::IndexMap; use proc_macro2::{Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree}; @@ -227,24 +227,44 @@ impl<'ast> ExpCtxt<'ast> { } fn resolve_custom_types(&mut self) { + /// Helper struct, recursively resolving types and keeping track of namespace which is + /// updated when entering a type from external contract. + struct Resolver<'a> { + map: &'a NamespacedMap, + cnt: usize, + namespace: Option, + } + impl VisitMut<'_> for Resolver<'_> { + fn visit_type(&mut self, ty: &mut Type) { + if self.cnt >= RESOLVE_LIMIT { + return; + } + let prev_namespace = self.namespace.clone(); + if let Type::Custom(name) = ty { + let Some(resolved) = self.map.resolve(name, &self.namespace) else { + return; + }; + // Update namespace if we're entering a new one + if name.len() == 2 { + self.namespace = Some(name.first().clone()); + } + ty.clone_from(resolved); + self.cnt += 1; + } + + visit_mut::visit_type(self, ty); + + self.namespace = prev_namespace; + } + } + self.mk_types_map(); let map = self.custom_types.clone(); for (namespace, custom_types) in &mut self.custom_types.0 { for ty in custom_types.values_mut() { - let mut i = 0; - ty.visit_mut(|ty| { - if i >= RESOLVE_LIMIT { - return; - } - let ty @ Type::Custom(_) = ty else { return }; - let Type::Custom(name) = &*ty else { unreachable!() }; - let Some(resolved) = map.resolve(name, namespace) else { - return; - }; - ty.clone_from(resolved); - i += 1; - }); - if i >= RESOLVE_LIMIT { + let mut resolver = Resolver { map: &map, cnt: 0, namespace: namespace.clone() }; + resolver.visit_type(ty); + if resolver.cnt >= RESOLVE_LIMIT { abort!( ty.span(), "failed to resolve types.\n\ @@ -517,18 +537,6 @@ impl<'ast> ExpCtxt<'ast> { self.all_items.resolve(name, &self.current_namespace).copied() } - /// Recursively resolves the given type by constructing a new one. - #[allow(dead_code)] - fn make_resolved_type(&self, ty: &Type) -> Type { - let mut ty = ty.clone(); - ty.visit_mut(|ty| { - if let Type::Custom(name) = ty { - *ty = self.custom_type(name).clone(); - } - }); - ty - } - fn custom_type(&self, name: &SolPath) -> &Type { match self.try_custom_type(name) { Some(item) => item, diff --git a/crates/sol-types/tests/macros/sol/mod.rs b/crates/sol-types/tests/macros/sol/mod.rs index 1302fae33..4bae1c3e9 100644 --- a/crates/sol-types/tests/macros/sol/mod.rs +++ b/crates/sol-types/tests/macros/sol/mod.rs @@ -1013,3 +1013,74 @@ fn normal_paths() { let _ = funcCall { stuff: I::S { x: U256::ZERO } }; } + +#[test] +fn regression_nested_namespaced_structs() { + mod inner { + super::sol! { + library LibA { + struct Simple { + uint256 x; + } + + struct Nested { + Simple simple; + LibB.Simple[] simpleB; + } + } + + library LibB { + struct Simple { + uint256 x; + uint256 y; + } + + struct Nested { + Simple simple; + LibA.Simple simpleA; + LibB.Simple simpleB; + } + } + + library LibC { + struct Nested1 { + LibA.Nested nestedA; + LibB.Nested nestedB; + } + + struct Nested2 { + LibA.Simple simpleA; + LibB.Simple simpleB; + LibA.Nested[] nestedA; + LibB.Nested nestedB; + Nested1[] nestedC1; + LibC.Nested1 nestedC2; + } + } + + contract C { + function libASimple(LibA.Simple memory simple) public returns(LibA.Simple memory); + function libBSimple(LibB.Simple memory simple) public returns(LibB.Simple memory); + function libANested(LibA.Nested memory nested) public returns(LibA.Nested memory); + function libBNested(LibB.Nested memory nested) public returns(LibB.Nested memory); + function libCNested1(LibC.Nested1 memory nested) public returns(LibC.Nested1 memory); + function libCNested2(LibC.Nested2 memory nested) public returns(LibC.Nested2 memory); + } + } + } + + let a_simple = "(uint256)"; + let b_simple = "(uint256,uint256)"; + let a_nested = format!("({a_simple},{b_simple}[])"); + let b_nested = format!("({b_simple},{a_simple},{b_simple})"); + let c_nested1 = format!("({a_nested},{b_nested})"); + let c_nested2 = + format!("({a_simple},{b_simple},{a_nested}[],{b_nested},{c_nested1}[],{c_nested1})"); + + assert_eq!(inner::C::libASimpleCall::SIGNATURE, format!("libASimple({a_simple})")); + assert_eq!(inner::C::libBSimpleCall::SIGNATURE, format!("libBSimple({b_simple})")); + assert_eq!(inner::C::libANestedCall::SIGNATURE, format!("libANested({a_nested})")); + assert_eq!(inner::C::libBNestedCall::SIGNATURE, format!("libBNested({b_nested})")); + assert_eq!(inner::C::libCNested1Call::SIGNATURE, format!("libCNested1({c_nested1})")); + assert_eq!(inner::C::libCNested2Call::SIGNATURE, format!("libCNested2({c_nested2})")); +}