Skip to content

Commit

Permalink
Add collect aux data mechanism
Browse files Browse the repository at this point in the history
commit-id:af02bb94
  • Loading branch information
maciektr committed Mar 6, 2024
1 parent d877a2f commit 313f772
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 38 deletions.
23 changes: 16 additions & 7 deletions scarb/src/compiler/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,41 @@ use smol_str::SmolStr;
use std::sync::Arc;
use tracing::trace;

use crate::compiler::plugin::proc_macro::ProcMacroHost;
use crate::compiler::plugin::proc_macro::{ProcMacroHost, ProcMacroHostPlugin};
use crate::compiler::{CairoCompilationUnit, CompilationUnitAttributes, CompilationUnitComponent};
use crate::core::Workspace;
use crate::DEFAULT_MODULE_MAIN_FILE;

pub struct BuiltScarbDatabase {
pub db: RootDatabase,
pub proc_macro_host: Arc<ProcMacroHostPlugin>,
}

// TODO(mkaput): ScarbDatabase?
pub(crate) fn build_scarb_root_database(
unit: &CairoCompilationUnit,
ws: &Workspace<'_>,
) -> Result<RootDatabase> {
) -> Result<BuiltScarbDatabase> {
let mut b = RootDatabase::builder();
b.with_project_config(build_project_config(unit)?);
b.with_cfg(unit.cfg_set.clone());
load_plugins(unit, ws, &mut b)?;
let proc_macro_host = load_plugins(unit, ws, &mut b)?;
if !unit.compiler_config.enable_gas {
b.skip_auto_withdraw_gas();
}
let mut db = b.build()?;
inject_virtual_wrapper_lib(&mut db, unit)?;
Ok(db)
Ok(BuiltScarbDatabase {
db,
proc_macro_host,
})
}

fn load_plugins(
unit: &CairoCompilationUnit,
ws: &Workspace<'_>,
builder: &mut RootDatabaseBuilder,
) -> Result<()> {
) -> Result<Arc<ProcMacroHostPlugin>> {
let mut proc_macros = ProcMacroHost::default();
for plugin_info in &unit.cairo_plugins {
if plugin_info.builtin {
Expand All @@ -49,8 +57,9 @@ fn load_plugins(
proc_macros.register(plugin_info.package.clone(), ws.config())?;
}
}
builder.with_plugin_suite(proc_macros.into_plugin_suite());
Ok(())
let macro_host = Arc::new(proc_macros.into_plugin());
builder.with_plugin_suite(ProcMacroHostPlugin::build_plugin_suite(macro_host.clone()));
Ok(macro_host)
}

/// Generates a wrapper lib file for appropriate compilation units.
Expand Down
72 changes: 60 additions & 12 deletions scarb/src/compiler/plugin/proc_macro/host.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::compiler::plugin::proc_macro::{FromItemAst, ProcMacroInstance};
use crate::core::{Config, Package, PackageId};
use anyhow::Result;
use cairo_lang_defs::db::DefsGroup;
use cairo_lang_defs::plugin::PluginDiagnostic;
use cairo_lang_defs::plugin::{
DynGeneratedFileAuxData, GeneratedFileAuxData, MacroPlugin, MacroPluginMetadata,
Expand Down Expand Up @@ -47,11 +48,25 @@ pub struct ProcMacroInput {
}

#[derive(Clone, Debug)]
pub struct ProcMacroAuxData(String);
pub struct ProcMacroAuxData {
value: String,
macro_id: ProcMacroId,
macro_package_id: PackageId,
}

impl ProcMacroAuxData {
pub fn new(value: String, macro_id: ProcMacroId, macro_package_id: PackageId) -> Self {
Self {
value,
macro_id,
macro_package_id,
}
}
}

impl From<ProcMacroAuxData> for AuxData {
fn from(data: ProcMacroAuxData) -> Self {
Self::new(data.0)
Self::new(data.value)
}
}

Expand All @@ -61,7 +76,8 @@ impl GeneratedFileAuxData for ProcMacroAuxData {
}

fn eq(&self, other: &dyn GeneratedFileAuxData) -> bool {
self.0 == other.as_any().downcast_ref::<Self>().unwrap().0
self.value == other.as_any().downcast_ref::<Self>().unwrap().value
&& self.macro_id == other.as_any().downcast_ref::<Self>().unwrap().macro_id
}
}

Expand Down Expand Up @@ -128,6 +144,36 @@ impl ProcMacroHostPlugin {
.find(|m| m.declared_attributes().contains(&name))
.map(|m| m.package_id())
}

pub fn build_plugin_suite(macr_host: Arc<Self>) -> PluginSuite {
let mut suite = PluginSuite::default();
suite.add_plugin_ex(macr_host);
suite
}

#[tracing::instrument(level = "trace", skip_all)]
pub fn collect_aux_data(&self, db: &dyn DefsGroup) -> Result<()> {
let mut data = Vec::new();
for crate_id in db.crates() {
let crate_modules = db.crate_modules(crate_id);
for module in crate_modules.iter() {
let file_infos = db.module_generated_file_infos(*module);
if let Ok(file_infos) = file_infos {
for file_info in file_infos.iter().flatten() {
let aux_data = file_info
.aux_data
.as_ref()
.and_then(|ad| ad.as_any().downcast_ref::<ProcMacroAuxData>());
if let Some(aux_data) = aux_data {
data.push(aux_data.clone());
}
}
}
}
}
let _aux_data = data.into_iter().into_group_map_by(|d| d.macro_package_id);
Ok(())
}
}

