Skip to content

Commit

Permalink
Implement aux_data callback
Browse files Browse the repository at this point in the history
commit-id:5a9c140b
  • Loading branch information
maciektr committed Mar 1, 2024
1 parent 6010202 commit 92547c0
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 8 deletions.
22 changes: 22 additions & 0 deletions plugins/cairo-lang-macro-attributes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,25 @@ pub fn attribute_macro(_args: TokenStream, input: TokenStream) -> TokenStream {
};
TokenStream::from(expanded)
}

/// AuxData callback helper.
///
/// This macro hides the conversion to stable ABI structs from the user.
///
/// # Safety
/// Note that AuxData deserialization may fail.
#[proc_macro_attribute]
pub fn aux_data_callback(_args: TokenStream, input: TokenStream) -> TokenStream {
let item: ItemFn = parse_macro_input!(input as ItemFn);
let item_name = &item.sig.ident;
let expanded = quote! {
#item

#[no_mangle]
pub unsafe extern "C" fn aux_data_callback(aux_data: *mut cairo_lang_macro_stable::StableAuxData, aux_data_n: usize) {
let aux_data = cairo_lang_macro::AuxData::deallocate(aux_data, aux_data_n);
#item_name(aux_data);
}
};
TokenStream::from(expanded)
}
44 changes: 40 additions & 4 deletions plugins/cairo-lang-macro/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use anyhow::Result;
use std::ffi::{c_char, CString};
use std::fmt::Display;
use std::slice;
Expand Down Expand Up @@ -49,6 +50,12 @@ impl AuxData {
pub fn try_new<T: serde::Serialize>(value: T) -> Result<Self, serde_json::Error> {
Ok(Self(serde_json::to_string(&value)?))
}

pub fn from_aux_data<T: serde::de::DeserializeOwned>(
aux_data: Self,
) -> Result<T, serde_json::Error> {
serde_json::from_str(&aux_data.to_string())
}
}

impl Display for AuxData {
Expand Down Expand Up @@ -153,7 +160,7 @@ impl ProcMacroResult {
let diagnostics = Diagnostic::deallocate(diagnostics, diagnostics_n);
ProcMacroResult::Replace {
token_stream: TokenStream::from_stable(token_stream),
aux_data: AuxData::from_stable(aux_data).unwrap(),
aux_data: AuxData::from_stable(aux_data),
diagnostics,
}
}
Expand Down Expand Up @@ -205,12 +212,41 @@ impl AuxData {
///
/// # Safety
#[doc(hidden)]
pub unsafe fn from_stable(aux_data: StableAuxData) -> Result<Option<Self>, serde_json::Error> {
pub unsafe fn from_stable(aux_data: StableAuxData) -> Option<Self> {
match aux_data {
StableAuxData::None => Ok(None),
StableAuxData::Some(raw) => Some(Self::try_new(raw_to_string(raw))).transpose(),
StableAuxData::None => None,
StableAuxData::Some(raw) => Some(Self::new(raw_to_string(raw))),
}
}

/// Allocate array with FFI-safe aux_data.
///
/// # Safety
#[doc(hidden)]
pub unsafe fn allocate(aux_data: Vec<AuxData>) -> (*mut StableAuxData, usize) {
let stable_aux_data = aux_data
.into_iter()
.map(|a| a.into_stable())
.collect::<Vec<_>>();
let mut slice: Box<[StableAuxData]> = stable_aux_data.into_boxed_slice();
let ptr: *mut StableAuxData = slice.as_mut_ptr();
let len = slice.len();
assert!(!ptr.is_null(), "failed to allocate aux_data array");
std::mem::forget(slice);
(ptr, len)
}

/// Deallocate array of aux_data, returning a vector.
///
/// # Safety
pub unsafe fn deallocate(diagnostics: *mut StableAuxData, n: usize) -> Vec<AuxData> {
let aux_data = Box::from_raw(slice::from_raw_parts_mut(diagnostics, n));
aux_data
.into_vec()
.into_iter()
.filter_map(|a| unsafe { AuxData::from_stable(a) })
.collect()
}
}

impl Diagnostic {
Expand Down
25 changes: 22 additions & 3 deletions scarb/src/compiler/plugin/proc_macro/ffi.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::core::{Config, Package, PackageId};
use anyhow::{Context, Result};
use cairo_lang_defs::patcher::PatchBuilder;
use cairo_lang_macro::{ProcMacroResult, TokenStream};
use cairo_lang_macro_stable::{StableProcMacroResult, StableTokenStream};
use cairo_lang_macro::{AuxData, ProcMacroResult, TokenStream};
use cairo_lang_macro_stable::{StableAuxData, StableProcMacroResult, StableTokenStream};
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::{ast, TypedSyntaxNode};
use camino::Utf8PathBuf;
use libloading::{Library, Symbol};
use std::fmt::Debug;

use crate::compiler::plugin::proc_macro::compilation::SharedLibraryProvider;
use crate::compiler::plugin::proc_macro::ProcMacroAuxData;
#[cfg(not(windows))]
use libloading::os::unix::Symbol as RawSymbol;
#[cfg(windows)]
Expand Down Expand Up @@ -70,12 +71,21 @@ impl ProcMacroInstance {
let result = (self.plugin.vtable.expand)(ffi_token_stream);
unsafe { ProcMacroResult::from_stable(result) }
}

pub(crate) fn aux_data_callback(&self, aux_data: Vec<ProcMacroAuxData>) {
let aux_data = aux_data.into_iter().map(Into::into).collect();
let (ptr, n) = unsafe { AuxData::allocate(aux_data) };

(self.plugin.vtable.aux_data_callback)(ptr, n);
}
}

type ExpandCode = extern "C" fn(StableTokenStream) -> StableProcMacroResult;
type AuxDataCallback = extern "C" fn(*mut StableAuxData, usize);

struct VTableV0 {
expand: RawSymbol<ExpandCode>,
aux_data_callback: RawSymbol<AuxDataCallback>,
}

impl VTableV0 {
Expand All @@ -84,7 +94,16 @@ impl VTableV0 {
.get(b"expand\0")
.context("failed to load expand function for procedural macro")?;
let expand = expand.into_raw();
Ok(VTableV0 { expand })

let aux_data_callback: Symbol<'_, AuxDataCallback> = library
.get(b"aux_data_callback\0")
.context("failed to load aux_data_callback function for procedural macro")?;
let aux_data_callback = aux_data_callback.into_raw();

Ok(VTableV0 {
expand,
aux_data_callback,
})
}
}

Expand Down
8 changes: 7 additions & 1 deletion scarb/src/compiler/plugin/proc_macro/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,13 @@ impl ProcMacroHostPlugin {
}
}
}
let _aux_data = data.into_iter().into_group_map_by(|d| d.macro_package_id);
let aux_data = data.into_iter().into_group_map_by(|d| d.macro_package_id);
for instance in self.macros.iter() {
let data = aux_data.get(&instance.package_id()).cloned();
if let Some(data) = data {
instance.aux_data_callback(data.clone());
}
}
Ok(())
}
}
Expand Down
77 changes: 77 additions & 0 deletions scarb/tests/build_cairo_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ fn simple_project_with_code(t: &impl PathChild, code: impl ToString) {
[dependencies]
cairo-lang-macro = {{ path = {macro_lib_path}}}
cairo-lang-macro-stable = {{ path = {macro_stable_lib_path}}}
serde = "*"
serde_json = "*"
"#},
)
.build(t);
Expand Down Expand Up @@ -451,3 +453,78 @@ fn can_replace_original_node() {
Run completed successfully, returning [34]
"#});
}

#[test]
fn can_return_aux_data_from_plugin() {
let temp = TempDir::new().unwrap();
let t = temp.child("some");
simple_project_with_code(
&t,
indoc! {r##"
use cairo_lang_macro::{ProcMacroResult, TokenStream, attribute_macro, AuxData, aux_data_callback};
use serde::{Serialize, Deserialize};
#[derive(Debug, Serialize, Deserialize)]
struct SomeMacroDataFormat {
msg: String
}
#[attribute_macro]
pub fn some_macro(token_stream: TokenStream) -> ProcMacroResult {
let token_stream = TokenStream::new(
token_stream
.to_string()
// Remove macro call to avoid infinite loop.
.replace("#[some]", "")
.replace("12", "34")
);
let aux_data = AuxData::try_new(
SomeMacroDataFormat { msg: "Hello from some macro!".to_string() }
).unwrap();
ProcMacroResult::Replace {
token_stream,
aux_data: Some(aux_data),
diagnostics: Vec::new()
}
}
#[aux_data_callback]
pub fn callback(aux_data: Vec<AuxData>) {
let aux_data = aux_data.into_iter()
.map(AuxData::from_aux_data::<SomeMacroDataFormat>)
.collect::<Result<Vec<_>, serde_json::Error>>();
println!("{:?}", aux_data);
}
"##},
);

let project = temp.child("hello");
ProjectBuilder::start()
.name("hello")
.version("1.0.0")
.dep_starknet()
.dep("some", &t)
.lib_cairo(indoc! {r#"
#[some]
fn main() -> felt252 { 12 }
"#})
.build(&project);

Scarb::quick_snapbox()
.arg("cairo-run")
// Disable output from Cargo.
.env("CARGO_TERM_QUIET", "true")
.current_dir(&project)
.assert()
.success()
.stdout_matches(indoc! {r#"
[..]Compiling some v1.0.0 ([..]Scarb.toml)
[..]Compiling hello v1.0.0 ([..]Scarb.toml)
Ok([SomeMacroDataFormat { msg: "Hello from some macro!" }])
[..]Finished release target(s) in [..]
[..]Running hello
[..]Run completed successfully, returning [..]
"#});
}

0 comments on commit 92547c0

Please sign in to comment.