Skip to content

Commit

Permalink
Hooks for code generator
Browse files Browse the repository at this point in the history
  • Loading branch information
avalentino committed Dec 28, 2023
1 parent 6694606 commit 9ca93cc
Showing 1 changed file with 57 additions and 30 deletions.
87 changes: 57 additions & 30 deletions bpack/tools/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,11 @@ class name to be used for the generated flat record descriptor
"""
self._indent = " " * indent if isinstance(indent, int) else indent
self._filed_names = set()
self._imports: dict = collections.defaultdict(set)
self._imports[None].add("bpack")
self.imports: dict = collections.defaultdict(set)
self.imports[None].add("bpack")
self.module_docstring = None
self.pre_code = None
self.post_code = None
self._lines = []

self._setup_class_declaration(descriptor, name)
Expand All @@ -113,17 +116,18 @@ def _setup_class_declaration(self, descriptor, name: Optional[str] = None):
backend = get_codec_type(descriptor).__module__
codec_type = "codec" if has_codec(descriptor, Codec) else "decoder"
self._lines.append(f"@{backend}.{codec_type}")
self.imports[None].add(backend)

descriptor_args = []

baseunits = bpack.baseunits(descriptor).name
descriptor_args.append(f"baseunits=EBaseUnits.{baseunits}")
self._imports["bpack"].add("EBaseUnits")
self.imports["bpack"].add("EBaseUnits")

byteorder = bpack.byteorder(descriptor).name
if bpack.byteorder(descriptor) != bpack.EByteOrder.DEFAULT:
descriptor_args.append(f"byteorder=EByteOrder.{byteorder}")
self._imports["bpack"].add("EByteOrder")
self.imports["bpack"].add("EByteOrder")

# TODO: bitorder

Expand All @@ -144,12 +148,12 @@ def _setup_fields(self, descriptor):
for fld in flat_fields_iterator(descriptor):
if bpack.typing.is_annotated(fld.type):
typestr = f'T["{annotated_to_str(fld.type)}"]'
self._imports["bpack"].add("T")
self.imports["bpack"].add("T")
elif fld.type is bool:
typestr = "bool"
elif issubclass(fld.type, enum.Enum):
typestr = fld.type.__name__
self._imports[fld.type.__module__].add(typestr)
self.imports[fld.type.__module__].add(typestr)
else:
raise TypeError(f"unsupported field type: {fld.type!r}")

Expand All @@ -166,13 +170,13 @@ def _setup_fields(self, descriptor):
if fld.default_factory is not MISSING:
default_str = f"default_factory={fld.default_factory}"
module = fld.default_factory.__module__
self._imports[module].add(fld.default_factory.__name__)
self.imports[module].add(fld.default_factory.__name__)
elif fld.default is not MISSING:
default_str = f"default={get_default_str(fld.default)}"
if hasattr(fld.default, "__class__"):
module = fld.default.__class__.__module__
name = fld.default.__class__.__name__
self._imports[module].add(name)
self.imports[module].add(name)
else:
default_str = ""

Expand All @@ -191,7 +195,7 @@ def _setup_fields(self, descriptor):
if hasattr(fld.default, "__class__"):
module = fld.default.__class__.__module__
name = fld.default.__class__.__name__
self._imports[module].add(name)
self.imports[module].add(name)
else:
field_str = ""

Expand Down Expand Up @@ -230,7 +234,7 @@ def _setup_methods(self, descriptor):
for var in annotations.values():
module = inspect.getmodule(var)
if module is not None:
self._imports[module.__name__].add(var.__name__)
self.imports[module.__name__].add(var.__name__)
else:
warnings.warn(
f"(annotations) unable to determine the module of "
Expand All @@ -246,30 +250,30 @@ def _setup_methods(self, descriptor):
)
except AttributeError:
module = descriptor.__module__
self._imports[module].add(name)
self.imports[module].add(name)
else:
module = inspect.getmodule(var)
if module is not None:
if module.__name__ == name:
self._imports[None].add(name)
self.imports[None].add(name)
else:
self._imports[module.__name__].add(name)
self.imports[module.__name__].add(name)
elif name in global_ctx:
self._imports[descriptor.__module__].add(name)
self.imports[descriptor.__module__].add(name)
elif name not in local_ctx:
warnings.warn(
f"({var_type}) unable to determine the "
f"module of {name!r}"
)

def _get_classified_imports(self):
if self._imports.get(None):
stdlib, others = _classify_modules(self._imports[None])
if self.imports.get(None):
stdlib, others = _classify_modules(self.imports[None])
else:
stdlib = others = None

from_stdlib, from_others = _classify_modules(
key for key in self._imports if key and key != "builtins"
key for key in self.imports if key and key != "builtins"
)

return stdlib, others, from_stdlib, from_others
Expand All @@ -285,7 +289,7 @@ def _get_import_lines(
)
if from_stdlib:
for key in sorted(sorted(from_stdlib), key=len):
values = sorted(self._imports[key])
values = sorted(self.imports[key])
s = f"from {key} import {', '.join(values)}"
if len(s) < line_length:
import_lines.append(s)
Expand All @@ -305,7 +309,7 @@ def _get_import_lines(
)
if from_others:
for key in sorted(sorted(from_others), key=len):
values = sorted(self._imports[key])
values = sorted(self.imports[key])
s = f"from {key} import {', '.join(values)}"
if len(s) < line_length:
import_lines.append(s)
Expand All @@ -318,34 +322,57 @@ def _get_import_lines(

return import_lines

def get_code(self, imports: bool = False, line_length: int = 80):
def patch(self, code):
"""Patch the generated code."""
return code

def get_code(
self,
imports: bool = False,
line_length: int = 80,
beautify: bool = True,
):
"""Generate the Python source code.
By default only the code for the binary record descriptor is generated.
If the `imports` is set to `True`, also the import statements for all
types used by the descriptor are included in the generated code.
"""
lines = []

if self.module_docstring:
lines.extend(f'"""{self.module_docstring}"""'.splitlines())
lines.append("")
if imports:
import_lines = self._get_import_lines(
*self._get_classified_imports(), line_length=line_length
)
lines.extend(import_lines)
lines.append("")
lines.append("")

if self.pre_code:
lines.extend(self.pre_code.splitlines())

lines.extend(self._lines)

if self.post_code:
lines.extend(self.post_code.splitlines())

code = "\n".join(lines)

try:
import black
except ImportError:
pass
else:
mode = black.Mode(
target_versions={black.TargetVersion.PY311},
line_length=line_length,
)
code = black.format_str(code, mode=mode)
code = self.patch(code)

if beautify:
try:
import black
except ImportError:
pass
else:
mode = black.Mode(
target_versions={black.TargetVersion.PY311},
line_length=line_length,
)
code = black.format_str(code, mode=mode)

return code

0 comments on commit 9ca93cc

Please sign in to comment.