impl MacroPlugin for ProcMacroHostPlugin {
Expand All @@ -146,7 +192,7 @@ impl MacroPlugin for ProcMacroHostPlugin {
let stable_ptr = item_ast.clone().stable_ptr().untyped();

let mut token_stream = TokenStream::from_item_ast(db, item_ast);
let mut aux_data: Option<AuxData> = None;
let mut aux_data: Option<ProcMacroAuxData> = None;
let mut modified = false;
let mut all_diagnostics: Vec<Diagnostic> = Vec::new();
for input in expansions {
Expand All @@ -162,7 +208,13 @@ impl MacroPlugin for ProcMacroHostPlugin {
diagnostics,
} => {
token_stream = new_token_stream;
aux_data = new_aux_data;
if let Some(new_aux_data) = new_aux_data {
aux_data = Some(ProcMacroAuxData::new(
new_aux_data.to_string(),
input.id,
input.macro_package_id,
));
}
modified = true;
all_diagnostics.extend(diagnostics);
}
Expand All @@ -185,8 +237,7 @@ impl MacroPlugin for ProcMacroHostPlugin {
name: "proc_macro".into(),
content: token_stream.to_string(),
code_mappings: Default::default(),
aux_data: aux_data
.map(|ad| DynGeneratedFileAuxData::new(ProcMacroAuxData(ad.to_string()))),
aux_data: aux_data.map(DynGeneratedFileAuxData::new),
}),
diagnostics: into_cairo_diagnostics(all_diagnostics, stable_ptr),
remove_original_item: true,
Expand Down Expand Up @@ -241,10 +292,7 @@ impl ProcMacroHost {
Ok(())
}

pub fn into_plugin_suite(self) -> PluginSuite {
let macro_host = ProcMacroHostPlugin::new(self.macros);
let mut suite = PluginSuite::default();
suite.add_plugin_ex(Arc::new(macro_host));
suite
pub fn into_plugin(self) -> ProcMacroHostPlugin {
ProcMacroHostPlugin::new(self.macros)
}
}
47 changes: 28 additions & 19 deletions scarb/src/ops/compile.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use anyhow::{anyhow, Result};
use cairo_lang_compiler::db::RootDatabase;
use cairo_lang_compiler::diagnostics::DiagnosticsError;
use cairo_lang_utils::Upcast;
use indoc::formatdoc;
use itertools::Itertools;

use scarb_ui::components::Status;
use scarb_ui::HumanDuration;

use crate::compiler::db::{build_scarb_root_database, has_starknet_plugin};
use crate::compiler::db::{build_scarb_root_database, has_starknet_plugin, BuiltScarbDatabase};
use crate::compiler::helpers::build_compiler_config;
use crate::compiler::plugin::proc_macro;
use crate::compiler::{CairoCompilationUnit, CompilationUnit, CompilationUnitAttributes};
Expand Down Expand Up @@ -104,9 +105,14 @@ fn compile_unit(unit: CompilationUnit, ws: &Workspace<'_>) -> Result<()> {
let result = match unit {
CompilationUnit::ProcMacro(unit) => proc_macro::compile_unit(unit, ws),
CompilationUnit::Cairo(unit) => {
let mut db = build_scarb_root_database(&unit, ws)?;
let BuiltScarbDatabase {
mut db,
proc_macro_host,
} = build_scarb_root_database(&unit, ws)?;
check_starknet_dependency(&unit, ws, &db, &package_name);
ws.config().compilers().compile(unit, &mut db, ws)
let result = ws.config().compilers().compile(unit, &mut db, ws);
proc_macro_host.collect_aux_data(db.upcast())?;
result
}
};

Expand All @@ -126,28 +132,31 @@ fn check_unit(unit: CompilationUnit, ws: &Workspace<'_>) -> Result<()> {
.ui()
.print(Status::new("Checking", &unit.name()));

match unit {
CompilationUnit::ProcMacro(unit) => proc_macro::check_unit(unit, ws)?,
let result = match unit {
CompilationUnit::ProcMacro(unit) => proc_macro::check_unit(unit, ws),
CompilationUnit::Cairo(unit) => {
let db = build_scarb_root_database(&unit, ws)?;

let BuiltScarbDatabase {
db,
proc_macro_host,
} = build_scarb_root_database(&unit, ws)?;
check_starknet_dependency(&unit, ws, &db, &package_name);

let mut compiler_config = build_compiler_config(&unit, ws);

compiler_config
let result = compiler_config
.diagnostics_reporter
.ensure(&db)
.map_err(|err| {
let valid_error = err.into();
if !suppress_error(&valid_error) {
ws.config().ui().anyhow(&valid_error);
}

anyhow!("could not check `{package_name}` due to previous error")
})?;
.map_err(|err| err.into());
proc_macro_host.collect_aux_data(db.upcast())?;
result
}
}
};

result.map_err(|err| {
if !suppress_error(&err) {
ws.config().ui().anyhow(&err);
}

anyhow!("could not check `{package_name}` due to previous error")
})?;

Ok(())
}
Expand Down

0 comments on commit 313f772

Please sign in to comment.