Skip to content

Commit

Permalink
Make local_variables a dictionary instead of a list of strings
Browse files Browse the repository at this point in the history
  • Loading branch information
dexter2206 committed Jul 3, 2024
1 parent c0ee3ea commit 1c93f0e
Show file tree
Hide file tree
Showing 13 changed files with 231 additions and 245 deletions.
2 changes: 1 addition & 1 deletion src/bartiq/_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class Routine(BaseModel):
connections: list[Connection] = Field(default_factory=list)
resources: dict[str, Resource] = Field(default_factory=dict)
input_params: Sequence[Symbol] = Field(default_factory=list)
local_variables: list[str] = Field(default_factory=list)
local_variables: dict[str, str] = Field(default_factory=dict)
linked_params: dict[Symbol, list[tuple[str, Symbol]]] = Field(default_factory=dict)
meta: Optional[dict[str, Any]] = Field(default_factory=dict)

Expand Down
2 changes: 1 addition & 1 deletion src/bartiq/compilation/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def _compile_routine_with_functions(

routine.symbolic_function = None
for subroutine in routine.walk():
subroutine.local_variables = []
subroutine.local_variables = {}

return routine

Expand Down
45 changes: 19 additions & 26 deletions src/bartiq/compilation/_symbolic_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def __eq__(self, other: object) -> bool:
# def from_str(cls: Type[T], ...) -> T[T_expr]
# But TypeVars cannot have arguments.
@classmethod
def from_str(
cls, inputs: list[str], outputs: list[str], backend: SymbolicBackend[T_expr]
def assemble(
cls, inputs: list[str], outputs: dict[str, str], backend: SymbolicBackend[T_expr]
) -> SymbolicFunction[T_expr]:
"""Creates a SymbolicFunction instance from lists of easy-to-write strings.
Expand All @@ -100,9 +100,12 @@ def from_str(
"""
return cls(_parse_input_expressions(inputs), parse_output_expressions(outputs, backend))

def to_str(self) -> tuple[list[str], list[str]]:
def to_str(self) -> tuple[list[str], dict[str, str]]:
"""Serialises the SymbolicFunction to a string (in the format required by ``from_str``)."""
return (_serialize_variables(self.inputs), _serialize_variables(self.outputs))
return (
[var.symbol for var in self.inputs.values()],
{var.symbol: str(var.expression) for var in self.outputs.values()},
)

def __repr__(self) -> str:
inputs = list(self._inputs.values())
Expand Down Expand Up @@ -160,7 +163,7 @@ def parse_output_expressions(
output_expressions: list[str], backend: SymbolicBackend[T_expr]
) -> list[DependentVariable[T_expr]]:
"""Parses a list of output expressions to a dictionary mapping output symbols to their expressions."""
return [parse_output_expression(output_expression, backend) for output_expression in output_expressions]
return [DependentVariable(key, backend.as_expression(value), backend) for key, value in output_expressions.items()]


def parse_output_expression(output_expression: str, backend: SymbolicBackend[T_expr]) -> DependentVariable[T_expr]:
Expand Down Expand Up @@ -569,7 +572,7 @@ def _make_cost_variables(
"""Compiles a cost variable, taking into account any local parameters."""
# This allows users to reuse costs in subsequent expressions.
known_params = {local_param.symbol: local_param for local_param in local_params}
costs = _resources_to_cost_expressions(resources)
costs = {resource.name: resource.value for resource in resources}
new_cost_variables = []
for old_output_variable in parse_output_expressions(costs, backend):
# Substitute any local parameters
Expand All @@ -587,13 +590,6 @@ def _make_cost_variables(
return new_cost_variables


def _resources_to_cost_expressions(resources: list[Resource]) -> list[str]:
expressions = []
for resource in resources:
expressions.append(f"{resource.name} = {resource.value}")
return expressions


def _make_output_register_size_variables(
output_ports: dict[str, Port],
local_params: list[DependentVariable[T_expr]],
Expand All @@ -602,14 +598,14 @@ def _make_output_register_size_variables(
"""Compiles an output register size variables, taking into account any local parameters."""
output_register_sizes = {key: port.size for key, port in output_ports.items() if port.size is not None}

output_expression_strs = [
f"{_get_output_name(output)} = {expression_str}" for output, expression_str in output_register_sizes.items()
]
output_expression_map = {
_get_output_name(output): expression_str for output, expression_str in output_register_sizes.items()
}
# Next, substitute in any local params
known_params = {local_param.symbol: local_param for local_param in local_params}
return [
_substitute_local_parameters(output, known_params)
for output in parse_output_expressions(output_expression_strs, backend)
for output in parse_output_expressions(output_expression_map, backend)
]


Expand All @@ -631,16 +627,13 @@ def _make_input_register_size_variables(
"""
input_register_sizes = {key: port.size for key, port in input_ports.items() if port.size is not None}
# First, remove all the #in_ prefixes to get the corresponding output variables
output_expression_strs = [
f"{input} = {expression_str}"
for input, expression_str in _get_non_trivial_input_register_sizes(input_register_sizes).items()
]
output_expression_map = _get_non_trivial_input_register_sizes(input_register_sizes)

# Next, substitute in any local params
known_params = {local_param.symbol: local_param for local_param in local_params}
return [
_substitute_local_parameters(output, known_params)
for output in parse_output_expressions(output_expression_strs, backend)
for output in parse_output_expressions(output_expression_map, backend)
]


Expand All @@ -666,12 +659,12 @@ def _make_input_register_constants(
"""Identifies constant input register sizes and formats them as output variables."""
input_register_sizes = {key: port.size for key, port in input_ports.items() if port.size is not None}

output_expression_strs = [
f"#{inpt} = {expression_str}"
output_expressions = {
f"#{inpt}": expression_str
for inpt, expression_str in input_register_sizes.items()
if is_constant_int(expression_str)
]
return parse_output_expressions(output_expression_strs, backend)
}
return parse_output_expressions(output_expressions, backend)


def _substitute_local_parameters(output_variable, local_params):
Expand Down
9 changes: 4 additions & 5 deletions src/bartiq/integrations/latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from sympy import latex, symbols

from .._routine import Routine
from ..compilation._utilities import split_equation
from ..symbolics.sympy_interpreter import parse_to_sympy


Expand Down Expand Up @@ -82,10 +81,10 @@ def _format_port_sizes(ports, label):

def _format_local_variables(local_variables):
"""Formats routine's local variables to LaTeX."""
lines = []
for variable in local_variables:
assignment, expression = split_equation(variable)
lines.append(f"&{_format_param_math(assignment)} = {_latex_expression(expression)}")
lines = [
f"&{_format_param_math(symbol)} = {_latex_expression(expression)}"
for symbol, expression in local_variables.items()
]
return _format_section_multi_line("Local variables", lines)


Expand Down
17 changes: 7 additions & 10 deletions src/bartiq/symbolics/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,18 @@ def _split_equation(equation: str) -> tuple[str, str]:

def infer_subresources(routine: Routine, backend):
"""Infer what are the resources of a routine's children."""
expressions = [resource.value for resource in routine.resources.values()]
for variable in routine.local_variables:
_, rhs = _split_equation(variable)
expressions.append(rhs)
expressions = [*[resource.value for resource in routine.resources.values()], *routine.local_variables.values()]

# Any path-prefixed variable (i.e. prefixed by a .-separated path) not
# in subresources, but found in the RHS of an expression in either costs,
# local_variables, or output ports.
subresources = []
for expr in expressions:
vars = _extract_input_variables_from_expression(expr, backend)
subresources = [
var
for expr in expressions
for var in _extract_input_variables_from_expression(expr, backend)
# Only consider variables that are subresources (ones that have a "." in the name).
for var in vars:
if "." in var:
subresources.append(var)
if "." in var
]
return sorted(set(subresources))


Expand Down
4 changes: 2 additions & 2 deletions src/bartiq/verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def _verify_expressions_parsable(routine: Routine, backend: SymbolicBackend) ->
]
local_variable_problems = [
_verify_expression(
backend, local_variable, local_variable.split("=")[1], "local_variable", subroutine.absolute_path()
backend, f"{variable} = {expression}", expression, "local_variable", subroutine.absolute_path()
)
for local_variable in subroutine.local_variables
for variable, expression in subroutine.local_variables.items()
]
port_problems = [
_verify_expression(backend, port, port.size, "port size", subroutine.absolute_path())
Expand Down
10 changes: 5 additions & 5 deletions tests/compilation/data/compile_test_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@
type: other
value: {type: str, value: k}
input_params: [x, y]
local_variables: [i = x + y, j = x - y, k = (i + j) / 2 + (i - j) / 2]
local_variables: {"i": "x + y", "j": "x - y", "k": "(i + j) / 2 + (i - j) / 2"}
b:
name: b
type: null
Expand All @@ -658,7 +658,7 @@
type: other
value: {type: str, value: k}
input_params: [x, y]
local_variables: [i = x + y, j = x - y, k = (i + j) / 2 + (i - j) / 2]
local_variables: {"i": "x + y", "j": "x - y", "k": "(i + j) / 2 + (i - j) / 2"}
resources:
z:
name: z
Expand All @@ -677,7 +677,7 @@
type: other
value: {type: str, value: k}
input_params: [x, y]
local_variables: [i = x + y, j = x - y, k = (i + j) / 2 + (i - j) / 2]
local_variables: {"i": "x + y", "j": "x - y", "k": "(i + j) / 2 + (i - j) / 2"}
b:
name: b
type: null
Expand All @@ -687,7 +687,7 @@
type: other
value: {type: str, value: k}
input_params: [x, y]
local_variables: [i = x + y, j = x - y, k = (i + j) / 2 + (i - j) / 2]
local_variables: {"i": "x + y", "j": "x - y", "k": "(i + j) / 2 + (i - j) / 2"}
resources:
z:
name: z
Expand Down Expand Up @@ -827,7 +827,7 @@
name: z
type: other
value: {type: str, value: 3 * M}
local_variables: [M = 2 * N]
local_variables: {"M": "2 * N"}
- name: root
type: null
ports:
Expand Down
4 changes: 2 additions & 2 deletions tests/compilation/data/evaluate_test_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -929,13 +929,13 @@
- - name: root
type: null
input_params: [L]
local_variables: ['N=L/multiplicity(2,L)']
local_variables: {'N': 'L/multiplicity(2,L)'}
resources:
T_gates: {name: T_gates, type: additive, value: 8*ceiling(log_2(N))}
- [L = 10]
- name: root
type: null
local_variables: ['N=L/multiplicity(2,L)']
local_variables: {'N': 'L/multiplicity(2,L)'}
resources:
T_gates: {name: T_gates, type: additive, value: '32'}
# Linked params through multiple generations
Expand Down
10 changes: 5 additions & 5 deletions tests/compilation/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,25 @@ def f_3_optional_inputs(a, b=2, c=3):
DEFINED_EXPRESSION_FUNCTIONS_TEST_DATA = [
# Testing function which can be interpreted symbolically
(
SymbolicFunction.from_str(["a", "b"], ["x = f(a) + f(b)"], BACKEND),
SymbolicFunction.assemble(["a", "b"], {"x": "f(a) + f(b)"}, BACKEND),
{"f": f_1_simple},
{"x": "a + b + 2"},
),
# Testing function which cannot be interpreted symbolically due to having a condition
(
SymbolicFunction.from_str(["a"], ["x = f(a) + a"], BACKEND),
SymbolicFunction.assemble(["a"], {"x": "f(a) + a"}, BACKEND),
{"f": f_2_conditional},
{"x": "a + f(a)"},
),
# Testing function with multiple inputs, some with default values ["x = a**2 + 2*a + 3*b"]
(
SymbolicFunction.from_str(["a", "b"], ["x = f(a, b) + f(a, a, a)"], BACKEND),
SymbolicFunction.assemble(["a", "b"], {"x": "f(a, b) + f(a, a, a)"}, BACKEND),
{"f": f_3_optional_inputs},
{"x": "a ^ 2 + 2*a + 3*b"},
),
# Testing nested calls of a simple function
(
SymbolicFunction.from_str(["a"], ["x = f(f(f(a)))"], BACKEND),
SymbolicFunction.assemble(["a"], {"x": "f(f(f(a)))"}, BACKEND),
{"f": f_1_simple},
{"x": "a + 3"},
),
Expand All @@ -106,7 +106,7 @@ def mad_max(a, b):
"function, functions_map, expected_error",
[
(
SymbolicFunction.from_str(["a", "b"], ["x = max(a, b)"], BACKEND),
SymbolicFunction.assemble(["a", "b"], {"x": "max(a, b)"}, BACKEND),
{"max": mad_max},
"Attempted to redefine built-in function max",
)
Expand Down
Loading

0 comments on commit 1c93f0e

Please sign in to comment.