diff --git a/src/styx/backend/python/documentation.py b/src/styx/backend/python/documentation.py index 70ba38d..7e83e4b 100644 --- a/src/styx/backend/python/documentation.py +++ b/src/styx/backend/python/documentation.py @@ -38,7 +38,7 @@ def docs_to_docstring(docs: Documentation) -> str | None: if len(docs.literature) == 1: re += f"Literature: {docs.literature[0]}" else: - entries = '\n'.join(docs.literature) + entries = "\n".join(docs.literature) re += f"Literature:\n{entries}" if docs.urls: @@ -46,7 +46,7 @@ def docs_to_docstring(docs: Documentation) -> str | None: if len(docs.urls) == 1: re += f"URL: {docs.urls[0]}" else: - entries = '\n'.join(docs.urls) + entries = "\n".join(docs.urls) re += f"URLs:\n{entries}" if re: diff --git a/src/styx/backend/python/interface.py b/src/styx/backend/python/interface.py index b0cf777..3c18234 100644 --- a/src/styx/backend/python/interface.py +++ b/src/styx/backend/python/interface.py @@ -30,13 +30,7 @@ def _compile_struct( metadata_symbol: str, root_function: bool, ) -> None: - has_outputs = struct_has_outputs(param) - if has_outputs: - _compile_outputs_class( - param=param, - interface_module=interface_module, - lookup=lookup, - ) + has_outputs = root_function or struct_has_outputs(param) outputs_type = lookup.py_output_type[param.param.id_] @@ -53,14 +47,14 @@ def _compile_struct( name="run", docstring_body="Build command line arguments. This method is called by the main command.", return_type="list[str]", - return_descr=None, + return_descr="Command line arguments", args=[ PyArg(name="self", type=None, default=None, docstring="The sub-command object."), PyArg(name="execution", type="Execution", default=None, docstring="The execution object."), ], ) struct_class = PyDataClass( - name=lookup.py_type[param.param.id_], + name=lookup.py_struct_type[param.param.id_], docstring=docs_to_docstring(param.param.docs), methods=[func_cargs_building], ) @@ -75,7 +69,6 @@ def _compile_struct( PyArg(name="execution", type="Execution", default=None, docstring="The execution object."), ], ) - struct_class.methods.append(func_outputs) pyargs = struct_class.fields # Collect param python symbols @@ -110,10 +103,18 @@ def _compile_struct( struct_compile_constraint_checks(func=func_cargs_building, struct=param, lookup=lookup) - func_cargs_building.body.extend([ - "runner = runner or get_global_runner()", - f"execution = runner.start_execution({metadata_symbol})", - ]) + if has_outputs: + _compile_outputs_class( + param=param, + interface_module=interface_module, + lookup=lookup, + ) + + if root_function: + func_cargs_building.body.extend([ + "runner = runner or get_global_runner()", + f"execution = runner.start_execution({metadata_symbol})", + ]) _compile_cargs_building(param, lookup, func_cargs_building, access_via_self=not root_function) @@ -141,10 +142,12 @@ def _compile_struct( func_outputs.body.extend([ "return ret", ]) + struct_class.methods.append(func_outputs) func_cargs_building.body.extend([ "return cargs", ]) interface_module.funcs_and_classes.append(struct_class) + interface_module.exports.append(struct_class.name) def _compile_cargs_building( @@ -209,6 +212,7 @@ def _compile_outputs_class( outputs_class = PyDataClass( name=lookup.py_output_type[param.param.id_], docstring=f"Output object returned when calling `{lookup.py_type[param.param.id_]}(...)`.", + is_named_tuple=True, ) outputs_class.fields.append( PyArg( @@ -243,7 +247,7 @@ def _compile_outputs_class( ) ) - interface_module.header.extend(blank_before(outputs_class.generate(True), 2)) + interface_module.funcs_and_classes.append(outputs_class) interface_module.exports.append(outputs_class.name) diff --git a/src/styx/backend/python/lookup.py b/src/styx/backend/python/lookup.py index 0985e4c..e8399cb 100644 --- a/src/styx/backend/python/lookup.py +++ b/src/styx/backend/python/lookup.py @@ -32,6 +32,9 @@ def _collect_py_symbol(param: ir.IStruct | ir.IParam, lookup_py_symbol: dict[ir. self.param: dict[ir.IdType, ir.IParam] = {interface.command.param.id_: interface.command} """Find param object by its ID. IParam.id_ -> IParam""" + self.py_struct_type: dict[ir.IdType, str] = {interface.command.param.id_: function_symbol} + """Find Python struct type by param id. IParam.id_ -> Python type + (this is different from py_type because of optionals and lists)""" self.py_type: dict[ir.IdType, str] = {interface.command.param.id_: function_symbol} """Find Python type by param id. IParam.id_ -> Python type""" self.py_symbol: dict[ir.IdType, str] = {} @@ -58,10 +61,11 @@ def _collect_py_symbol(param: ir.IStruct | ir.IParam, lookup_py_symbol: dict[ir. self.param[elem.param.id_] = elem if isinstance(elem, ir.IStruct): - if elem.param.id_ not in self.py_type: # Struct unions may resolve these first - self.py_type[elem.param.id_] = package_scope.add_or_dodge( + if elem.param.id_ not in self.py_struct_type: # Struct unions may resolve these first + self.py_struct_type[elem.param.id_] = package_scope.add_or_dodge( python_pascalize(f"{interface.command.struct.name}_{elem.struct.name}") ) + self.py_type[elem.param.id_] = param_py_type(elem, self.py_struct_type) self.py_output_type[elem.param.id_] = package_scope.add_or_dodge( python_pascalize(f"{interface.command.struct.name}_{elem.struct.name}_Outputs") ) @@ -75,9 +79,10 @@ def _collect_py_symbol(param: ir.IStruct | ir.IParam, lookup_py_symbol: dict[ir. ) elif isinstance(elem, ir.IStructUnion): for alternative in elem.alts: - self.py_type[alternative.param.id_] = package_scope.add_or_dodge( + self.py_struct_type[alternative.param.id_] = package_scope.add_or_dodge( python_pascalize(f"{interface.command.struct.name}_{alternative.struct.name}") ) - self.py_type[elem.param.id_] = param_py_type(elem, self.py_type) + self.py_type[alternative.param.id_] = param_py_type(alternative, self.py_struct_type) + self.py_type[elem.param.id_] = param_py_type(elem, self.py_struct_type) else: - self.py_type[elem.param.id_] = param_py_type(elem, self.py_type) + self.py_type[elem.param.id_] = param_py_type(elem, self.py_struct_type) diff --git a/src/styx/backend/python/pycodegen/core.py b/src/styx/backend/python/pycodegen/core.py index df67bde..d533591 100644 --- a/src/styx/backend/python/pycodegen/core.py +++ b/src/styx/backend/python/pycodegen/core.py @@ -153,8 +153,9 @@ class PyDataClass(PyGen): docstring: str fields: list[PyArg] = field(default_factory=list) methods: list[PyFunc] = field(default_factory=list) + is_named_tuple: bool = False - def generate(self, named_tuple: bool = False) -> LineBuffer: + def generate(self) -> LineBuffer: # Sort fields so default arguments come last self.fields.sort(key=lambda a: a.default is not None) @@ -166,7 +167,7 @@ def _arg_docstring(arg: PyArg) -> LineBuffer: args = concat([[f.declaration(), *_arg_docstring(f)] for f in self.fields]) methods = concat([method.generate() for method in self.methods], [""]) - if not named_tuple: + if not self.is_named_tuple: buf = [ "@dataclasses.dataclass", f"class {self.name}:", diff --git a/src/styx/backend/python/utils.py b/src/styx/backend/python/utils.py index 39771a7..3b9e1c1 100644 --- a/src/styx/backend/python/utils.py +++ b/src/styx/backend/python/utils.py @@ -37,7 +37,7 @@ def _base() -> str: if isinstance(param, ir.IInt): if param.choices: return f"typing.Literal[{', '.join(map(as_py_literal, param.choices))}]" - return "str" + return "int" if isinstance(param, ir.IFloat): return "float" if isinstance(param, ir.IFile): @@ -55,6 +55,7 @@ def _base() -> str: type_ = f"list[{type_}]" if isinstance(param, ir.IOptional): type_ = f"{type_} | None" + return type_ diff --git a/src/styx/frontend/boutiques/core.py b/src/styx/frontend/boutiques/core.py index 6596dd5..443097e 100644 --- a/src/styx/frontend/boutiques/core.py +++ b/src/styx/frontend/boutiques/core.py @@ -190,6 +190,7 @@ def _arg_elem_from_bt_elem( default_value=d.get("default-value", ir.IOptional.SetToNone) if input_type.is_optional else d.get("default-value"), + choices=choices, ) case InputTypePrimitive.Integer: @@ -209,6 +210,7 @@ def _arg_elem_from_bt_elem( else d.get("default-value"), min_value=constraints.value_min, max_value=constraints.value_max, + choices=choices, ) case InputTypePrimitive.Float: