From 9ca93cc2f8a4448702bfa8ad8f9242dad57f0b0d Mon Sep 17 00:00:00 2001 From: Antonio Valentino Date: Thu, 28 Dec 2023 17:29:07 +0100 Subject: [PATCH] Hooks for code generator --- bpack/tools/codegen.py | 87 +++++++++++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 30 deletions(-) diff --git a/bpack/tools/codegen.py b/bpack/tools/codegen.py index 36e4992..75453a4 100644 --- a/bpack/tools/codegen.py +++ b/bpack/tools/codegen.py @@ -99,8 +99,11 @@ class name to be used for the generated flat record descriptor """ self._indent = " " * indent if isinstance(indent, int) else indent self._filed_names = set() - self._imports: dict = collections.defaultdict(set) - self._imports[None].add("bpack") + self.imports: dict = collections.defaultdict(set) + self.imports[None].add("bpack") + self.module_docstring = None + self.pre_code = None + self.post_code = None self._lines = [] self._setup_class_declaration(descriptor, name) @@ -113,17 +116,18 @@ def _setup_class_declaration(self, descriptor, name: Optional[str] = None): backend = get_codec_type(descriptor).__module__ codec_type = "codec" if has_codec(descriptor, Codec) else "decoder" self._lines.append(f"@{backend}.{codec_type}") + self.imports[None].add(backend) descriptor_args = [] baseunits = bpack.baseunits(descriptor).name descriptor_args.append(f"baseunits=EBaseUnits.{baseunits}") - self._imports["bpack"].add("EBaseUnits") + self.imports["bpack"].add("EBaseUnits") byteorder = bpack.byteorder(descriptor).name if bpack.byteorder(descriptor) != bpack.EByteOrder.DEFAULT: descriptor_args.append(f"byteorder=EByteOrder.{byteorder}") - self._imports["bpack"].add("EByteOrder") + self.imports["bpack"].add("EByteOrder") # TODO: bitorder @@ -144,12 +148,12 @@ def _setup_fields(self, descriptor): for fld in flat_fields_iterator(descriptor): if bpack.typing.is_annotated(fld.type): typestr = f'T["{annotated_to_str(fld.type)}"]' - self._imports["bpack"].add("T") + self.imports["bpack"].add("T") elif fld.type is bool: typestr = "bool" elif issubclass(fld.type, enum.Enum): typestr = fld.type.__name__ - self._imports[fld.type.__module__].add(typestr) + self.imports[fld.type.__module__].add(typestr) else: raise TypeError(f"unsupported field type: {fld.type!r}") @@ -166,13 +170,13 @@ def _setup_fields(self, descriptor): if fld.default_factory is not MISSING: default_str = f"default_factory={fld.default_factory}" module = fld.default_factory.__module__ - self._imports[module].add(fld.default_factory.__name__) + self.imports[module].add(fld.default_factory.__name__) elif fld.default is not MISSING: default_str = f"default={get_default_str(fld.default)}" if hasattr(fld.default, "__class__"): module = fld.default.__class__.__module__ name = fld.default.__class__.__name__ - self._imports[module].add(name) + self.imports[module].add(name) else: default_str = "" @@ -191,7 +195,7 @@ def _setup_fields(self, descriptor): if hasattr(fld.default, "__class__"): module = fld.default.__class__.__module__ name = fld.default.__class__.__name__ - self._imports[module].add(name) + self.imports[module].add(name) else: field_str = "" @@ -230,7 +234,7 @@ def _setup_methods(self, descriptor): for var in annotations.values(): module = inspect.getmodule(var) if module is not None: - self._imports[module.__name__].add(var.__name__) + self.imports[module.__name__].add(var.__name__) else: warnings.warn( f"(annotations) unable to determine the module of " @@ -246,16 +250,16 @@ def _setup_methods(self, descriptor): ) except AttributeError: module = descriptor.__module__ - self._imports[module].add(name) + self.imports[module].add(name) else: module = inspect.getmodule(var) if module is not None: if module.__name__ == name: - self._imports[None].add(name) + self.imports[None].add(name) else: - self._imports[module.__name__].add(name) + self.imports[module.__name__].add(name) elif name in global_ctx: - self._imports[descriptor.__module__].add(name) + self.imports[descriptor.__module__].add(name) elif name not in local_ctx: warnings.warn( f"({var_type}) unable to determine the " @@ -263,13 +267,13 @@ def _setup_methods(self, descriptor): ) def _get_classified_imports(self): - if self._imports.get(None): - stdlib, others = _classify_modules(self._imports[None]) + if self.imports.get(None): + stdlib, others = _classify_modules(self.imports[None]) else: stdlib = others = None from_stdlib, from_others = _classify_modules( - key for key in self._imports if key and key != "builtins" + key for key in self.imports if key and key != "builtins" ) return stdlib, others, from_stdlib, from_others @@ -285,7 +289,7 @@ def _get_import_lines( ) if from_stdlib: for key in sorted(sorted(from_stdlib), key=len): - values = sorted(self._imports[key]) + values = sorted(self.imports[key]) s = f"from {key} import {', '.join(values)}" if len(s) < line_length: import_lines.append(s) @@ -305,7 +309,7 @@ def _get_import_lines( ) if from_others: for key in sorted(sorted(from_others), key=len): - values = sorted(self._imports[key]) + values = sorted(self.imports[key]) s = f"from {key} import {', '.join(values)}" if len(s) < line_length: import_lines.append(s) @@ -318,7 +322,16 @@ def _get_import_lines( return import_lines - def get_code(self, imports: bool = False, line_length: int = 80): + def patch(self, code): + """Patch the generated code.""" + return code + + def get_code( + self, + imports: bool = False, + line_length: int = 80, + beautify: bool = True, + ): """Generate the Python source code. By default only the code for the binary record descriptor is generated. @@ -326,6 +339,10 @@ def get_code(self, imports: bool = False, line_length: int = 80): types used by the descriptor are included in the generated code. """ lines = [] + + if self.module_docstring: + lines.extend(f'"""{self.module_docstring}"""'.splitlines()) + lines.append("") if imports: import_lines = self._get_import_lines( *self._get_classified_imports(), line_length=line_length @@ -333,19 +350,29 @@ def get_code(self, imports: bool = False, line_length: int = 80): lines.extend(import_lines) lines.append("") lines.append("") + + if self.pre_code: + lines.extend(self.pre_code.splitlines()) + lines.extend(self._lines) + if self.post_code: + lines.extend(self.post_code.splitlines()) + code = "\n".join(lines) - try: - import black - except ImportError: - pass - else: - mode = black.Mode( - target_versions={black.TargetVersion.PY311}, - line_length=line_length, - ) - code = black.format_str(code, mode=mode) + code = self.patch(code) + + if beautify: + try: + import black + except ImportError: + pass + else: + mode = black.Mode( + target_versions={black.TargetVersion.PY311}, + line_length=line_length, + ) + code = black.format_str(code, mode=mode) return code