From 4e3924fb82b7299d56d3619aa5d7b9863f581e0a Mon Sep 17 00:00:00 2001 From: glihm Date: Mon, 23 Sep 2024 09:48:13 -0400 Subject: [PATCH] fix: add a derive for generate contract (#58) * feat: add derives specific to contract * fix: add examples for contract derives * docs: add example for derives in README --- crates/rs-macro/README.md | 10 ++++++++++ crates/rs-macro/src/lib.rs | 2 ++ crates/rs-macro/src/macro_inputs.rs | 12 +++++++++++ crates/rs-macro/src/macro_inputs_legacy.rs | 12 +++++++++++ crates/rs/src/expand/contract.rs | 12 ++++++++--- crates/rs/src/lib.rs | 23 +++++++++++++++++++++- examples/abigen_generate.rs | 4 +++- examples/simple_get_set.rs | 4 +++- src/bin/cli/args.rs | 5 +++++ src/bin/cli/main.rs | 1 + src/bin/cli/plugins/builtins/rust.rs | 1 + src/bin/cli/plugins/mod.rs | 1 + 12 files changed, 81 insertions(+), 6 deletions(-) diff --git a/crates/rs-macro/README.md b/crates/rs-macro/README.md index 93e5206..2dde826 100644 --- a/crates/rs-macro/README.md +++ b/crates/rs-macro/README.md @@ -46,6 +46,8 @@ The `abigen!` macro takes 2 or 3 inputs: 3. Optional parameters: - `output_path`: if provided, the content will be generated in the given file instead of being expanded at the location of the macro invocation. - `type_aliases`: to avoid type name conflicts between components / contracts, you can rename some type by providing an alias for the full type path. It is important to give the **full** type path to ensure aliases are applied correctly. + - `derive`: to specify the derive for the generated structs/enums. + - `contract_derives`: to specify the derive for the generated contract type. ```rust use cainome::rs::abigen; @@ -66,6 +68,14 @@ abigen!( }, ); +// Example with custom derives: +abigen!( + MyContract, + "./contracts/abi/components.abi.json", + derive(Debug, Clone), + contract_derives(Debug, Clone) +); + fn main() { // ... use the generated types here, which all of them // implement CairoSerde trait. diff --git a/crates/rs-macro/src/lib.rs b/crates/rs-macro/src/lib.rs index dac8a80..32d31ee 100644 --- a/crates/rs-macro/src/lib.rs +++ b/crates/rs-macro/src/lib.rs @@ -37,6 +37,7 @@ fn abigen_internal(input: TokenStream) -> TokenStream { &abi_tokens, contract_abi.execution_version, &contract_abi.derives, + &contract_abi.contract_derives, ); if let Some(out_path) = contract_abi.output_path { @@ -66,6 +67,7 @@ fn abigen_internal_legacy(input: TokenStream) -> TokenStream { &abi_tokens, cainome_rs::ExecutionVersion::V1, &contract_abi.derives, + &contract_abi.contract_derives, ); if let Some(out_path) = contract_abi.output_path { diff --git a/crates/rs-macro/src/macro_inputs.rs b/crates/rs-macro/src/macro_inputs.rs index 212a112..7098615 100644 --- a/crates/rs-macro/src/macro_inputs.rs +++ b/crates/rs-macro/src/macro_inputs.rs @@ -39,6 +39,7 @@ pub(crate) struct ContractAbi { pub type_aliases: HashMap, pub execution_version: ExecutionVersion, pub derives: Vec, + pub contract_derives: Vec, } impl Parse for ContractAbi { @@ -92,6 +93,7 @@ impl Parse for ContractAbi { let mut execution_version = ExecutionVersion::V1; let mut type_aliases = HashMap::new(); let mut derives = Vec::new(); + let mut contract_derives = Vec::new(); loop { if input.parse::().is_err() { @@ -153,6 +155,15 @@ impl Parse for ContractAbi { derives.push(derive.to_token_stream().to_string()); } } + "contract_derives" => { + let content; + parenthesized!(content in input); + let parsed = content.parse_terminated(Spanned::::parse, Token![,])?; + + for derive in parsed { + contract_derives.push(derive.to_token_stream().to_string()); + } + } _ => emit_error!(name.span(), format!("unexpected named parameter `{name}`")), } } @@ -164,6 +175,7 @@ impl Parse for ContractAbi { type_aliases, execution_version, derives, + contract_derives, }) } } diff --git a/crates/rs-macro/src/macro_inputs_legacy.rs b/crates/rs-macro/src/macro_inputs_legacy.rs index 43cd6d1..2641ead 100644 --- a/crates/rs-macro/src/macro_inputs_legacy.rs +++ b/crates/rs-macro/src/macro_inputs_legacy.rs @@ -36,6 +36,7 @@ pub(crate) struct ContractAbiLegacy { pub output_path: Option, pub type_aliases: HashMap, pub derives: Vec, + pub contract_derives: Vec, } impl Parse for ContractAbiLegacy { @@ -89,6 +90,7 @@ impl Parse for ContractAbiLegacy { let mut output_path: Option = None; let mut type_aliases = HashMap::new(); let mut derives = Vec::new(); + let mut contract_derives = Vec::new(); loop { if input.parse::().is_err() { @@ -142,6 +144,15 @@ impl Parse for ContractAbiLegacy { derives.push(derive.to_token_stream().to_string()); } } + "contract_derives" => { + let content; + parenthesized!(content in input); + let parsed = content.parse_terminated(Spanned::::parse, Token![,])?; + + for derive in parsed { + contract_derives.push(derive.to_token_stream().to_string()); + } + } _ => emit_error!(name.span(), format!("unexpected named parameter `{name}`")), } } @@ -152,6 +163,7 @@ impl Parse for ContractAbiLegacy { output_path, type_aliases, derives, + contract_derives, }) } } diff --git a/crates/rs/src/expand/contract.rs b/crates/rs/src/expand/contract.rs index 8405bc8..ad00de7 100644 --- a/crates/rs/src/expand/contract.rs +++ b/crates/rs/src/expand/contract.rs @@ -7,16 +7,22 @@ use super::utils; pub struct CairoContract; impl CairoContract { - pub fn expand(contract_name: Ident) -> TokenStream2 { + pub fn expand(contract_name: Ident, contract_derives: &[String]) -> TokenStream2 { let reader = utils::str_to_ident(format!("{}Reader", contract_name).as_str()); let snrs_types = utils::snrs_types(); let snrs_accounts = utils::snrs_accounts(); let snrs_providers = utils::snrs_providers(); + let mut internal_derives = vec![]; + + for d in contract_derives { + internal_derives.push(utils::str_to_type(d)); + } + let q = quote! { - #[derive(Debug)] + #[derive(#(#internal_derives,)*)] pub struct #contract_name { pub address: #snrs_types::Felt, pub account: A, @@ -45,7 +51,7 @@ impl CairoContract { } } - #[derive(Debug)] + #[derive(#(#internal_derives,)*)] pub struct #reader { pub address: #snrs_types::Felt, pub provider: P, diff --git a/crates/rs/src/lib.rs b/crates/rs/src/lib.rs index 1f9f288..8755589 100644 --- a/crates/rs/src/lib.rs +++ b/crates/rs/src/lib.rs @@ -73,6 +73,8 @@ pub struct Abigen { pub execution_version: ExecutionVersion, /// Derives to be added to the generated types. pub derives: Vec, + /// Derives to be added to the generated contract. + pub contract_derives: Vec, } impl Abigen { @@ -90,6 +92,7 @@ impl Abigen { types_aliases: HashMap::new(), execution_version: ExecutionVersion::V1, derives: vec![], + contract_derives: vec![], } } @@ -123,6 +126,16 @@ impl Abigen { self } + /// Sets the derives to be added to the generated contract. + /// + /// # Arguments + /// + /// * `derives` - Derives to be added to the generated contract. + pub fn with_contract_derives(mut self, derives: Vec) -> Self { + self.contract_derives = derives; + self + } + /// Generates the contract bindings. pub fn generate(&self) -> Result { let file_content = std::fs::read_to_string(&self.abi_source)?; @@ -134,6 +147,7 @@ impl Abigen { &tokens, self.execution_version, &self.derives, + &self.contract_derives, ); Ok(ContractBindings { @@ -157,17 +171,24 @@ impl Abigen { /// /// * `contract_name` - Name of the contract. /// * `abi_tokens` - Tokenized ABI. +/// * `execution_version` - The version of transaction to be executed. +/// * `derives` - Derives to be added to the generated types. +/// * `contract_derives` - Derives to be added to the generated contract. pub fn abi_to_tokenstream( contract_name: &str, abi_tokens: &TokenizedAbi, execution_version: ExecutionVersion, derives: &[String], + contract_derives: &[String], ) -> TokenStream2 { let contract_name = utils::str_to_ident(contract_name); let mut tokens: Vec = vec![]; - tokens.push(CairoContract::expand(contract_name.clone())); + tokens.push(CairoContract::expand( + contract_name.clone(), + contract_derives, + )); let mut sorted_structs = abi_tokens.structs.clone(); sorted_structs.sort_by(|a, b| { diff --git a/examples/abigen_generate.rs b/examples/abigen_generate.rs index ae250e2..d294609 100644 --- a/examples/abigen_generate.rs +++ b/examples/abigen_generate.rs @@ -10,7 +10,9 @@ async fn main() { "MyContract", "./contracts/target/dev/contracts_simple_get_set.contract_class.json", ) - .with_types_aliases(aliases); + .with_types_aliases(aliases) + .with_derives(vec!["Debug".to_string(), "PartialEq".to_string()]) + .with_contract_derives(vec!["Debug".to_string(), "Clone".to_string()]); abigen .generate() diff --git a/examples/simple_get_set.rs b/examples/simple_get_set.rs index e6a8aa8..f7ebcaa 100644 --- a/examples/simple_get_set.rs +++ b/examples/simple_get_set.rs @@ -20,7 +20,9 @@ const KATANA_CHAIN_ID: &str = "0x4b4154414e41"; // Or you can use the extracted abi entries with jq in contracts/abi/. abigen!( MyContract, - "./contracts/target/dev/contracts_simple_get_set.contract_class.json" + "./contracts/target/dev/contracts_simple_get_set.contract_class.json", + derives(Debug, PartialEq), + contract_derives(Debug, Clone) ); //abigen!(MyContract, "./contracts/abi/simple_get_set.abi.json"); diff --git a/src/bin/cli/args.rs b/src/bin/cli/args.rs index adb4b25..086d3d7 100644 --- a/src/bin/cli/args.rs +++ b/src/bin/cli/args.rs @@ -67,6 +67,11 @@ pub struct CainomeArgs { #[arg(value_name = "DERIVES")] #[arg(help = "Derives to be added to the generated types.")] pub derives: Option>, + + #[arg(long)] + #[arg(value_name = "CONTRACT_DERIVES")] + #[arg(help = "Derives to be added to the generated contract.")] + pub contract_derives: Option>, } #[derive(Debug, Args, Clone)] diff --git a/src/bin/cli/main.rs b/src/bin/cli/main.rs index 7786c6d..8b77e2f 100644 --- a/src/bin/cli/main.rs +++ b/src/bin/cli/main.rs @@ -53,6 +53,7 @@ async fn main() -> CainomeCliResult<()> { contracts, execution_version: args.execution_version, derives: args.derives.unwrap_or_default(), + contract_derives: args.contract_derives.unwrap_or_default(), }) .await?; diff --git a/src/bin/cli/plugins/builtins/rust.rs b/src/bin/cli/plugins/builtins/rust.rs index ddbb0cd..f21a949 100644 --- a/src/bin/cli/plugins/builtins/rust.rs +++ b/src/bin/cli/plugins/builtins/rust.rs @@ -37,6 +37,7 @@ impl BuiltinPlugin for RustPlugin { &contract.tokens, input.execution_version, &input.derives, + &input.contract_derives, ); let filename = format!( "{}.rs", diff --git a/src/bin/cli/plugins/mod.rs b/src/bin/cli/plugins/mod.rs index 3c79e43..41e5a11 100644 --- a/src/bin/cli/plugins/mod.rs +++ b/src/bin/cli/plugins/mod.rs @@ -14,6 +14,7 @@ pub struct PluginInput { pub contracts: Vec, pub execution_version: ExecutionVersion, pub derives: Vec, + pub contract_derives: Vec, } #[derive(Debug)]