From 4b5fa8a864cdbd358b866a6fb363cc240113f3aa Mon Sep 17 00:00:00 2001 From: mstechly Date: Wed, 15 May 2024 07:31:11 -0700 Subject: [PATCH] fix latex issue for names with underscores --- src/bartiq/integrations/latex.py | 64 ++++++++++++++------------ tests/{ => integrations}/test_latex.py | 0 2 files changed, 34 insertions(+), 30 deletions(-) rename tests/{ => integrations}/test_latex.py (100%) diff --git a/src/bartiq/integrations/latex.py b/src/bartiq/integrations/latex.py index 45cfcf1..13d6afb 100644 --- a/src/bartiq/integrations/latex.py +++ b/src/bartiq/integrations/latex.py @@ -19,6 +19,26 @@ from ..symbolics.sympy_interpreter import parse_to_sympy +def represent_routine_in_latex(routine: Routine, show_non_root_resources: bool = True) -> str: + """Returns a snippet of LaTeX used to render the routine using clear LaTeX. + + Args: + routine: The routine to render. + show_non_root_costs: If ``True`` (default), displays all costs, otherwise only includes costs + from the root node. + + Returns: + A LaTeX snippet of the routine. + """ + lines = [format_line(data) for attr_name, format_line in SECTIONS if (data := getattr(routine, attr_name))] + + # We deal with resources separately due to show_non_root_resources option + if resource_section := _format_resources(routine, show_non_root_resources): + lines.append(resource_section) + + return "\\begin{align}\n" + "\\\\\n".join(lines) + "\n\\end{align}" + + def _format_input_params(input_params: list[str]): """Formats estimator input parameters to LaTeX.""" input_params = [_format_param(input_param) for input_param in input_params] @@ -37,19 +57,19 @@ def _format_linked_params(linked_params): def _format_input_port_sizes(ports): - values = [] - for port in ports.values(): - values.append(rf"{_format_param_text(port.name)}.\!{_format_param_math(port.size)}") - return _format_section_one_line("Input ports", values) + _format_port_sizes(ports, "Input") def _format_output_port_sizes(ports): - """Returns the output register sizes formatted in LaTeX.""" + _format_port_sizes(ports, "Output") + + +def _format_port_sizes(ports, label): lines = [] for port in ports.values(): port_name = port.name lines.append(f"&{_format_param_text(port_name)} = {_latex_expression(port.size)}") - return _format_section_multi_line("Output ports", lines) + return _format_section_multi_line(f"{label} ports", lines) def _format_local_variables(local_variables): @@ -62,9 +82,6 @@ def _format_local_variables(local_variables): SECTIONS = [ - # pairs of the form (get_line_data, format_line_data) - # TODO: actually implement the functions listed below, base on the estimator-based ones further in this file - # TODO: ordering of this list matters, make sure it is correct ("input_params", _format_input_params), ("linked_params", _format_linked_params), ("input_ports", _format_input_port_sizes), @@ -73,26 +90,6 @@ def _format_local_variables(local_variables): ] -def represent_routine_in_latex(routine: Routine, show_non_root_resources: bool = True) -> str: - """Returns a snippet of LaTeX used to render the routine using clear LaTeX. - - Args: - routine: The routine to render. - show_non_root_costs: If ``True`` (default), displays all costs, otherwise only includes costs - from the root node. - - Returns: - A LaTeX snippet of the routine. - """ - lines = [format_line(data) for attr_name, format_line in SECTIONS if (data := getattr(routine, attr_name))] - - # We deal with resources separately due to show_non_root_resources option - if resource_section := _format_resources(routine, show_non_root_resources): - lines.append(resource_section) - - return "\\begin{align}\n" + "\\\\\n".join(lines) + "\n\\end{align}" - - def _format_section_one_line(header, entries): """Formats a parameter section into a bolded header followed by a comma-separated list of entries.""" return f"&\\bf\\text{{{header}:}}\\\\\n&" + ", ".join(entries) @@ -117,7 +114,14 @@ def _format_local_param(param): def _format_param_text(param): """Formats a param as text.""" - return rf"\text{{{param}}}" + if param.count("_") == 0: + return rf"\text{{{param}}}" + elif param.count("_") == 1: + splitted_param = param.split("_") + return rf"\text{{{splitted_param[0]}}}_\text{{{splitted_param[1]}}}" + else: + # TODO: add test case with more underscores to test it + return rf"\text{{{param}}}" # .replace("_", "\\_") def _format_param_math(param): diff --git a/tests/test_latex.py b/tests/integrations/test_latex.py similarity index 100% rename from tests/test_latex.py rename to tests/integrations/test_latex.py