Skip to content

Commit

Permalink
fix(sol-macro): namespaced custom type resolution (#731)
Browse files Browse the repository at this point in the history
  • Loading branch information
klkvr authored Sep 9, 2024
1 parent 83e70b7 commit 1ff30e2
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 28 deletions.
64 changes: 36 additions & 28 deletions crates/sol-macro-expander/src/expand/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::{
};
use alloy_sol_macro_input::{ContainsSolAttrs, SolAttrs};
use ast::{
EventParameter, File, Item, ItemError, ItemEvent, ItemFunction, Parameters, SolIdent, SolPath,
Spanned, Type, VariableDeclaration, Visit,
visit_mut, EventParameter, File, Item, ItemError, ItemEvent, ItemFunction, Parameters,
SolIdent, SolPath, Spanned, Type, VariableDeclaration, Visit, VisitMut,
};
use indexmap::IndexMap;
use proc_macro2::{Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree};
Expand Down Expand Up @@ -227,24 +227,44 @@ impl<'ast> ExpCtxt<'ast> {
}

fn resolve_custom_types(&mut self) {
/// Helper struct, recursively resolving types and keeping track of namespace which is
/// updated when entering a type from external contract.
struct Resolver<'a> {
map: &'a NamespacedMap<Type>,
cnt: usize,
namespace: Option<SolIdent>,
}
impl VisitMut<'_> for Resolver<'_> {
fn visit_type(&mut self, ty: &mut Type) {
if self.cnt >= RESOLVE_LIMIT {
return;
}
let prev_namespace = self.namespace.clone();
if let Type::Custom(name) = ty {
let Some(resolved) = self.map.resolve(name, &self.namespace) else {
return;
};
// Update namespace if we're entering a new one
if name.len() == 2 {
self.namespace = Some(name.first().clone());
}
ty.clone_from(resolved);
self.cnt += 1;
}

visit_mut::visit_type(self, ty);

self.namespace = prev_namespace;
}
}

self.mk_types_map();
let map = self.custom_types.clone();
for (namespace, custom_types) in &mut self.custom_types.0 {
for ty in custom_types.values_mut() {
let mut i = 0;
ty.visit_mut(|ty| {
if i >= RESOLVE_LIMIT {
return;
}
let ty @ Type::Custom(_) = ty else { return };
let Type::Custom(name) = &*ty else { unreachable!() };
let Some(resolved) = map.resolve(name, namespace) else {
return;
};
ty.clone_from(resolved);
i += 1;
});
if i >= RESOLVE_LIMIT {
let mut resolver = Resolver { map: &map, cnt: 0, namespace: namespace.clone() };
resolver.visit_type(ty);
if resolver.cnt >= RESOLVE_LIMIT {
abort!(
ty.span(),
"failed to resolve types.\n\
Expand Down Expand Up @@ -517,18 +537,6 @@ impl<'ast> ExpCtxt<'ast> {
self.all_items.resolve(name, &self.current_namespace).copied()
}

/// Recursively resolves the given type by constructing a new one.
#[allow(dead_code)]
fn make_resolved_type(&self, ty: &Type) -> Type {
let mut ty = ty.clone();
ty.visit_mut(|ty| {
if let Type::Custom(name) = ty {
*ty = self.custom_type(name).clone();
}
});
ty
}

fn custom_type(&self, name: &SolPath) -> &Type {
match self.try_custom_type(name) {
Some(item) => item,
Expand Down
71 changes: 71 additions & 0 deletions crates/sol-types/tests/macros/sol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1013,3 +1013,74 @@ fn normal_paths() {

let _ = funcCall { stuff: I::S { x: U256::ZERO } };
}

#[test]
fn regression_nested_namespaced_structs() {
mod inner {
super::sol! {
library LibA {
struct Simple {
uint256 x;
}

struct Nested {
Simple simple;
LibB.Simple[] simpleB;
}
}

library LibB {
struct Simple {
uint256 x;
uint256 y;
}

struct Nested {
Simple simple;
LibA.Simple simpleA;
LibB.Simple simpleB;
}
}

library LibC {
struct Nested1 {
LibA.Nested nestedA;
LibB.Nested nestedB;
}

struct Nested2 {
LibA.Simple simpleA;
LibB.Simple simpleB;
LibA.Nested[] nestedA;
LibB.Nested nestedB;
Nested1[] nestedC1;
LibC.Nested1 nestedC2;
}
}

contract C {
function libASimple(LibA.Simple memory simple) public returns(LibA.Simple memory);
function libBSimple(LibB.Simple memory simple) public returns(LibB.Simple memory);
function libANested(LibA.Nested memory nested) public returns(LibA.Nested memory);
function libBNested(LibB.Nested memory nested) public returns(LibB.Nested memory);
function libCNested1(LibC.Nested1 memory nested) public returns(LibC.Nested1 memory);
function libCNested2(LibC.Nested2 memory nested) public returns(LibC.Nested2 memory);
}
}
}

let a_simple = "(uint256)";
let b_simple = "(uint256,uint256)";
let a_nested = format!("({a_simple},{b_simple}[])");
let b_nested = format!("({b_simple},{a_simple},{b_simple})");
let c_nested1 = format!("({a_nested},{b_nested})");
let c_nested2 =
format!("({a_simple},{b_simple},{a_nested}[],{b_nested},{c_nested1}[],{c_nested1})");

assert_eq!(inner::C::libASimpleCall::SIGNATURE, format!("libASimple({a_simple})"));
assert_eq!(inner::C::libBSimpleCall::SIGNATURE, format!("libBSimple({b_simple})"));
assert_eq!(inner::C::libANestedCall::SIGNATURE, format!("libANested({a_nested})"));
assert_eq!(inner::C::libBNestedCall::SIGNATURE, format!("libBNested({b_nested})"));
assert_eq!(inner::C::libCNested1Call::SIGNATURE, format!("libCNested1({c_nested1})"));
assert_eq!(inner::C::libCNested2Call::SIGNATURE, format!("libCNested2({c_nested2})"));
}

0 comments on commit 1ff30e2

Please sign in to comment.