From 1c93f0ee35bd2eee233db9f773d02fa4be53619d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konrad=20Ja=C5=82owiecki?= Date: Wed, 3 Jul 2024 17:15:00 +0200 Subject: [PATCH] Make local_variables a dictionary instead of a list of strings --- src/bartiq/_routine.py | 2 +- src/bartiq/compilation/_compile.py | 2 +- src/bartiq/compilation/_symbolic_function.py | 45 ++-- src/bartiq/integrations/latex.py | 9 +- src/bartiq/symbolics/utilities.py | 17 +- src/bartiq/verification.py | 4 +- tests/compilation/data/compile_test_data.yaml | 10 +- .../compilation/data/evaluate_test_data.yaml | 4 +- tests/compilation/test_compile.py | 10 +- tests/compilation/test_core.py | 253 +++++++++--------- tests/compilation/test_symbolic_function.py | 94 +++---- tests/integrations/test_latex.py | 24 +- tests/test_verification.py | 2 +- 13 files changed, 231 insertions(+), 245 deletions(-) diff --git a/src/bartiq/_routine.py b/src/bartiq/_routine.py index 886acc3..8fe87e8 100644 --- a/src/bartiq/_routine.py +++ b/src/bartiq/_routine.py @@ -217,7 +217,7 @@ class Routine(BaseModel): connections: list[Connection] = Field(default_factory=list) resources: dict[str, Resource] = Field(default_factory=dict) input_params: Sequence[Symbol] = Field(default_factory=list) - local_variables: list[str] = Field(default_factory=list) + local_variables: dict[str, str] = Field(default_factory=dict) linked_params: dict[Symbol, list[tuple[str, Symbol]]] = Field(default_factory=dict) meta: Optional[dict[str, Any]] = Field(default_factory=dict) diff --git a/src/bartiq/compilation/_compile.py b/src/bartiq/compilation/_compile.py index e11100e..f49dce5 100644 --- a/src/bartiq/compilation/_compile.py +++ b/src/bartiq/compilation/_compile.py @@ -433,7 +433,7 @@ def _compile_routine_with_functions( routine.symbolic_function = None for subroutine in routine.walk(): - subroutine.local_variables = [] + subroutine.local_variables = {} return routine diff --git a/src/bartiq/compilation/_symbolic_function.py b/src/bartiq/compilation/_symbolic_function.py index 6c5d502..b9a906c 100644 --- a/src/bartiq/compilation/_symbolic_function.py +++ b/src/bartiq/compilation/_symbolic_function.py @@ -86,8 +86,8 @@ def __eq__(self, other: object) -> bool: # def from_str(cls: Type[T], ...) -> T[T_expr] # But TypeVars cannot have arguments. @classmethod - def from_str( - cls, inputs: list[str], outputs: list[str], backend: SymbolicBackend[T_expr] + def assemble( + cls, inputs: list[str], outputs: dict[str, str], backend: SymbolicBackend[T_expr] ) -> SymbolicFunction[T_expr]: """Creates a SymbolicFunction instance from lists of easy-to-write strings. @@ -100,9 +100,12 @@ def from_str( """ return cls(_parse_input_expressions(inputs), parse_output_expressions(outputs, backend)) - def to_str(self) -> tuple[list[str], list[str]]: + def to_str(self) -> tuple[list[str], dict[str, str]]: """Serialises the SymbolicFunction to a string (in the format required by ``from_str``).""" - return (_serialize_variables(self.inputs), _serialize_variables(self.outputs)) + return ( + [var.symbol for var in self.inputs.values()], + {var.symbol: str(var.expression) for var in self.outputs.values()}, + ) def __repr__(self) -> str: inputs = list(self._inputs.values()) @@ -160,7 +163,7 @@ def parse_output_expressions( output_expressions: list[str], backend: SymbolicBackend[T_expr] ) -> list[DependentVariable[T_expr]]: """Parses a list of output expressions to a dictionary mapping output symbols to their expressions.""" - return [parse_output_expression(output_expression, backend) for output_expression in output_expressions] + return [DependentVariable(key, backend.as_expression(value), backend) for key, value in output_expressions.items()] def parse_output_expression(output_expression: str, backend: SymbolicBackend[T_expr]) -> DependentVariable[T_expr]: @@ -569,7 +572,7 @@ def _make_cost_variables( """Compiles a cost variable, taking into account any local parameters.""" # This allows users to reuse costs in subsequent expressions. known_params = {local_param.symbol: local_param for local_param in local_params} - costs = _resources_to_cost_expressions(resources) + costs = {resource.name: resource.value for resource in resources} new_cost_variables = [] for old_output_variable in parse_output_expressions(costs, backend): # Substitute any local parameters @@ -587,13 +590,6 @@ def _make_cost_variables( return new_cost_variables -def _resources_to_cost_expressions(resources: list[Resource]) -> list[str]: - expressions = [] - for resource in resources: - expressions.append(f"{resource.name} = {resource.value}") - return expressions - - def _make_output_register_size_variables( output_ports: dict[str, Port], local_params: list[DependentVariable[T_expr]], @@ -602,14 +598,14 @@ def _make_output_register_size_variables( """Compiles an output register size variables, taking into account any local parameters.""" output_register_sizes = {key: port.size for key, port in output_ports.items() if port.size is not None} - output_expression_strs = [ - f"{_get_output_name(output)} = {expression_str}" for output, expression_str in output_register_sizes.items() - ] + output_expression_map = { + _get_output_name(output): expression_str for output, expression_str in output_register_sizes.items() + } # Next, substitute in any local params known_params = {local_param.symbol: local_param for local_param in local_params} return [ _substitute_local_parameters(output, known_params) - for output in parse_output_expressions(output_expression_strs, backend) + for output in parse_output_expressions(output_expression_map, backend) ] @@ -631,16 +627,13 @@ def _make_input_register_size_variables( """ input_register_sizes = {key: port.size for key, port in input_ports.items() if port.size is not None} # First, remove all the #in_ prefixes to get the corresponding output variables - output_expression_strs = [ - f"{input} = {expression_str}" - for input, expression_str in _get_non_trivial_input_register_sizes(input_register_sizes).items() - ] + output_expression_map = _get_non_trivial_input_register_sizes(input_register_sizes) # Next, substitute in any local params known_params = {local_param.symbol: local_param for local_param in local_params} return [ _substitute_local_parameters(output, known_params) - for output in parse_output_expressions(output_expression_strs, backend) + for output in parse_output_expressions(output_expression_map, backend) ] @@ -666,12 +659,12 @@ def _make_input_register_constants( """Identifies constant input register sizes and formats them as output variables.""" input_register_sizes = {key: port.size for key, port in input_ports.items() if port.size is not None} - output_expression_strs = [ - f"#{inpt} = {expression_str}" + output_expressions = { + f"#{inpt}": expression_str for inpt, expression_str in input_register_sizes.items() if is_constant_int(expression_str) - ] - return parse_output_expressions(output_expression_strs, backend) + } + return parse_output_expressions(output_expressions, backend) def _substitute_local_parameters(output_variable, local_params): diff --git a/src/bartiq/integrations/latex.py b/src/bartiq/integrations/latex.py index 51367a2..3c5ba12 100644 --- a/src/bartiq/integrations/latex.py +++ b/src/bartiq/integrations/latex.py @@ -15,7 +15,6 @@ from sympy import latex, symbols from .._routine import Routine -from ..compilation._utilities import split_equation from ..symbolics.sympy_interpreter import parse_to_sympy @@ -82,10 +81,10 @@ def _format_port_sizes(ports, label): def _format_local_variables(local_variables): """Formats routine's local variables to LaTeX.""" - lines = [] - for variable in local_variables: - assignment, expression = split_equation(variable) - lines.append(f"&{_format_param_math(assignment)} = {_latex_expression(expression)}") + lines = [ + f"&{_format_param_math(symbol)} = {_latex_expression(expression)}" + for symbol, expression in local_variables.items() + ] return _format_section_multi_line("Local variables", lines) diff --git a/src/bartiq/symbolics/utilities.py b/src/bartiq/symbolics/utilities.py index c904a31..af751a2 100644 --- a/src/bartiq/symbolics/utilities.py +++ b/src/bartiq/symbolics/utilities.py @@ -32,21 +32,18 @@ def _split_equation(equation: str) -> tuple[str, str]: def infer_subresources(routine: Routine, backend): """Infer what are the resources of a routine's children.""" - expressions = [resource.value for resource in routine.resources.values()] - for variable in routine.local_variables: - _, rhs = _split_equation(variable) - expressions.append(rhs) + expressions = [*[resource.value for resource in routine.resources.values()], *routine.local_variables.values()] # Any path-prefixed variable (i.e. prefixed by a .-separated path) not # in subresources, but found in the RHS of an expression in either costs, # local_variables, or output ports. - subresources = [] - for expr in expressions: - vars = _extract_input_variables_from_expression(expr, backend) + subresources = [ + var + for expr in expressions + for var in _extract_input_variables_from_expression(expr, backend) # Only consider variables that are subresources (ones that have a "." in the name). - for var in vars: - if "." in var: - subresources.append(var) + if "." in var + ] return sorted(set(subresources)) diff --git a/src/bartiq/verification.py b/src/bartiq/verification.py index d3902bf..1de1a40 100644 --- a/src/bartiq/verification.py +++ b/src/bartiq/verification.py @@ -104,9 +104,9 @@ def _verify_expressions_parsable(routine: Routine, backend: SymbolicBackend) -> ] local_variable_problems = [ _verify_expression( - backend, local_variable, local_variable.split("=")[1], "local_variable", subroutine.absolute_path() + backend, f"{variable} = {expression}", expression, "local_variable", subroutine.absolute_path() ) - for local_variable in subroutine.local_variables + for variable, expression in subroutine.local_variables.items() ] port_problems = [ _verify_expression(backend, port, port.size, "port size", subroutine.absolute_path()) diff --git a/tests/compilation/data/compile_test_data.yaml b/tests/compilation/data/compile_test_data.yaml index 978abe4..71776ca 100644 --- a/tests/compilation/data/compile_test_data.yaml +++ b/tests/compilation/data/compile_test_data.yaml @@ -648,7 +648,7 @@ type: other value: {type: str, value: k} input_params: [x, y] - local_variables: [i = x + y, j = x - y, k = (i + j) / 2 + (i - j) / 2] + local_variables: {"i": "x + y", "j": "x - y", "k": "(i + j) / 2 + (i - j) / 2"} b: name: b type: null @@ -658,7 +658,7 @@ type: other value: {type: str, value: k} input_params: [x, y] - local_variables: [i = x + y, j = x - y, k = (i + j) / 2 + (i - j) / 2] + local_variables: {"i": "x + y", "j": "x - y", "k": "(i + j) / 2 + (i - j) / 2"} resources: z: name: z @@ -677,7 +677,7 @@ type: other value: {type: str, value: k} input_params: [x, y] - local_variables: [i = x + y, j = x - y, k = (i + j) / 2 + (i - j) / 2] + local_variables: {"i": "x + y", "j": "x - y", "k": "(i + j) / 2 + (i - j) / 2"} b: name: b type: null @@ -687,7 +687,7 @@ type: other value: {type: str, value: k} input_params: [x, y] - local_variables: [i = x + y, j = x - y, k = (i + j) / 2 + (i - j) / 2] + local_variables: {"i": "x + y", "j": "x - y", "k": "(i + j) / 2 + (i - j) / 2"} resources: z: name: z @@ -827,7 +827,7 @@ name: z type: other value: {type: str, value: 3 * M} - local_variables: [M = 2 * N] + local_variables: {"M": "2 * N"} - name: root type: null ports: diff --git a/tests/compilation/data/evaluate_test_data.yaml b/tests/compilation/data/evaluate_test_data.yaml index 9dab1da..c7fefd9 100644 --- a/tests/compilation/data/evaluate_test_data.yaml +++ b/tests/compilation/data/evaluate_test_data.yaml @@ -929,13 +929,13 @@ - - name: root type: null input_params: [L] - local_variables: ['N=L/multiplicity(2,L)'] + local_variables: {'N': 'L/multiplicity(2,L)'} resources: T_gates: {name: T_gates, type: additive, value: 8*ceiling(log_2(N))} - [L = 10] - name: root type: null - local_variables: ['N=L/multiplicity(2,L)'] + local_variables: {'N': 'L/multiplicity(2,L)'} resources: T_gates: {name: T_gates, type: additive, value: '32'} # Linked params through multiple generations diff --git a/tests/compilation/test_compile.py b/tests/compilation/test_compile.py index 0e75abf..4876424 100644 --- a/tests/compilation/test_compile.py +++ b/tests/compilation/test_compile.py @@ -62,25 +62,25 @@ def f_3_optional_inputs(a, b=2, c=3): DEFINED_EXPRESSION_FUNCTIONS_TEST_DATA = [ # Testing function which can be interpreted symbolically ( - SymbolicFunction.from_str(["a", "b"], ["x = f(a) + f(b)"], BACKEND), + SymbolicFunction.assemble(["a", "b"], {"x": "f(a) + f(b)"}, BACKEND), {"f": f_1_simple}, {"x": "a + b + 2"}, ), # Testing function which cannot be interpreted symbolically due to having a condition ( - SymbolicFunction.from_str(["a"], ["x = f(a) + a"], BACKEND), + SymbolicFunction.assemble(["a"], {"x": "f(a) + a"}, BACKEND), {"f": f_2_conditional}, {"x": "a + f(a)"}, ), # Testing function with multiple inputs, some with default values ["x = a**2 + 2*a + 3*b"] ( - SymbolicFunction.from_str(["a", "b"], ["x = f(a, b) + f(a, a, a)"], BACKEND), + SymbolicFunction.assemble(["a", "b"], {"x": "f(a, b) + f(a, a, a)"}, BACKEND), {"f": f_3_optional_inputs}, {"x": "a ^ 2 + 2*a + 3*b"}, ), # Testing nested calls of a simple function ( - SymbolicFunction.from_str(["a"], ["x = f(f(f(a)))"], BACKEND), + SymbolicFunction.assemble(["a"], {"x": "f(f(f(a)))"}, BACKEND), {"f": f_1_simple}, {"x": "a + 3"}, ), @@ -106,7 +106,7 @@ def mad_max(a, b): "function, functions_map, expected_error", [ ( - SymbolicFunction.from_str(["a", "b"], ["x = max(a, b)"], BACKEND), + SymbolicFunction.assemble(["a", "b"], {"x": "max(a, b)"}, BACKEND), {"max": mad_max}, "Attempted to redefine built-in function max", ) diff --git a/tests/compilation/test_core.py b/tests/compilation/test_core.py index 590a30a..8d036c6 100644 --- a/tests/compilation/test_core.py +++ b/tests/compilation/test_core.py @@ -20,7 +20,6 @@ SymbolicFunction, _get_renamed_inputs_and_outputs, _merge_functions, - _serialize_variables, compile_functions, rename_variables, ) @@ -36,14 +35,14 @@ # Null case ( [], # Inputs - [], # Outputs + {}, # Outputs {}, # Expected inputs {}, # Expected outputs ), # Input-only case ( ["a", "b", "c"], # Inputs - [], # Outputs + {}, # Outputs { # Expected inputs "a": IndependentVariable("a"), "b": IndependentVariable("b"), @@ -54,7 +53,7 @@ # Output-only case ( [], # Inputs - ["a = 1", "b = 2"], # Outputs + {"a": "1", "b": "2"}, # Outputs {}, # Expected inputs { # Expected outputs "a": "a = 1", @@ -64,7 +63,7 @@ # Full case ( ["x", "y"], # Inputs - ["a = x + y", "b = x - y"], # Outputs + {"a": "x + y", "b": "x - y"}, # Outputs { # Expected inputs "x": IndependentVariable("x"), "y": IndependentVariable("y"), @@ -78,9 +77,9 @@ @pytest.mark.parametrize("inputs, outputs, expected_inputs, expected_outputs", FROM_STR_TEST_CASES) -def test_SymbolicFunction_from_str(inputs, outputs, expected_inputs, expected_outputs, backend): +def test_SymbolicFunction_assemble(inputs, outputs, expected_inputs, expected_outputs, backend): expected_outputs = {k: DependentVariable.from_str(v, backend) for k, v in expected_outputs.items()} - function = SymbolicFunction.from_str(inputs, outputs, backend) + function = SymbolicFunction.assemble(inputs, outputs, backend) assert function.inputs == expected_inputs assert function.outputs == expected_outputs @@ -93,157 +92,155 @@ def test_SymbolicFunction_from_str(inputs, outputs, expected_inputs, expected_ou EQUALITY_TEST_CASES = [ # Null case ( - ([], []), - ([], []), + ([], {}), + ([], {}), ), # Permuted inputs ( - (["a", "b"], []), - (["b", "a"], []), + (["a", "b"], {}), + (["b", "a"], {}), ), # Permuted outputs ( - ([], ["a = 42", "b = 24"]), - ([], ["b = 24", "a = 42"]), + ([], {"a": "42", "b": "24"}), + ([], {"b": "24", "a": "42"}), ), # Permuted inputs and outputs ( - (["a", "b"], ["c = a", "d = b"]), - (["b", "a"], ["d = b", "c = a"]), + (["a", "b"], {"c": "a", "d": "b"}), + (["b", "a"], {"d": "b", "c": "a"}), ), # Addition vs multiplication ( - (["a"], ["b = a + a"]), - (["a"], ["b = 2 * a"]), + (["a"], {"b": "a + a"}), + (["a"], {"b": "2 * a"}), ), ] @pytest.mark.parametrize("function_1, function_2", EQUALITY_TEST_CASES) def test_SymbolicFunction_equality(function_1, function_2, backend): - function_1 = SymbolicFunction.from_str(*function_1, backend) - function_2 = SymbolicFunction.from_str(*function_2, backend) + function_1 = SymbolicFunction.assemble(*function_1, backend) + function_2 = SymbolicFunction.assemble(*function_2, backend) assert function_1 == function_2 ERRORS_TEST_CASES = [ # Output references unknown variables - (([], ["b = a"]), BartiqCompilationError, "Expressions must not contain unknown variables"), + (([], {"b": "a"}), BartiqCompilationError, "Expressions must not contain unknown variables"), # No duplicate inputs - ((["a", "a"], []), BartiqCompilationError, "Variable list contains repeated symbol"), - # No duplicate outputs - (([], ["a = 0", "a = 1"]), BartiqCompilationError, "Variable list contains repeated symbol"), + ((["a", "a"], {}), BartiqCompilationError, "Variable list contains repeated symbol"), # Outputs cannot share names with inputs - ((["a"], ["a = 0"]), BartiqCompilationError, "Outputs must not reuse input symbols"), + ((["a"], {"a": "0"}), BartiqCompilationError, "Outputs must not reuse input symbols"), ] @pytest.mark.parametrize("function, exception, match", ERRORS_TEST_CASES) def test_SymbolicFunction_errors(function, exception, match, backend): with pytest.raises(exception, match=match): - SymbolicFunction.from_str(*function, backend) + SymbolicFunction.assemble(*function, backend) MERGE_FUNCTIONS_TEST_CASES = [ # Null case ( - ([], []), - ([], []), - ([], []), + ([], {}), + ([], {}), + ([], {}), ), # Constant function from two unconnected functions ( - (["a"], []), - ([], ["b = 42"]), - (["a"], ["b = 42"]), + (["a"], {}), + ([], {"b": "42"}), + (["a"], {"b": "42"}), ), # Constant function from connected functions ( - (["a"], ["b = a + 1"]), - (["b"], ["c = 42"]), - (["a"], ["c = 42"]), + (["a"], {"b": "a + 1"}), + (["b"], {"c": "42"}), + (["a"], {"c": "42"}), ), # Null outer function from non-null functions ( - ([], ["a = 42"]), - (["a"], []), - ([], []), + ([], {"a": "42"}), + (["a"], {}), + ([], {}), ), # Trivial case 1 ( - (["a"], ["b = a"]), - ([], []), - (["a"], ["b = a"]), + (["a"], {"b": "a"}), + ([], {}), + (["a"], {"b": "a"}), ), # Trivial case 2 ( - ([], []), - (["a"], ["b = a"]), - (["a"], ["b = a"]), + ([], {}), + (["a"], {"b": "a"}), + (["a"], {"b": "a"}), ), # Simple linear case ( - (["a"], ["b = a"]), - (["b"], ["c = b"]), - (["a"], ["c = a"]), + (["a"], {"b": "a"}), + (["b"], {"c": "b"}), + (["a"], {"c": "a"}), ), # Tensor of two functions ( - (["a"], ["b = a + 1"]), - (["c"], ["d = c + 2"]), - (["a", "c"], ["b = a + 1", "d = c + 2"]), + (["a"], {"b": "a + 1"}), + (["c"], {"d": "c + 2"}), + (["a", "c"], {"b": "a + 1", "d": "c + 2"}), ), # Simple single-variable addition ( - (["x"], ["y = x + 1"]), - (["y"], ["z = y + 1"]), - (["x"], ["z = x + 2"]), + (["x"], {"y": "x + 1"}), + (["y"], {"z": "y + 1"}), + (["x"], {"z": "x + 2"}), ), # Nested functions ( - (["x"], ["y = f(x)"]), - (["y"], ["z = g(y)"]), - (["x"], ["z = g(f(x))"]), + (["x"], {"y": "f(x)"}), + (["y"], {"z": "g(y)"}), + (["x"], {"z": "g(f(x))"}), ), # Both functions introduce new variables ( - (["a", "b"], ["c = a + b"]), - (["c", "d"], ["e = c + d"]), - (["a", "b", "d"], ["e = a + b + d"]), + (["a", "b"], {"c": "a + b"}), + (["c", "d"], {"e": "c + d"}), + (["a", "b", "d"], {"e": "a + b + d"}), ), # Both functions define outputs ( - (["a"], ["b = f(a)", "c = g(a)"]), - (["b"], ["d = h(b)"]), - (["a"], ["c = g(a)", "d = h(f(a))"]), + (["a"], {"b": "f(a)", "c": "g(a)"}), + (["b"], {"d": "h(b)"}), + (["a"], {"c": "g(a)", "d": "h(f(a))"}), ), # Automatic simplification (woooooahhh!) ( - (["x"], ["y = log(x)"]), - (["y"], ["z = exp(y)"]), - (["x"], ["z = x"]), + (["x"], {"y": "log(x)"}), + (["y"], {"z": "exp(y)"}), + (["x"], {"z": "x"}), ), # Allow for merged functions to take same inputs ( - (["x"], []), - (["x"], []), - (["x"], []), + (["x"], {}), + (["x"], {}), + (["x"], {}), ), # Allow for merged functions to take same outputs if they have the same expressions ( - ([], ["x = 0"]), - ([], ["x = 0"]), - ([], ["x = 0"]), + ([], {"x": "0"}), + ([], {"x": "0"}), + ([], {"x": "0"}), ), ] @pytest.mark.parametrize("base_func, target_func, expected_func", MERGE_FUNCTIONS_TEST_CASES) def test_merge_functions(base_func, target_func, expected_func, backend): - base_func = SymbolicFunction.from_str(*base_func, backend) - target_func = SymbolicFunction.from_str(*target_func, backend) - expected_func = SymbolicFunction.from_str(*expected_func, backend) + base_func = SymbolicFunction.assemble(*base_func, backend) + target_func = SymbolicFunction.assemble(*target_func, backend) + expected_func = SymbolicFunction.assemble(*expected_func, backend) assert _merge_functions(base_func, target_func) == expected_func @@ -251,14 +248,14 @@ def test_merge_functions(base_func, target_func, expected_func, backend): MERGE_FUNCTION_ERRORS_TEST_CASES = [ # Target function has output already defined as base input ( - (["y"], ["z = g(y)"]), - (["x"], ["y = f(x)"]), + (["y"], {"z": "g(y)"}), + (["x"], {"y": "f(x)"}), "Target function outputs must not reference base function inputs when merging", ), # Functions have same output symbols, but different expressions ( - ([], ["x = 1"]), - ([], ["x = 0"]), + ([], {"x": "1"}), + ([], {"x": "0"}), "Merging functions may only have same outputs if the outputs share the same expression", ), ] @@ -266,8 +263,8 @@ def test_merge_functions(base_func, target_func, expected_func, backend): @pytest.mark.parametrize("base_func, target_func, match", MERGE_FUNCTION_ERRORS_TEST_CASES) def test_merge_functions_errors(base_func, target_func, match, backend): - base_func = SymbolicFunction.from_str(*base_func, backend) - target_func = SymbolicFunction.from_str(*target_func, backend) + base_func = SymbolicFunction.assemble(*base_func, backend) + target_func = SymbolicFunction.assemble(*target_func, backend) with pytest.raises(BartiqCompilationError, match=match): _merge_functions(base_func, target_func) @@ -277,55 +274,55 @@ def test_merge_functions_errors(base_func, target_func, match, backend): # Null case ( [], - ([], []), + ([], {}), ), # Nuller case ( [ - ([], []), + ([], {}), ], - ([], []), + ([], {}), ), # Nullest case ( [ - ([], []), - ([], []), + ([], {}), + ([], {}), ], - ([], []), + ([], {}), ), # The "Dude, I heard you like functions..." case ( [ - (["a"], ["b = f(a)"]), - (["b"], ["c = g(b)"]), - (["c"], ["d = h(c)"]), + (["a"], {"b": "f(a)"}), + (["b"], {"c": "g(b)"}), + (["c"], {"d": "h(c)"}), ], - (["a"], ["d = h(g(f(a)))"]), + (["a"], {"d": "h(g(f(a)))"}), ), # Expanding outputs case (via generating binary number additions) ( [ - (["a"], ["b0 = a + 0", "b1 = a + 1"]), - (["b0"], ["c00 = b0 + 0", "c10 = b0 + 2"]), - (["b1"], ["c01 = b1 + 0", "c11 = b1 + 2"]), + (["a"], {"b0": "a + 0", "b1": "a + 1"}), + (["b0"], {"c00": "b0 + 0", "c10": "b0 + 2"}), + (["b1"], {"c01": "b1 + 0", "c11": "b1 + 2"}), ], - (["a"], ["c00 = a + 0", "c01 = a + 1", "c10 = a + 2", "c11 = a + 3"]), + (["a"], {"c00": "a + 0", "c01": "a + 1", "c10": "a + 2", "c11": "a + 3"}), ), # Decreasing inputs case (via log-tree adder) ( [ - (["a000", "a001"], ["b00 = a000 + a001"]), - (["a010", "a011"], ["b01 = a010 + a011"]), - (["a100", "a101"], ["b10 = a100 + a101"]), - (["a110", "a111"], ["b11 = a110 + a111"]), - (["b00", "b01"], ["c0 = b00 + b01"]), - (["b10", "b11"], ["c1 = b10 + b11"]), - (["c0", "c1"], ["d = c0 + c1"]), + (["a000", "a001"], {"b00": "a000 + a001"}), + (["a010", "a011"], {"b01": "a010 + a011"}), + (["a100", "a101"], {"b10": "a100 + a101"}), + (["a110", "a111"], {"b11": "a110 + a111"}), + (["b00", "b01"], {"c0": "b00 + b01"}), + (["b10", "b11"], {"c1": "b10 + b11"}), + (["c0", "c1"], {"d": "c0 + c1"}), ], ( ["a000", "a001", "a010", "a011", "a100", "a101", "a110", "a111"], - ["d = a000 + a001 + a010 + a011 + a100 + a101 + a110 + a111"], + {"d": "a000 + a001 + a010 + a011 + a100 + a101 + a110 + a111"}, ), ), ] @@ -333,8 +330,8 @@ def test_merge_functions_errors(base_func, target_func, match, backend): @pytest.mark.parametrize("functions, expected_function", COMPILE_FUNCTIONS_TEST_CASES) def test_compile_functions(functions, expected_function, backend): - functions = list(map(lambda func: SymbolicFunction.from_str(*func, backend), functions)) - expected_function = SymbolicFunction.from_str(*expected_function, backend) + functions = list(map(lambda func: SymbolicFunction.assemble(*func, backend), functions)) + expected_function = SymbolicFunction.assemble(*expected_function, backend) assert compile_functions(functions) == expected_function @@ -342,35 +339,35 @@ def test_compile_functions(functions, expected_function, backend): RENAME_VARIABLES_TEST_CASES = [ # Null case ( - ([], []), + ([], {}), {}, - ([], []), + ([], {}), ), # Renaming inputs ( - (["a", "b", "c"], []), + (["a", "b", "c"], {}), {"a": "x", "c": "y"}, - (["x", "b", "y"], []), + (["x", "b", "y"], {}), ), # Renaming outputs ( - ([], ["a = 42", "b = 3.141", "c = 101"]), + ([], {"a": "42", "b": "3.141", "c": "101"}), {"a": "x", "c": "y"}, - ([], ["x = 42", "b = 3.141", "y = 101"]), + ([], {"x": "42", "b": "3.141", "y": "101"}), ), # Renaming inputs and outputs ( - (["a", "b", "c"], ["d = a + b", "e = b + c"]), + (["a", "b", "c"], {"d": "a + b", "e": "b + c"}), {"a": "x", "c": "y", "e": "z"}, - (["x", "b", "y"], ["d = x + b", "z = b + y"]), + (["x", "b", "y"], {"d": "x + b", "z": "b + y"}), ), ] @pytest.mark.parametrize("function, variable_map, expected_function", RENAME_VARIABLES_TEST_CASES) def test_rename_variables(function, variable_map, expected_function, backend): - function = SymbolicFunction.from_str(*function, backend) - expected_function = SymbolicFunction.from_str(*expected_function, backend) + function = SymbolicFunction.assemble(*function, backend) + expected_function = SymbolicFunction.assemble(*expected_function, backend) assert rename_variables(function, variable_map) == expected_function @@ -378,43 +375,43 @@ def test_rename_variables(function, variable_map, expected_function, backend): RENAME_INPUTS_AND_OUTPUTS_TEST_CASES = [ # Ensure non-duplication of inputs ( - (["a", "b"], []), + (["a", "b"], {}), {"a": "b"}, - (["b"], []), + (["b"], {}), ), # Ensure non-duplication of outputs ( - ([], ["x = 1", "y = 1"]), + ([], {"x": "1", "y": "1"}), {"x": "y"}, - ([], ["y = 1"]), + ([], {"y": "1"}), ), # Ensure simultaneous non-duplication of outputs ( - (["a", "b"], ["x = a", "y = b"]), + (["a", "b"], {"x": "a", "y": "b"}), {"a": "b", "x": "y"}, - (["b"], ["y = b"]), + (["b"], {"y": "b"}), ), # Renaming inputs and outputs ( - (["a", "b", "c"], ["d = a + b", "e = b + c"]), + (["a", "b", "c"], {"d": "a + b", "e": "b + c"}), {"a": "x", "c": "x", "e": "z"}, - (["x", "b"], ["d = b + x", "z = b + x"]), + (["x", "b"], {"d": "b + x", "z": "b + x"}), ), # Renaming with cycle ( - (["a", "b", "c"], ["d = a + b", "e = b + c"]), + (["a", "b", "c"], {"d": "a + b", "e": "b + c"}), {"a": "x", "c": "x", "x": "z", "z": "a"}, - (["a", "b"], ["d = a + b", "e = a + b"]), + (["a", "b"], {"d": "a + b", "e": "a + b"}), ), ] @pytest.mark.parametrize("function, variable_map, expected_results", RENAME_INPUTS_AND_OUTPUTS_TEST_CASES) def test_rename_inputs_and_outputs(function, variable_map, expected_results, backend): - function = SymbolicFunction.from_str(*function, backend) + function = SymbolicFunction.assemble(*function, backend) new_inputs, new_outputs = _get_renamed_inputs_and_outputs(function, variable_map) - serialized_inputs = _serialize_variables(new_inputs) - serialized_outputs = _serialize_variables(new_outputs) + serialized_inputs = [var.symbol for var in new_inputs.values()] + serialized_outputs = {var.symbol: str(var.expression) for var in new_outputs.values()} expected_inputs, expected_outputs = expected_results assert serialized_inputs == expected_inputs @@ -423,13 +420,13 @@ def test_rename_inputs_and_outputs(function, variable_map, expected_results, bac RENAME_INPUTS_AND_OUTPUTS_ERRORS_TEST_CASES: list[tuple[tuple, dict, str]] = [ # Output renaming would cause a conflict - (([], ["x = 1", "y = 2"]), {"x": "y"}, "Cannot rename output variable"), + (([], {"x": "1", "y": "2"}), {"x": "y"}, "Cannot rename output variable"), ] @pytest.mark.parametrize("function, variable_map, expected_error", RENAME_INPUTS_AND_OUTPUTS_ERRORS_TEST_CASES) def test_rename_inputs_and_outputs_errors(function, variable_map, expected_error, backend): - function = SymbolicFunction.from_str(*function, backend) + function = SymbolicFunction.assemble(*function, backend) with pytest.raises(BartiqCompilationError, match=re.escape(expected_error)): _get_renamed_inputs_and_outputs(function, variable_map) diff --git a/tests/compilation/test_symbolic_function.py b/tests/compilation/test_symbolic_function.py index 7eb6697..36fb845 100644 --- a/tests/compilation/test_symbolic_function.py +++ b/tests/compilation/test_symbolic_function.py @@ -51,29 +51,29 @@ def _dummy_resources(cost_strs): TO_SYMBOLIC_FUNCTION_TEST_CASES = [ # Null case - (_make_routine(), ([], [])), + (_make_routine(), ([], {})), # Simple case with no register sizes ( _make_routine(input_params=["a", "b"], resources=_dummy_resources(["x = a + b", "y = a - b"])), - (["a", "b"], ["x = a + b", "y = a - b"]), + (["a", "b"], {"x": "a + b", "y": "a - b"}), ), # No register sizes, but including local parameters ( _make_routine( input_params=["a", "b"], - local_variables=["m = a + b", "n = a - b"], + local_variables={"m": "a + b", "n": "a - b"}, resources=_dummy_resources(["x = m + n", "y = m - n"]), ), - (["a", "b"], ["x = 2 * a", "y = 2 * b"]), + (["a", "b"], {"x": "2 * a", "y": "2 * b"}), ), # No register sizes, but including self-referential local parameters ( _make_routine( input_params=["a", "b"], - local_variables=["m = a ** 2", "n = b ** 2", "w = m + n"], + local_variables={"m": "a ** 2", "n": "b ** 2", "w": "m + n"}, resources=_dummy_resources(["x = w - m", "y = w - n"]), ), - (["a", "b"], ["x = b ** 2", "y = a ** 2"]), + (["a", "b"], {"x": "b ** 2", "y": "a ** 2"}), ), # Input and output register sizes ( @@ -83,18 +83,18 @@ def _dummy_resources(cost_strs): **_ports_from_reg_sizes({"psi": "3 * N"}, "out"), } ), - (["#in_psi.N"], ["#out_psi = 3 * #in_psi.N"]), + (["#in_psi.N"], {"#out_psi": "3 * #in_psi.N"}), ), # Input and output register sizes with local parameters ( _make_routine( - local_variables=["M = 3 * N"], + local_variables={"M": "3 * N"}, ports={ **_ports_from_reg_sizes({"psi": "N"}, "in"), **_ports_from_reg_sizes({"psi": "M"}, "out"), }, ), - (["#in_psi.N"], ["#out_psi = 3 * #in_psi.N"]), + (["#in_psi.N"], {"#out_psi": "3 * #in_psi.N"}), ), # Both inputs have the same size ( @@ -104,7 +104,7 @@ def _dummy_resources(cost_strs): **_ports_from_reg_sizes({"0": "2*N"}, "out"), } ), - (["#in_0.N", "#in_1.N"], ["#out_0 = 2*#in_0.N"]), + (["#in_0.N", "#in_1.N"], {"#out_0": "2*#in_0.N"}), ), # Multiple inputs have the same size ( @@ -116,7 +116,7 @@ def _dummy_resources(cost_strs): ), ( ["#in_0.A", "#in_1.A", "#in_2.B", "#in_3.C", "#in_4.B", "#in_5.A", "#in_6.C"], - ["#out_0 = #in_0.A + #in_2.B + 2*#in_3.C"], + {"#out_0": "#in_0.A + #in_2.B + 2*#in_3.C"}, ), ), # Multiple inputs have the same size and we use input params @@ -128,26 +128,26 @@ def _dummy_resources(cost_strs): **_ports_from_reg_sizes({"0": "2*N"}, "out"), }, ), - (["#in_0.N", "#in_1.N", "a", "b"], ["#out_0 = 2*#in_0.N"]), + (["#in_0.N", "#in_1.N", "a", "b"], {"#out_0": "2*#in_0.N"}), ), # Only inputs are subresources ( _make_routine(resources=_dummy_resources(["x = a.N + a.b.N"])), - (["a.N", "a.b.N"], ["x = a.N + a.b.N"]), + (["a.N", "a.b.N"], {"x": "a.N + a.b.N"}), ), # Input params and subresources ( _make_routine(input_params=["N"], resources=_dummy_resources(["x = N + a.N + a.b.N"])), - (["N", "a.N", "a.b.N"], ["x = N + a.N + a.b.N"]), + (["N", "a.N", "a.b.N"], {"x": "N + a.N + a.b.N"}), ), # Input params, subresources, and local parameters ( _make_routine( input_params=["N"], - local_variables=["M = N + a.N + a.b.N"], + local_variables={"M": "N + a.N + a.b.N"}, resources=_dummy_resources(["x = M"]), ), - (["N", "a.N", "a.b.N"], ["x = N + a.N + a.b.N"]), + (["N", "a.N", "a.b.N"], {"x": "N + a.N + a.b.N"}), ), # The whole shebang ( @@ -171,10 +171,10 @@ def _dummy_resources(cost_strs): "out", ), }, - local_variables=[ - "M = N + a.N + a.b.N", - "C = a + b", - ], + local_variables={ + "M": "N + a.N + a.b.N", + "C": "a + b", + }, resources=_dummy_resources(["x = M + C"]), children={ "a": { @@ -186,11 +186,11 @@ def _dummy_resources(cost_strs): ), ( ["N", "a.N", "a.b.N", "#in_psi.a", "#in_phi.b"], - [ - "x = N + a.N + a.b.N + #in_psi.a + #in_phi.b", - "#out_psi = N + a.N + a.b.N", - "#out_phi = #in_psi.a + #in_phi.b", - ], + { + "x": "N + a.N + a.b.N + #in_psi.a + #in_phi.b", + "#out_psi": "N + a.N + a.b.N", + "#out_phi": "#in_psi.a + #in_phi.b", + }, ), ), # Allow reuse of cost in subsequent costs expressions @@ -206,7 +206,7 @@ def _dummy_resources(cost_strs): "out", ), }, - local_variables=["b_anc = 1"], + local_variables={"b_anc": "1"}, resources=_dummy_resources( [ "Q_anc = b_0", @@ -217,13 +217,13 @@ def _dummy_resources(cost_strs): ), ( ["#in_comp_0.b_0"], - [ - "Q_anc = #in_comp_0.b_0", - "B_anc = 1", - "Q = 2*#in_comp_0.b_0 + 1", - "#out_comp_0 = #in_comp_0.b_0", - "#out_anc = 1", - ], + { + "Q_anc": "#in_comp_0.b_0", + "B_anc": "1", + "Q": "2*#in_comp_0.b_0 + 1", + "#out_comp_0": "#in_comp_0.b_0", + "#out_anc": "1", + }, ), ), # Special case for when an input port has a constant size @@ -237,14 +237,14 @@ def _dummy_resources(cost_strs): "in", ), ), - (["#in_foo.bar"], ["#in_0 = 1"]), + (["#in_foo.bar"], {"#in_0": "1"}), ), ] @pytest.mark.parametrize("routine, expected_function", TO_SYMBOLIC_FUNCTION_TEST_CASES) def test_to_symbolic_function(routine, expected_function, backend): - expected_function = SymbolicFunction.from_str(*expected_function, backend) + expected_function = SymbolicFunction.assemble(*expected_function, backend) assert to_symbolic_function(routine, backend) == expected_function @@ -265,7 +265,7 @@ def test_to_symbolic_function(routine, expected_function, backend): # Redundant variable in both costs and local_params ( _make_routine( - local_variables=["x = 1"], + local_variables={"x": "1"}, resources=_dummy_resources(["z = x + 1", "x = 1"]), ), "Variable is redundantly defined in local_params and costs.", @@ -288,13 +288,13 @@ def test_to_symbolic_function_errors(routine, expected_error, backend): # Null case ( _make_routine(), - ([], []), + ([], {}), _make_routine(), ), # Input-only case ( _make_routine(), - (["x", "y"], []), + (["x", "y"], {}), _make_routine( input_params=["x", "y"], ), @@ -305,7 +305,7 @@ def test_to_symbolic_function_errors(routine, expected_error, backend): input_params=["x", "y"], ports=_ports_from_reg_sizes({"0": None}, "in"), ), - (["x", "y", "#in_0.z"], []), + (["x", "y", "#in_0.z"], {}), _make_routine( input_params=["x", "y"], ports=_ports_from_reg_sizes({"0": "z"}, "in"), @@ -314,7 +314,7 @@ def test_to_symbolic_function_errors(routine, expected_error, backend): # Output-only case ( _make_routine(), - ([], ["a = 42", "b = 24"]), + ([], {"a": "42", "b": "24"}), _make_routine( resources=_dummy_resources(["a = 42", "b = 24"]), ), @@ -324,7 +324,7 @@ def test_to_symbolic_function_errors(routine, expected_error, backend): _make_routine( ports=_ports_from_reg_sizes({"0": None}, "out"), ), - ([], ["a = 42", "b = 24", "#out_0 = 101"]), + ([], {"a": "42", "b": "24", "#out_0": "101"}), _make_routine( ports=_ports_from_reg_sizes({"0": "101"}, "out"), resources=_dummy_resources(["a = 42", "b = 24"]), @@ -333,7 +333,7 @@ def test_to_symbolic_function_errors(routine, expected_error, backend): # Input and output case ( _make_routine(), - (["x", "y"], ["a = x + y", "b = x - y"]), + (["x", "y"], {"a": "x + y", "b": "x - y"}), _make_routine( input_params=["x", "y"], resources=_dummy_resources(["a = x + y", "b = x - y"]), @@ -348,7 +348,7 @@ def test_to_symbolic_function_errors(routine, expected_error, backend): **_ports_from_reg_sizes({"0": None}, "out"), }, ), - (["x", "y", "#in_0.z"], ["a = x + y", "b = x - y - #in_0.z", "#out_0 = x * y * #in_0.z"]), + (["x", "y", "#in_0.z"], {"a": "x + y", "b": "x - y - #in_0.z", "#out_0": "x * y * #in_0.z"}), _make_routine( input_params=["x", "y"], ports={ @@ -369,7 +369,7 @@ def test_to_symbolic_function_errors(routine, expected_error, backend): "in", ), ), - (["#in_foo.bar"], ["#in_0 = 1"]), + (["#in_foo.bar"], {"#in_0": "1"}), _make_routine( ports=_ports_from_reg_sizes( { @@ -385,7 +385,7 @@ def test_to_symbolic_function_errors(routine, expected_error, backend): @pytest.mark.parametrize("routine, function, expected_routine", UPDATE_ROUTINE_WITH_SYMBOLIC_FUNCTION_TEST_CASES) def test_update_routine_with_symbolic_function(routine, function, expected_routine, backend): - function = SymbolicFunction.from_str(*function, backend) + function = SymbolicFunction.assemble(*function, backend) update_routine_with_symbolic_function(routine, function) assert routine == expected_routine @@ -402,12 +402,12 @@ def test_update_routine_with_symbolic_function(routine, function, expected_routi _make_routine( ports=_ports_from_reg_sizes({"0": None}, "in"), ), - (["x"], ["#in_0 = x"]), + (["x"], {"#in_0": "x"}), "Only constant-sized input register sizes supported in function outputs", ), ], ) def test_update_routine_with_symbolic_function_fails(routine, function, expected_error, backend): - function = SymbolicFunction.from_str(*function, backend) + function = SymbolicFunction.assemble(*function, backend) with pytest.raises(BartiqCompilationError, match=expected_error): update_routine_with_symbolic_function(routine, function) diff --git a/tests/integrations/test_latex.py b/tests/integrations/test_latex.py index 353cbea..b9474d4 100644 --- a/tests/integrations/test_latex.py +++ b/tests/integrations/test_latex.py @@ -93,10 +93,10 @@ Routine( name="root", input_params=["a", "b"], - local_variables=[ - "x_foo = y + a", - "y_bar = b * c", - ], + local_variables={ + "x_foo": "y + a", + "y_bar": "b * c", + }, ), {}, r""" @@ -163,10 +163,10 @@ "d": {"name": "d", "input_params": ["k_2"]}, "e": {"name": "e", "input_params": ["l_3"]}, }, - local_variables=[ - "x_foo = a.i_0 + a", - "y_bar = b * c.j_1", - ], + local_variables={ + "x_foo": "a.i_0 + a", + "y_bar": "b * c.j_1", + }, resources={ "t": {"name": "t", "value": 0, "type": "additive"}, }, @@ -196,10 +196,10 @@ ( Routine( name="root", - local_variables=[ - "a=1+2", - "b = 3+4", - ], + local_variables={ + "a": "1+2", + "b": "3+4", + }, resources={ "c": {"name": "c", "value": "a + b", "type": "additive"}, "d": {"name": "d", "value": "a-b", "type": "additive"}, diff --git a/tests/test_verification.py b/tests/test_verification.py index 59bbb5a..a099cf7 100644 --- a/tests/test_verification.py +++ b/tests/test_verification.py @@ -60,7 +60,7 @@ def test_verify_uncompiled_routine(routine): name="root", input_params=["N"], resources={"X": {"name": "X", "value": "a +", "type": "other"}}, - local_variables=["X=a*"], + local_variables={"X": "a*"}, ports={"in_0": {"name": "in_0", "direction": "input", "size": "#"}}, type=None, ),