diff --git a/core/src/language/python.rs b/core/src/language/python.rs index 041a2bbe..cf8be267 100644 --- a/core/src/language/python.rs +++ b/core/src/language/python.rs @@ -1,11 +1,11 @@ use crate::parser::ParsedData; -use crate::rust_types::{RustType, RustTypeFormatError, SpecialRustType}; +use crate::rust_types::{RustItem, RustType, RustTypeFormatError, SpecialRustType}; +use crate::topsort::topsort; use crate::{ language::Language, rust_types::{RustEnum, RustEnumVariant, RustField, RustStruct, RustTypeAlias}, }; use once_cell::sync::Lazy; -use std::collections::hash_map::Entry; use std::collections::HashSet; use std::hash::Hash; use std::{collections::HashMap, io::Write}; @@ -13,95 +13,6 @@ use std::{collections::HashMap, io::Write}; use super::CrateTypes; use convert_case::{Case, Casing}; -use topological_sort::TopologicalSort; - -#[derive(Debug, Default)] -pub struct Module { - // HashMap - imports: HashMap>, - // HashMap> - // Used to lay out runtime references in the module - // such that it can be read top to bottom - globals: HashMap>, - type_variables: HashSet, -} - -#[derive(Debug)] -struct GenerationError; - -impl Module { - // Idempotently insert an import - fn add_import(&mut self, module: String, identifier: String) { - self.imports.entry(module).or_default().insert(identifier); - } - fn add_global(&mut self, identifier: String, deps: Vec) { - match self.globals.entry(identifier) { - Entry::Occupied(mut e) => e.get_mut().extend_from_slice(&deps), - Entry::Vacant(e) => { - e.insert(deps); - } - } - } - fn add_type_var(&mut self, name: String) { - self.add_import("typing".to_string(), "TypeVar".to_string()); - self.type_variables.insert(name); - } - fn get_type_vars(&mut self, n: usize) -> Vec { - let vars: Vec = (0..n) - .map(|i| { - if i == 0 { - "T".to_string() - } else { - format!("T{}", i) - } - }) - .collect(); - vars.iter().for_each(|tv| self.add_type_var(tv.clone())); - vars - } - // Rust lets you declare type aliases before the struct they point to. - // But in Python we need the struct to come first. - // So we need to topologically sort the globals so that type aliases - // always come _after_ the struct/enum they point to. - fn topologically_sorted_globals(&self) -> Result, GenerationError> { - let mut ts: TopologicalSort = TopologicalSort::new(); - for (identifier, dependencies) in &self.globals { - for dependency in dependencies { - ts.add_dependency(dependency.clone(), identifier.clone()) - } - } - let mut res: Vec = Vec::new(); - loop { - let mut level = ts.pop_all(); - level.sort(); - res.extend_from_slice(&level); - if level.is_empty() { - if !ts.is_empty() { - return Err(GenerationError); - } - break; - } - } - let existing: HashSet<&String> = HashSet::from_iter(res.iter()); - let mut missing: Vec = self - .globals - .keys() - .filter(|&k| !existing.contains(k)) - .cloned() - .collect(); - missing.sort(); - res.extend(missing); - Ok(res) - } -} - -#[derive(Debug, Clone)] -enum ParsedRustThing<'a> { - Struct(&'a RustStruct), - Enum(&'a RustEnum), - TypeAlias(&'a RustTypeAlias), -} - // Collect unique type vars from an enum field // Since we explode enums into unions of types, we need to extract all of the generics // used by each individual field @@ -155,8 +66,13 @@ fn dedup(v: &mut Vec) { pub struct Python { /// Mappings from Rust type names to Python type names pub type_mappings: HashMap, - /// The Python module for the generated code. - pub module: Module, + // HashMap + pub imports: HashMap>, + // HashMap> + // Used to lay out runtime references in the module + // such that it can be read top to bottom + // globals: HashMap>, + pub type_variables: HashSet, } impl Language for Python { @@ -169,61 +85,36 @@ impl Language for Python { _imports: &CrateTypes, data: ParsedData, ) -> std::io::Result<()> { - let mut globals: Vec; - { - for alias in &data.aliases { - let thing = ParsedRustThing::TypeAlias(alias); - let identifier = self.get_identifier(thing); - match &alias.r#type { - RustType::Generic { id, parameters: _ } => { - self.module.add_global(identifier, vec![id.clone()]) - } - RustType::Simple { id } => self.module.add_global(identifier, vec![id.clone()]), - RustType::Special(_) => {} - } - } - for strct in &data.structs { - let thing = ParsedRustThing::Struct(strct); - let identifier = self.get_identifier(thing); - self.module.add_global(identifier, vec![]); - } - for enm in &data.enums { - let thing = ParsedRustThing::Enum(enm); - let identifier = self.get_identifier(thing); - self.module.add_global(identifier, vec![]); - } - globals = data - .aliases - .iter() - .map(ParsedRustThing::TypeAlias) - .chain(data.structs.iter().map(ParsedRustThing::Struct)) - .chain(data.enums.iter().map(ParsedRustThing::Enum)) - .collect(); - let sorted_identifiers = self.module.topologically_sorted_globals().unwrap(); - globals.sort_by(|a, b| { - let identifier_a = self.get_identifier(a.clone()); - let identifier_b = self.get_identifier(b.clone()); - let pos_a = sorted_identifiers - .iter() - .position(|o| o.eq(&identifier_a)) - .unwrap_or(0); - let pos_b = sorted_identifiers - .iter() - .position(|o| o.eq(&identifier_b)) - .unwrap_or(0); - pos_a.cmp(&pos_b) - }); - } + self.begin_file(w, &data)?; + + let ParsedData { + structs, + enums, + aliases, + .. + } = data; + + let mut items = aliases + .into_iter() + .map(RustItem::Alias) + .chain(structs.into_iter().map(RustItem::Struct)) + .chain(enums.into_iter().map(RustItem::Enum)) + .collect::>(); + + topsort(&mut items); + let mut body: Vec = Vec::new(); - for thing in globals { + for thing in items { match thing { - ParsedRustThing::Enum(e) => self.write_enum(&mut body, e)?, - ParsedRustThing::Struct(rs) => self.write_struct(&mut body, rs)?, - ParsedRustThing::TypeAlias(t) => self.write_type_alias(&mut body, t)?, + RustItem::Enum(e) => self.write_enum(&mut body, &e)?, + RustItem::Struct(rs) => self.write_struct(&mut body, &rs)?, + RustItem::Alias(t) => self.write_type_alias(&mut body, &t)?, }; } - self.begin_file(w, &data)?; - let _ = w.write(&body)?; + + self.write_all_imports(w)?; + + w.write_all(&body)?; Ok(()) } @@ -273,15 +164,13 @@ impl Language for Python { SpecialRustType::Vec(rtype) | SpecialRustType::Array(rtype, _) | SpecialRustType::Slice(rtype) => { - self.module - .add_import("typing".to_string(), "List".to_string()); + self.add_import("typing".to_string(), "List".to_string()); Ok(format!("List[{}]", self.format_type(rtype, generic_types)?)) } // We add optionality above the type formatting level SpecialRustType::Option(rtype) => self.format_type(rtype, generic_types), SpecialRustType::HashMap(rtype1, rtype2) => { - self.module - .add_import("typing".to_string(), "Dict".to_string()); + self.add_import("typing".to_string(), "Dict".to_string()); Ok(format!( "Dict[{}, {}]", match rtype1.as_ref() { @@ -313,32 +202,9 @@ impl Language for Python { } fn begin_file(&mut self, w: &mut dyn Write, _parsed_data: &ParsedData) -> std::io::Result<()> { - let mut type_var_names: Vec = self.module.type_variables.iter().cloned().collect(); - type_var_names.sort(); - let type_vars: Vec = type_var_names - .iter() - .map(|name| format!("{} = TypeVar(\"{}\")", name, name)) - .collect(); - let mut imports = vec![]; - for (import_module, identifiers) in &self.module.imports { - let mut identifier_vec = identifiers.iter().cloned().collect::>(); - identifier_vec.sort(); - imports.push(format!( - "from {} import {}", - import_module, - identifier_vec.join(", ") - )) - } - imports.sort(); writeln!(w, "\"\"\"")?; writeln!(w, " Generated by typeshare {}", env!("CARGO_PKG_VERSION"))?; writeln!(w, "\"\"\"")?; - writeln!(w, "from __future__ import annotations\n").unwrap(); - writeln!(w, "{}\n", imports.join("\n"))?; - match type_vars.is_empty() { - true => writeln!(w).unwrap(), - false => writeln!(w, "{}\n\n", type_vars.join("\n")).unwrap(), - }; Ok(()) } @@ -367,15 +233,13 @@ impl Language for Python { rs.generic_types .iter() .cloned() - .for_each(|v| self.module.add_type_var(v)) + .for_each(|v| self.add_type_var(v)) } let bases = match rs.generic_types.is_empty() { true => "BaseModel".to_string(), false => { - self.module - .add_import("pydantic.generics".to_string(), "GenericModel".to_string()); - self.module - .add_import("typing".to_string(), "Generic".to_string()); + self.add_import("pydantic.generics".to_string(), "GenericModel".to_string()); + self.add_import("typing".to_string(), "Generic".to_string()); format!("GenericModel, Generic[{}]", rs.generic_types.join(", ")) } }; @@ -383,7 +247,7 @@ impl Language for Python { self.write_comments(w, true, &rs.comments, 1)?; - handle_model_config(w, &mut self.module, rs); + handle_model_config(w, self, rs); rs.fields .iter() @@ -393,8 +257,7 @@ impl Language for Python { write!(w, " pass")? } write!(w, "\n\n")?; - self.module - .add_import("pydantic".to_string(), "BaseModel".to_string()); + self.add_import("pydantic".to_string(), "BaseModel".to_string()); Ok(()) } @@ -410,8 +273,7 @@ impl Language for Python { // Write all the unit variants out (there can only be unit variants in // this case) RustEnum::Unit(shared) => { - self.module - .add_import("typing".to_string(), "Literal".to_string()); + self.add_import("typing".to_string(), "Literal".to_string()); write!( w, "{} = Literal[{}]", @@ -431,7 +293,7 @@ impl Language for Python { .collect::>() .join(", ") )?; - write!(w, "\n\n").unwrap(); + write!(w, "\n\n")?; } // Write all the algebraic variants out (all three variant types are possible // here) @@ -446,14 +308,13 @@ impl Language for Python { .generic_types .iter() .cloned() - .for_each(|v| self.module.add_type_var(v)) + .for_each(|v| self.add_type_var(v)) } let mut variants: Vec<(String, Vec)> = Vec::new(); shared.variants.iter().for_each(|variant| { match variant { RustEnumVariant::Unit(unit_variant) => { - self.module - .add_import("typing".to_string(), "Literal".to_string()); + self.add_import("typing".to_string(), "Literal".to_string()); let variant_name = format!("{}{}", shared.id.original, unit_variant.id.original); variants.push((variant_name.clone(), vec![])); @@ -469,8 +330,7 @@ impl Language for Python { ty, shared: variant_shared, } => { - self.module - .add_import("typing".to_string(), "Literal".to_string()); + self.add_import("typing".to_string(), "Literal".to_string()); let variant_name = format!("{}{}", shared.id.original, variant_shared.id.original); match ty { @@ -483,23 +343,22 @@ impl Language for Python { }) .collect(); dedup(&mut generic_parameters); - let type_vars = - self.module.get_type_vars(generic_parameters.len()); + let type_vars = self.get_type_vars(generic_parameters.len()); variants.push((variant_name.clone(), type_vars)); { if generic_parameters.is_empty() { - self.module.add_import( + self.add_import( "pydantic".to_string(), "BaseModel".to_string(), ); writeln!(w, "class {}(BaseModel):", variant_name) .unwrap(); } else { - self.module.add_import( + self.add_import( "typing".to_string(), "Generic".to_string(), ); - self.module.add_import( + self.add_import( "pydantic.generics".to_string(), "GenericModel".to_string(), ); @@ -525,18 +384,18 @@ impl Language for Python { variants.push((variant_name.clone(), generics.clone())); { if generics.is_empty() { - self.module.add_import( + self.add_import( "pydantic".to_string(), "BaseModel".to_string(), ); writeln!(w, "class {}(BaseModel):", variant_name) .unwrap(); } else { - self.module.add_import( + self.add_import( "typing".to_string(), "Generic".to_string(), ); - self.module.add_import( + self.add_import( "pydantic.generics".to_string(), "GenericModel".to_string(), ); @@ -585,7 +444,7 @@ impl Language for Python { collect_generics_for_variant(&f.ty, &shared.generic_types) }) .count(); - let type_vars = self.module.get_type_vars(num_generic_parameters); + let type_vars = self.get_type_vars(num_generic_parameters); let name = make_anonymous_struct_name(&variant_shared.id.original); variants.push((name, type_vars)); } @@ -603,10 +462,9 @@ impl Language for Python { }) .collect::>() .join(" | ") - ) - .unwrap(); + )?; self.write_comments(w, true, &e.shared().comments, 0)?; - writeln!(w).unwrap(); + writeln!(w)?; } }; Ok(()) @@ -625,32 +483,15 @@ impl Python { fn add_imports(&mut self, tp: &str) { match tp { "Url" => { - self.module - .add_import("pydantic.networks".to_string(), "AnyUrl".to_string()); + self.add_import("pydantic.networks".to_string(), "AnyUrl".to_string()); } "DateTime" => { - self.module - .add_import("datetime".to_string(), "datetime".to_string()); + self.add_import("datetime".to_string(), "datetime".to_string()); } _ => {} } } - fn get_identifier(&self, thing: ParsedRustThing) -> String { - match thing { - ParsedRustThing::TypeAlias(alias) => alias.id.original.clone(), - ParsedRustThing::Struct(strct) => strct.id.original.clone(), - ParsedRustThing::Enum(enm) => match enm { - RustEnum::Unit(u) => u.id.original.clone(), - RustEnum::Algebraic { - tag_key: _, - content_key: _, - shared, - } => shared.id.original.clone(), - }, - } - } - fn write_field( &mut self, w: &mut dyn Write, @@ -663,16 +504,13 @@ impl Python { let python_field_name = python_property_aware_rename(&field.id.original); if field.ty.is_optional() { python_type = format!("Optional[{}]", python_type); - self.module - .add_import("typing".to_string(), "Optional".to_string()); + self.add_import("typing".to_string(), "Optional".to_string()); } python_type = match python_field_name == field.id.renamed { true => python_type, false => { - self.module - .add_import("typing".to_string(), "Annotated".to_string()); - self.module - .add_import("pydantic".to_string(), "Field".to_string()); + self.add_import("typing".to_string(), "Annotated".to_string()); + self.add_import("pydantic".to_string(), "Field".to_string()); format!( "Annotated[{}, Field(alias=\"{}\")]", python_type, field.id.renamed @@ -730,6 +568,59 @@ impl Python { } Ok(()) } + + // Idempotently insert an import + fn add_import(&mut self, module: String, identifier: String) { + self.imports.entry(module).or_default().insert(identifier); + } + + fn add_type_var(&mut self, name: String) { + self.add_import("typing".to_string(), "TypeVar".to_string()); + self.type_variables.insert(name); + } + + fn get_type_vars(&mut self, n: usize) -> Vec { + let vars: Vec = (0..n) + .map(|i| { + if i == 0 { + "T".to_string() + } else { + format!("T{}", i) + } + }) + .collect(); + vars.iter().for_each(|tv| self.add_type_var(tv.clone())); + vars + } + + fn write_all_imports(&self, w: &mut dyn Write) -> std::io::Result<()> { + let mut type_var_names: Vec = self.type_variables.iter().cloned().collect(); + type_var_names.sort(); + let type_vars: Vec = type_var_names + .iter() + .map(|name| format!("{} = TypeVar(\"{}\")", name, name)) + .collect(); + let mut imports = vec![]; + for (import_module, identifiers) in &self.imports { + let mut identifier_vec = identifiers.iter().cloned().collect::>(); + identifier_vec.sort(); + imports.push(format!( + "from {} import {}", + import_module, + identifier_vec.join(", ") + )) + } + imports.sort(); + + writeln!(w, "from __future__ import annotations\n")?; + writeln!(w, "{}\n", imports.join("\n"))?; + + match type_vars.is_empty() { + true => writeln!(w)?, + false => writeln!(w, "{}\n\n", type_vars.join("\n"))?, + }; + Ok(()) + } } static PYTHON_KEYWORDS: Lazy> = Lazy::new(|| { @@ -754,7 +645,7 @@ fn python_property_aware_rename(name: &str) -> String { } // If at least one field from within a class is changed when the serde rename is used (a.k.a the field has 2 words) then we must use aliasing and we must also use a config dict at the top level of the class. -fn handle_model_config(w: &mut dyn Write, python_module: &mut Module, rs: &RustStruct) { +fn handle_model_config(w: &mut dyn Write, python_module: &mut Python, rs: &RustStruct) { let visibly_renamed_field = rs.fields.iter().find(|f| { let python_field_name = python_property_aware_rename(&f.id.original); python_field_name != f.id.renamed