diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0beab3f..1aed87e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,4 +41,3 @@ repos: - id: conventional-pre-commit stages: [commit-msg] args: [] - diff --git a/docs/tutorials/01_basic_example.ipynb b/docs/tutorials/01_basic_example.ipynb index 870f74e..5a91177 100644 --- a/docs/tutorials/01_basic_example.ipynb +++ b/docs/tutorials/01_basic_example.ipynb @@ -66,13 +66,19 @@ }, { "cell_type": "code", + "execution_count": 1, "id": "6e2460c7-8ec3-4157-8b92-3d8ac5ff300f", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:18.268334Z", "start_time": "2024-06-24T20:58:18.259199Z" } }, +======= + "metadata": {}, + "outputs": [], +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "source": [ "my_algorithm = {\n", " \"name\": \"my_algorithm\",\n", @@ -82,9 +88,13 @@ " {\"name\": \"out\", \"direction\": \"output\", \"size\": None},\n", " ],\n", "}" +<<<<<<< HEAD ], "outputs": [], "execution_count": 31 +======= + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -104,13 +114,19 @@ }, { "cell_type": "code", + "execution_count": 2, "id": "591b4a88-af8c-41ec-bc07-ca0bbe576c70", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:18.403637Z", "start_time": "2024-06-24T20:58:18.391328Z" } }, +======= + "metadata": {}, + "outputs": [], +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "source": [ "routine_a = {\n", " \"name\": \"A\",\n", @@ -120,9 +136,13 @@ " {\"name\": \"out\", \"direction\": \"output\", \"size\": \"2*n_a\"},\n", " ],\n", "}" +<<<<<<< HEAD ], "outputs": [], "execution_count": 32 +======= + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -134,26 +154,36 @@ }, { "cell_type": "code", + "execution_count": 3, "id": "76a0ed1a-0acb-489e-b2fc-40ce956b8f49", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:18.438587Z", "start_time": "2024-06-24T20:58:18.425845Z" } }, +======= + "metadata": {}, + "outputs": [], +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "source": [ "routine_b = {\n", " \"name\": \"B\",\n", " \"type\": None,\n", " \"ports\": [\n", " {\"name\": \"in\", \"direction\": \"input\", \"size\": \"n_b\"},\n", - " # \"y\" will be defined in the next step\n", + " # \"y\" will be defined in the next step\n", " {\"name\": \"out\", \"direction\": \"output\", \"size\": \"n_b + y\"},\n", " ],\n", "}" +<<<<<<< HEAD ], "outputs": [], "execution_count": 33 +======= + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -174,24 +204,39 @@ }, { "cell_type": "code", + "execution_count": 4, "id": "697c5cdb-abb8-4437-97e8-d3a1e4de1ca2", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:18.530086Z", "start_time": "2024-06-24T20:58:18.521327Z" } }, +======= + "metadata": {}, + "outputs": [], +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "source": [ "# Define T-gate counts for routine a\n", "routine_a[\"input_params\"] = [\"x\"]\n", - "routine_a[\"resources\"] = [{\"name\": \"T_gates\", \"type\": \"additive\", \"value\": \"2*n_a + x\"}]\n", + "routine_a[\"resources\"] = [\n", + " {\"name\": \"T_gates\", \"type\": \"additive\", \"value\": \"2*n_a + x\"}\n", + "]\n", "\n", "# Define T-gate counts for routine b\n", "routine_b[\"input_params\"] = [\"y\"]\n", +<<<<<<< HEAD "routine_b[\"resources\"] = [{\"name\": \"T_gates\", \"type\": \"additive\", \"value\": \"n_b*ceil(log_2(n_b)) * y\"}]" ], "outputs": [], "execution_count": 34 +======= + "routine_b[\"resources\"] = [\n", + " {\"name\": \"T_gates\", \"type\": \"additive\", \"value\": \"n_b*ceil(log_2(n_b)) * y\"}\n", + "]" + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -211,13 +256,19 @@ }, { "cell_type": "code", + "execution_count": 5, "id": "3cc1dcef-a151-401b-a818-3be783aa68f9", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:18.603781Z", "start_time": "2024-06-24T20:58:18.587847Z" } }, +======= + "metadata": {}, + "outputs": [], +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "source": [ "my_algorithm[\"children\"] = [routine_a, routine_b]\n", "my_algorithm[\"connections\"] = [\n", @@ -227,9 +278,13 @@ "]\n", "my_algorithm[\"input_params\"] = [\"z\"]\n", "my_algorithm[\"linked_params\"] = [{\"source\": \"z\", \"targets\": [\"A.x\", \"B.y\"]}]" +<<<<<<< HEAD ], "outputs": [], "execution_count": 35 +======= + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -252,7 +307,9 @@ }, { "cell_type": "code", + "execution_count": 6, "id": "ac4be2b0-f1be-46db-a42d-9f204727db7b", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:18.643939Z", @@ -344,6 +401,24 @@ }, { "cell_type": "markdown", +======= + "metadata": {}, + "outputs": [], + "source": [ + "my_algorithm_qref = {'version': 'v1', 'program': my_algorithm}" + ] + }, + { + "cell_type": "markdown", + "id": "5fe2e5a6-0487-44d1-8071-3f692ee845ba", + "metadata": {}, + "source": [ + "Now we can translate our algorithm into a proper `bartiq` routine and see what's the total cost of `my_algorithm`." + ] + }, + { + "cell_type": "markdown", +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "id": "ea57757c-2757-4f81-93ec-b070a3b945fc", "metadata": {}, "source": [ @@ -355,7 +430,7 @@ "id": "5d4c9a76-ab98-4eac-9233-2c5ad62bb770", "metadata": {}, "source": [ - " Below you can find depiction of `my_algorithm`.\n", + "Below you can find depiction of `my_algorithm`.\n", "![title](../images/basic_uncompiled.png)\n", "\n", "We can create `bartiq` `Routine` from `QREF` definition by simply running:" @@ -363,20 +438,29 @@ }, { "cell_type": "code", + "execution_count": 7, "id": "f3bfa36d-b208-4488-abd6-8b020ef6cffc", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:19.598730Z", "start_time": "2024-06-24T20:58:19.589741Z" } }, +======= + "metadata": {}, + "outputs": [], +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "source": [ "from bartiq.integrations import qref_to_bartiq\n", - "\n", "uncompiled_routine = qref_to_bartiq(my_algorithm_qref)" +<<<<<<< HEAD ], "outputs": [], "execution_count": 38 +======= + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -390,7 +474,9 @@ }, { "cell_type": "code", + "execution_count": 8, "id": "ef6e5a99-9537-4a9a-aea2-ddf4ec28c5df", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:19.615088Z", @@ -400,6 +486,9 @@ "source": [ "uncompiled_routine.children[\"A\"].resources" ], +======= + "metadata": {}, +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "outputs": [ { "data": { @@ -407,12 +496,22 @@ "{'T_gates': }" ] }, +<<<<<<< HEAD "execution_count": 39, +======= + "execution_count": 8, +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "metadata": {}, "output_type": "execute_result" } ], +<<<<<<< HEAD "execution_count": 39 +======= + "source": [ + "uncompiled_routine.children[\"A\"].resources" + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -426,7 +525,9 @@ }, { "cell_type": "code", + "execution_count": 9, "id": "71860050-361e-478e-b6c8-495c9b0bb6fe", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:19.641159Z", @@ -436,19 +537,32 @@ "source": [ "uncompiled_routine.ports[\"out\"]" ], +======= + "metadata": {}, +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "outputs": [ { "data": { "text/plain": [ - "Port(my_algorithm.#out, size=None, output)" + "Port(name='out', parent=, direction='output', size=None, meta={})" ] }, +<<<<<<< HEAD "execution_count": 40, +======= + "execution_count": 9, +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "metadata": {}, "output_type": "execute_result" } ], +<<<<<<< HEAD "execution_count": 40 +======= + "source": [ + "uncompiled_routine.ports[\"out\"]" + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -460,7 +574,9 @@ }, { "cell_type": "code", + "execution_count": 10, "id": "da6566d1-f14b-4531-8029-4134143f785f", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:19.665379Z", @@ -470,6 +586,9 @@ "source": [ "uncompiled_routine.resources" ], +======= + "metadata": {}, +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "outputs": [ { "data": { @@ -477,12 +596,22 @@ "{}" ] }, +<<<<<<< HEAD "execution_count": 41, +======= + "execution_count": 10, +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "metadata": {}, "output_type": "execute_result" } ], +<<<<<<< HEAD "execution_count": 41 +======= + "source": [ + "uncompiled_routine.resources" + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -499,20 +628,29 @@ }, { "cell_type": "code", + "execution_count": 11, "id": "0163c4b5-16a6-4510-9210-dc84f6711c61", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:19.716932Z", "start_time": "2024-06-24T20:58:19.667445Z" } }, +======= + "metadata": {}, + "outputs": [], +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "source": [ "from bartiq import compile_routine\n", - "\n", "compiled_routine = compile_routine(uncompiled_routine)" +<<<<<<< HEAD ], "outputs": [], "execution_count": 42 +======= + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -524,7 +662,9 @@ }, { "cell_type": "code", + "execution_count": 12, "id": "e017c894-5203-4958-9f0f-f85904a1ad0b", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:19.733877Z", @@ -536,6 +676,9 @@ "print(\"Output size:\", compiled_routine.ports[\"out\"].size)\n", "print(\"Total T gates:\", compiled_routine.resources[\"T_gates\"].value)" ], +======= + "metadata": {}, +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "outputs": [ { "name": "stdout", @@ -547,7 +690,15 @@ ] } ], +<<<<<<< HEAD "execution_count": 43 +======= + "source": [ + "print(\"T gates for A:\", compiled_routine.children[\"A\"].resources[\"T_gates\"].value)\n", + "print(\"Output size:\", compiled_routine.ports[\"out\"].size)\n", + "print(\"Total T gates:\", compiled_routine.resources[\"T_gates\"].value)" + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -575,7 +726,9 @@ }, { "cell_type": "code", + "execution_count": 13, "id": "4c232bc3-77bb-4bd8-a587-b6591d539c41", +<<<<<<< HEAD "metadata": { "ExecuteTime": { "end_time": "2024-06-24T20:58:19.846940Z", @@ -598,6 +751,9 @@ "\n", "print(\"Total T gates:\", evaluated_routine.resources[\"T_gates\"].value)" ], +======= + "metadata": {}, +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) "outputs": [ { "name": "stdout", @@ -614,7 +770,26 @@ ] } ], +<<<<<<< HEAD "execution_count": 44 +======= + "source": [ + "from bartiq import evaluate\n", + "\n", + "print(\"Different values of n:\")\n", + "for n in range(6, 16, 2):\n", + " assignments = [f\"n={n}\"]\n", + " evaluated_routine = evaluate(compiled_routine, assignments)\n", + " print(f\"n = {n}, total T gates:\", evaluated_routine.resources[\"T_gates\"].value)\n", + "\n", + "z=5\n", + "assignments = [f\"n={n}\", f\"z={z}\"]\n", + "evaluated_routine = evaluate(compiled_routine, assignments)\n", + "print(f\"For n={n}, z={z}\")\n", + "\n", + "print(\"Total T gates:\", evaluated_routine.resources[\"T_gates\"].value)" + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -627,7 +802,6 @@ "- How to construct a simple algorithm in `bartiq`\n", "- How to compile an estimate\n", "- How to evaluate an estimate\n", - "- How to use the `qref` visualization tool to visualize an algorithm\n", "\n", "In the next tutorial we'll cover how to implement a more complex algorithm from a paper." ] diff --git a/docs/tutorials/02_alias_sampling_basic.ipynb b/docs/tutorials/02_alias_sampling_basic.ipynb index 940bf19..01a9745 100644 --- a/docs/tutorials/02_alias_sampling_basic.ipynb +++ b/docs/tutorials/02_alias_sampling_basic.ipynb @@ -199,10 +199,15 @@ " ],\n", " \"resources\": [{\"name\": \"T_gates\", \"type\": \"additive\", \"value\": \"4*mu-4\"}],\n", " \"input_params\": [\"mu\"],\n", +<<<<<<< HEAD "}" ], "outputs": [], "execution_count": 44 +======= + "}\n" + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "code", @@ -232,10 +237,15 @@ " {\"source\": \"In_target_1\", \"target\": \"out_target_1\"},\n", " ],\n", " \"input_params\": [\"X\"],\n", +<<<<<<< HEAD "}" ], "outputs": [], "execution_count": 45 +======= + "}\n" + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -320,6 +330,7 @@ } }, "source": [ +<<<<<<< HEAD "alias_sampling_qref = {\"version\": \"v1\", \"program\": alias_sampling_dict}" ], "outputs": [], @@ -389,6 +400,10 @@ "cell_type": "markdown", "source": "As we expected, the diagram displays five subroutines which are `usp`, `qrom`, `compare`, `had`, `swap`, and their hierarchical connections. Everything appears to be in order! Let's proceed with the compilation.\n", "id": "2ee860fa5c451aad" +======= + "alias_sampling_qref = {'version': 'v1', 'program': alias_sampling_dict}" + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "code", @@ -401,7 +416,6 @@ }, "source": [ "from bartiq.integrations import qref_to_bartiq\n", - "\n", "uncompiled_routine = qref_to_bartiq(alias_sampling_qref)" ], "outputs": [], @@ -418,7 +432,6 @@ }, "source": [ "from bartiq import compile_routine\n", - "\n", "compiled_routine = compile_routine(uncompiled_routine)" ], "outputs": [], @@ -501,7 +514,17 @@ ] } ], +<<<<<<< HEAD "execution_count": 52 +======= + "source": [ + "from bartiq import evaluate\n", + "assignments = {\"L=120\", \"mu=8\"}\n", + "evaluated_routine = evaluate(compiled_routine, assignments)\n", + "for resource in evaluated_routine.resources.values():\n", + " print(f\"{resource.name}: {resource.value}\")" + ] +>>>>>>> parent of 9ab018a (Add example of qref visualization in tutorials.) }, { "cell_type": "markdown", @@ -562,11 +585,9 @@ "source": [ "import math\n", "\n", - "\n", "def big_O(x):\n", " return math.ceil(x)\n", "\n", - "\n", "functions_map = {\"O\": big_O}\n", "evaluated_routine = evaluate(compiled_routine, assignments, functions_map=functions_map)\n", "for resource in evaluated_routine.resources.values():\n", @@ -620,7 +641,6 @@ }, "source": [ "from bartiq.integrations import explore_routine\n", - "\n", "explore_routine(evaluated_routine)" ], "outputs": [ @@ -657,7 +677,6 @@ "source": [ "from bartiq.integrations import routine_to_latex\n", "from IPython.display import Math\n", - "\n", "Math(routine_to_latex(evaluated_routine))" ], "outputs": [ diff --git a/src/bartiq/_routine.py b/src/bartiq/_routine.py index 69e2f3b..d0ac999 100644 --- a/src/bartiq/_routine.py +++ b/src/bartiq/_routine.py @@ -174,11 +174,7 @@ class BaseModel(_BaseModel): which is needed for handling sympy symbols. """ - model_config = { - "arbitrary_types_allowed": True, - "use_enum_values": True, - "extra": "forbid", - } + model_config = {"arbitrary_types_allowed": True, "use_enum_values": True, "extra": "forbid"} class Routine(BaseModel): @@ -283,11 +279,7 @@ def _validate_connections(cls, v, values) -> list[Connection]: ( connection if isinstance(connection, Connection) - else _parse_connection_dict( - connection, - values.data.get("children", {}), - values.data.get("ports", {}), - ) + else _parse_connection_dict(connection, values.data.get("children", {}), values.data.get("ports", {})) ) for connection in v ] diff --git a/src/bartiq/compilation/_compile.py b/src/bartiq/compilation/_compile.py index a4f2c35..f31aae5 100644 --- a/src/bartiq/compilation/_compile.py +++ b/src/bartiq/compilation/_compile.py @@ -101,9 +101,7 @@ def _compile_routine( def _add_function_to_routine( - routine: Routine, - global_functions: Optional[list[str]], - backend: SymbolicBackend[T_expr], + routine: Routine, global_functions: Optional[list[str]], backend: SymbolicBackend[T_expr] ) -> RoutineWithFunction[T_expr]: """Converts each routine to a symbolic function.""" routine_with_functions = RoutineWithFunction.from_routine(routine) @@ -186,9 +184,7 @@ def _pull_in_input_register_size_params( def _pull_in_input_register_size_param( - function: SymbolicFunction[T_expr], - input_port: Port, - backend: SymbolicBackend[T_expr], + function: SymbolicFunction[T_expr], input_port: Port, backend: SymbolicBackend[T_expr] ) -> SymbolicFunction[T_expr]: """Renames a leaf's input register size to the associated high-level register size.""" source_port = get_port_source(input_port) @@ -216,19 +212,13 @@ def _pull_in_input_register_size_param( if is_constant_int(root_input_register_size): assert isinstance(root_input_register_size, (int, str)) new_function = set_input_port_size_to_constant_value( - function, - input_port.absolute_path(exclude_root_name=True), - int(root_input_register_size), - backend, + function, input_port.absolute_path(exclude_root_name=True), int(root_input_register_size), backend ) return new_function # If the root input is of variable size, then we will rename the parameter with the root parameter elif is_single_parameter(root_input_register_size): - root_param = join_paths( - source_port.absolute_path(exclude_root_name=True), - str(root_input_register_size), - ) + root_param = join_paths(source_port.absolute_path(exclude_root_name=True), str(root_input_register_size)) param = str(input_port.size) leaf_param = join_paths(input_port.absolute_path(exclude_root_name=True), param) if is_constant_int(param): @@ -641,9 +631,7 @@ def _infer_missing_register_sizes( assert port_endpoint.size # To satisfy typechecker new_output_symbol = f"#{port.name}" new_outputs[new_output_symbol] = DependentVariable( - new_output_symbol, - backend.as_expression(port_endpoint.size), - backend=backend, + new_output_symbol, backend.as_expression(port_endpoint.size), backend=backend ) return SymbolicFunction(function.inputs, new_outputs) diff --git a/src/bartiq/compilation/_evaluate.py b/src/bartiq/compilation/_evaluate.py index a03d97d..16a19a1 100644 --- a/src/bartiq/compilation/_evaluate.py +++ b/src/bartiq/compilation/_evaluate.py @@ -61,12 +61,7 @@ class _RegisterSizeAssignment: @overload -def evaluate( - routine: Routine, - assignments: list[str], - *, - functions_map: Optional[FunctionsMap] = None, -) -> Routine: +def evaluate(routine: Routine, assignments: list[str], *, functions_map: Optional[FunctionsMap] = None) -> Routine: pass # pragma: no cover @@ -94,12 +89,7 @@ def evaluate(routine, assignments, *, backend=sympy_backend, functions_map=None) Returns: A new estimate with variables assigned to the desired values. """ - return _evaluate( - routine=routine, - assignments=assignments, - backend=backend, - functions_map=functions_map, - ) + return _evaluate(routine=routine, assignments=assignments, backend=backend, functions_map=functions_map) def _evaluate( @@ -197,10 +187,7 @@ def _evaluate_over_assignment( routine_downstream_register_size_assignments = _propagate_forward_constant_output_register_sizes( evaluated_routine ) - for ( - path, - downstream_assignments, - ) in routine_downstream_register_size_assignments.items(): + for path, downstream_assignments in routine_downstream_register_size_assignments.items(): register_sizes[path].extend(downstream_assignments) assert not register_sizes, f"Shouldn't have any more register sizes left to evaluate; found {register_sizes}" diff --git a/src/bartiq/compilation/_symbolic_function.py b/src/bartiq/compilation/_symbolic_function.py index 6bff800..6c5d502 100644 --- a/src/bartiq/compilation/_symbolic_function.py +++ b/src/bartiq/compilation/_symbolic_function.py @@ -125,9 +125,7 @@ def _verify_no_repeated_variable_symbols(variables: list[TVar]) -> None: raise BartiqCompilationError(f"Variable list contains repeated symbol; found {variables}") -def compile_functions( - functions: list[SymbolicFunction[T_expr]], -) -> SymbolicFunction[T_expr]: +def compile_functions(functions: list[SymbolicFunction[T_expr]]) -> SymbolicFunction[T_expr]: """Compiles a series of functions into a single function. The compiled function is the function produced when function inputs and outputs sharing the same name are diff --git a/src/bartiq/precompilation/_core.py b/src/bartiq/precompilation/_core.py index a1694b1..a5188a3 100644 --- a/src/bartiq/precompilation/_core.py +++ b/src/bartiq/precompilation/_core.py @@ -29,9 +29,7 @@ def precompile( - routine: Routine, - backend: SymbolicBackend, - precompilation_stages: Optional[list[PrecompilationStage]] = None, + routine: Routine, backend: SymbolicBackend, precompilation_stages: Optional[list[PrecompilationStage]] = None ) -> Routine: """A precompilation stage that transforms a routine prior to estimate compilation. diff --git a/src/bartiq/precompilation/stages.py b/src/bartiq/precompilation/stages.py index 1cb91df..7bf570e 100644 --- a/src/bartiq/precompilation/stages.py +++ b/src/bartiq/precompilation/stages.py @@ -179,24 +179,13 @@ def add_passthrough_placeholders(self, routine: Routine, _backend: SymbolicBacke # as otherwise serializing this using `exclude_unset` will still consider this field # as unset. This is a prime example why mutability might be problematic in hard to predict ways. # routine.children[new_routine.name] = new_routine # <- this causes problems - routine.children = { - **routine.children, - new_routine.name: new_routine, - } + routine.children = {**routine.children, new_routine.name: new_routine} connections_to_remove.append(i) connections_to_add.append( - Connection( - source=connection.source, - target=new_routine.ports["in_0"], - parent=routine, - ) + Connection(source=connection.source, target=new_routine.ports["in_0"], parent=routine) ) connections_to_add.append( - Connection( - source=new_routine.ports["out_0"], - target=connection.target, - parent=routine, - ) + Connection(source=new_routine.ports["out_0"], target=connection.target, parent=routine) ) self.index += 1 diff --git a/src/bartiq/symbolics/variables.py b/src/bartiq/symbolics/variables.py index 435250e..0cd2ba5 100644 --- a/src/bartiq/symbolics/variables.py +++ b/src/bartiq/symbolics/variables.py @@ -181,9 +181,7 @@ def _evaluate_expression(self) -> T_expr: ) in self.expression_functions.items(): if expression_function_callable: evaluated_expression = self.backend.define_function( - evaluated_expression, - expression_function_name, - expression_function_callable, + evaluated_expression, expression_function_name, expression_function_callable ) return evaluated_expression @@ -265,11 +263,7 @@ def substitute(self, variable: str, expression: str | Number) -> Self: else: new_expression_variables[symbol] = IndependentVariable(symbol) - return replace( - self, - expression=new_expression, - expression_variables=new_expression_variables, - ) + return replace(self, expression=new_expression, expression_variables=new_expression_variables) def substitute_series(self, substitution_map: dict[str, str | Number]) -> Self: """Applies a series of substitutions.""" @@ -290,11 +284,7 @@ def rename_function(self, old_function: str, new_function: str) -> Self: **old_expression_functions, new_function: expression_function, } - return replace( - self, - expression=new_expression, - expression_functions=new_expression_functions, - ) + return replace(self, expression=new_expression, expression_functions=new_expression_functions) else: return self diff --git a/tests/compilation/test_compile.py b/tests/compilation/test_compile.py index b984044..ab4b638 100644 --- a/tests/compilation/test_compile.py +++ b/tests/compilation/test_compile.py @@ -86,18 +86,12 @@ def f_3_optional_inputs(a, b=2, c=3): ] -@pytest.mark.parametrize( - "function, functions_map, expected_output_expressions", - DEFINED_EXPRESSION_FUNCTIONS_TEST_DATA, -) +@pytest.mark.parametrize("function, functions_map, expected_output_expressions", DEFINED_EXPRESSION_FUNCTIONS_TEST_DATA) def test_defined_expression_functions(function, functions_map, expected_output_expressions): new_function = define_expression_functions(function=function, functions_map=functions_map) assert function.inputs == new_function.inputs for output_symbol, output_variable in new_function.outputs.items(): - for ( - expression_function_name, - expression_function_callable, - ) in functions_map.items(): + for expression_function_name, expression_function_callable in functions_map.items(): assert output_variable.expression_functions[expression_function_name] == expression_function_callable new_evaluated_expression = output_variable.evaluated_expression assert expected_output_expressions[output_symbol] == new_evaluated_expression @@ -171,13 +165,7 @@ def test_compiling_correctly_propagates_global_functions(): "b": { "name": "b", "type": "dummy", - "resources": { - "X": { - "name": "X", - "value": "my_f(my_f(1), 4, 5) + 3", - "type": "other", - } - }, + "resources": {"X": {"name": "X", "value": "my_f(my_f(1), 4, 5) + 3", "type": "other"}}, }, }, ), @@ -209,11 +197,7 @@ def test_compiling_correctly_propagates_global_functions(): ), {"b.my_f": f_2_conditional}, [], - [ - (None, "X", "2*N + b.my_f(N) + 3"), - ("a", "X", "2*N"), - ("b", "X", "b.my_f(N) + 3"), - ], + [(None, "X", "2*N + b.my_f(N) + 3"), ("a", "X", "2*N"), ("b", "X", "b.my_f(N) + 3")], ), ( Routine( @@ -283,13 +267,7 @@ def test_compile_can_use_arbitrary_functions(routine, functions_map, global_func "b": { "name": "b", "type": "dummy", - "ports": { - "in_0": { - "name": "in_0", - "direction": "input", - "size": None, - } - }, + "ports": {"in_0": {"name": "in_0", "direction": "input", "size": None}}, } }, "connections": [{"source": "in_0", "target": "b.in_0"}], @@ -344,13 +322,7 @@ def test_compile_can_use_arbitrary_functions(routine, functions_map, global_func "name": "a", "type": "dummy", "input_params": ["M", "N"], - "ports": { - "out_foo": { - "name": "out_foo", - "direction": "output", - "size": "M + N", - } - }, + "ports": {"out_foo": {"name": "out_foo", "direction": "output", "size": "M + N"}}, }, "b": { "name": "b", @@ -383,11 +355,7 @@ def test_compile_can_use_arbitrary_functions(routine, functions_map, global_func "name": "c", "type": "dummy", "ports": { - "out_0": { - "name": "out_0", - "direction": "output", - "size": "2*N", - }, + "out_0": {"name": "out_0", "direction": "output", "size": "2*N"}, "in_0": {"name": "in_0", "direction": "input", "size": "N"}, "in_1": {"name": "in_1", "direction": "input", "size": "N"}, }, diff --git a/tests/compilation/test_core.py b/tests/compilation/test_core.py index 273369d..590a30a 100644 --- a/tests/compilation/test_core.py +++ b/tests/compilation/test_core.py @@ -129,29 +129,13 @@ def test_SymbolicFunction_equality(function_1, function_2, backend): ERRORS_TEST_CASES = [ # Output references unknown variables - ( - ([], ["b = a"]), - BartiqCompilationError, - "Expressions must not contain unknown variables", - ), + (([], ["b = a"]), BartiqCompilationError, "Expressions must not contain unknown variables"), # No duplicate inputs - ( - (["a", "a"], []), - BartiqCompilationError, - "Variable list contains repeated symbol", - ), + ((["a", "a"], []), BartiqCompilationError, "Variable list contains repeated symbol"), # No duplicate outputs - ( - ([], ["a = 0", "a = 1"]), - BartiqCompilationError, - "Variable list contains repeated symbol", - ), + (([], ["a = 0", "a = 1"]), BartiqCompilationError, "Variable list contains repeated symbol"), # Outputs cannot share names with inputs - ( - (["a"], ["a = 0"]), - BartiqCompilationError, - "Outputs must not reuse input symbols", - ), + ((["a"], ["a = 0"]), BartiqCompilationError, "Outputs must not reuse input symbols"), ] @@ -443,10 +427,7 @@ def test_rename_inputs_and_outputs(function, variable_map, expected_results, bac ] -@pytest.mark.parametrize( - "function, variable_map, expected_error", - RENAME_INPUTS_AND_OUTPUTS_ERRORS_TEST_CASES, -) +@pytest.mark.parametrize("function, variable_map, expected_error", RENAME_INPUTS_AND_OUTPUTS_ERRORS_TEST_CASES) def test_rename_inputs_and_outputs_errors(function, variable_map, expected_error, backend): function = SymbolicFunction.from_str(*function, backend) diff --git a/tests/compilation/test_evaluate.py b/tests/compilation/test_evaluate.py index 26282e2..9edb8b7 100644 --- a/tests/compilation/test_evaluate.py +++ b/tests/compilation/test_evaluate.py @@ -42,11 +42,7 @@ def test_evaluate(input_dict, assignments, expected_dict, backend): [ (routine_with_passthrough(), ["N=10"], {"out_0": "10"}), (routine_with_passthrough(a_out_size="N+2"), ["N=10"], {"out_0": "12"}), - ( - routine_with_two_passthroughs(), - ["N=10", "M=7"], - {"out_0": "10", "out_1": "7"}, - ), + (routine_with_two_passthroughs(), ["N=10", "M=7"], {"out_0": "10", "out_1": "7"}), ], ) def test_passthroughs(op, assignments, expected_sizes, backend): @@ -78,10 +74,7 @@ def custom_function(a, b): "X": { "name": "X", "type": "other", - "value": { - "type": "str", - "value": "2*N + a.unknown_fun(1)", - }, + "value": {"type": "str", "value": "2*N + a.unknown_fun(1)"}, } }, "input_params": ["N"], @@ -124,10 +117,7 @@ def custom_function(a, b): "X": { "name": "X", "type": "other", - "value": { - "type": "str", - "value": "a.unknown_fun(1) + 10", - }, + "value": {"type": "str", "value": "a.unknown_fun(1) + 10"}, } }, }, diff --git a/tests/compilation/test_symbolic_function.py b/tests/compilation/test_symbolic_function.py index a58b054..7eb6697 100644 --- a/tests/compilation/test_symbolic_function.py +++ b/tests/compilation/test_symbolic_function.py @@ -54,10 +54,7 @@ def _dummy_resources(cost_strs): (_make_routine(), ([], [])), # Simple case with no register sizes ( - _make_routine( - input_params=["a", "b"], - resources=_dummy_resources(["x = a + b", "y = a - b"]), - ), + _make_routine(input_params=["a", "b"], resources=_dummy_resources(["x = a + b", "y = a - b"])), (["a", "b"], ["x = a + b", "y = a - b"]), ), # No register sizes, but including local parameters @@ -113,31 +110,12 @@ def _dummy_resources(cost_strs): ( _make_routine( ports={ - **_ports_from_reg_sizes( - { - "0": "A", - "1": "A", - "2": "B", - "3": "C", - "4": "B", - "5": "A", - "6": "C", - }, - "in", - ), + **_ports_from_reg_sizes({"0": "A", "1": "A", "2": "B", "3": "C", "4": "B", "5": "A", "6": "C"}, "in"), **_ports_from_reg_sizes({"0": "A + B + 2*C"}, "out"), } ), ( - [ - "#in_0.A", - "#in_1.A", - "#in_2.B", - "#in_3.C", - "#in_4.B", - "#in_5.A", - "#in_6.C", - ], + ["#in_0.A", "#in_1.A", "#in_2.B", "#in_3.C", "#in_4.B", "#in_5.A", "#in_6.C"], ["#out_0 = #in_0.A + #in_2.B + 2*#in_3.C"], ), ), @@ -370,10 +348,7 @@ def test_to_symbolic_function_errors(routine, expected_error, backend): **_ports_from_reg_sizes({"0": None}, "out"), }, ), - ( - ["x", "y", "#in_0.z"], - ["a = x + y", "b = x - y - #in_0.z", "#out_0 = x * y * #in_0.z"], - ), + (["x", "y", "#in_0.z"], ["a = x + y", "b = x - y - #in_0.z", "#out_0 = x * y * #in_0.z"]), _make_routine( input_params=["x", "y"], ports={ @@ -408,10 +383,7 @@ def test_to_symbolic_function_errors(routine, expected_error, backend): ] -@pytest.mark.parametrize( - "routine, function, expected_routine", - UPDATE_ROUTINE_WITH_SYMBOLIC_FUNCTION_TEST_CASES, -) +@pytest.mark.parametrize("routine, function, expected_routine", UPDATE_ROUTINE_WITH_SYMBOLIC_FUNCTION_TEST_CASES) def test_update_routine_with_symbolic_function(routine, function, expected_routine, backend): function = SymbolicFunction.from_str(*function, backend) diff --git a/tests/compilation/test_utilities.py b/tests/compilation/test_utilities.py index e53154e..7c9d34c 100644 --- a/tests/compilation/test_utilities.py +++ b/tests/compilation/test_utilities.py @@ -128,10 +128,7 @@ def test_split_equation(equation, expected_lhs, expected_rhs): # No equals ("foo", "Equations must contain a single equals sign; found foo"), # Too many equals - ( - "foo=bar=baz", - "Equations must contain a single equals sign; found foo=bar=baz", - ), + ("foo=bar=baz", "Equations must contain a single equals sign; found foo=bar=baz"), # Bad LHS ("=a", "Equations must have both a left- and right-hand side; found =a"), # Bad RHS diff --git a/tests/integrations/test_qref_integration.py b/tests/integrations/test_qref_integration.py index 8aaad69..e8cdcbc 100644 --- a/tests/integrations/test_qref_integration.py +++ b/tests/integrations/test_qref_integration.py @@ -121,8 +121,6 @@ def test_converting_qref_v1_object_to_routine_give_correct_output(example_routin assert qref_to_bartiq(example_serialized_qref_v1_object) == example_routine -def test_conversion_from_bartiq_to_qref_raises_an_error_if_version_is_unsupported( - example_routine, -): +def test_conversion_from_bartiq_to_qref_raises_an_error_if_version_is_unsupported(example_routine): with pytest.raises(ValueError): bartiq_to_qref(example_routine, version="v3") diff --git a/tests/precompilation/test_precompile.py b/tests/precompilation/test_precompile.py index b7a6d7c..4fc200c 100644 --- a/tests/precompilation/test_precompile.py +++ b/tests/precompilation/test_precompile.py @@ -732,9 +732,7 @@ def test_precompile_adds_additive_resources(input_dict, precompilation_stages, e def test_precompile_handles_wildcards(input_dict, expected_resource, backend): input_routine = Routine(**input_dict) precompiled_routine = precompile( - input_routine, - precompilation_stages=[unroll_wildcarded_resources], - backend=backend, + input_routine, precompilation_stages=[unroll_wildcarded_resources], backend=backend ) assert precompiled_routine.resources[expected_resource[0]].value == expected_resource[1] diff --git a/tests/routine/test_routine.py b/tests/routine/test_routine.py index b5c8eb6..9629751 100644 --- a/tests/routine/test_routine.py +++ b/tests/routine/test_routine.py @@ -186,10 +186,7 @@ def test_parent_is_visited_after_children_are_visited(self): "a": _dummy_routine_dict("a"), "b": { **_dummy_routine_dict("b"), - "children": { - "c": _dummy_routine_dict("c"), - "d": _dummy_routine_dict("d"), - }, + "children": {"c": _dummy_routine_dict("c"), "d": _dummy_routine_dict("d")}, }, }, ) @@ -223,9 +220,7 @@ def test_linearly_ordered_children_are_always_enumerated_in_topological_order(se assert visited_names == ["d", "c", "a", "b", "root"] - def test_each_chain_of_linearly_connected_children_is_enumerated_in_topological_order( - self, - ): + def test_each_chain_of_linearly_connected_children_is_enumerated_in_topological_order(self): root = Routine( **_dummy_routine_dict("root"), children={name: _dummy_routine_dict(name) for name in ("a", "b", "c", "d")}, @@ -325,11 +320,7 @@ def test_walk_over_routine_with_multiple_connections(self): "ports": { "in_0": {"name": "in_0", "direction": "input", "size": 1}, "in_1": {"name": "in_1", "direction": "input", "size": 1}, - "out_0": { - "name": "out_0", - "direction": "output", - "size": 1, - }, + "out_0": {"name": "out_0", "direction": "output", "size": 1}, }, } }, diff --git a/tests/symbolics/test_sympy_interpreter.py b/tests/symbolics/test_sympy_interpreter.py index d8f1788..419a581 100644 --- a/tests/symbolics/test_sympy_interpreter.py +++ b/tests/symbolics/test_sympy_interpreter.py @@ -218,10 +218,7 @@ def add_routine_path(symbol): # Mmmmmm, three-tiered pi ("Pi * pi * PI", Symbol("Pi") * Symbol("pi") * Pi), # Ignore subscripts - ( - "N_x + N_y + N_z + N", - Symbol("N_x") + Symbol("N_y") + Symbol("N_z") + Symbol("N"), - ), + ("N_x + N_y + N_z + N", Symbol("N_x") + Symbol("N_y") + Symbol("N_z") + Symbol("N")), # Can use all letters of the English and Greek alphabets (with and without path prefixes) as Symbol *make_alphabet_test_cases(use="symbol"), # Can use all letters of the English and Greek alphabets (with and without path prefixes) as functions diff --git a/tests/test_routing.py b/tests/test_routing.py index b88bbb0..ae98717 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -88,11 +88,7 @@ def _nested_routine(): type=None, ports={ "in_0": {"name": "in_0", "direction": "input", "size": "N"}, - "out_0": { - "name": "out_0", - "direction": "output", - "size": "N", - }, + "out_0": {"name": "out_0", "direction": "output", "size": "N"}, }, ), }, @@ -114,11 +110,7 @@ def _nested_routine(): type=None, ports={ "in_0": {"name": "in_0", "direction": "input", "size": "N"}, - "out_0": { - "name": "out_0", - "direction": "output", - "size": "N", - }, + "out_0": {"name": "out_0", "direction": "output", "size": "N"}, }, ), }, diff --git a/tests/utilities.py b/tests/utilities.py index 46a198c..27c2284 100644 --- a/tests/utilities.py +++ b/tests/utilities.py @@ -30,11 +30,7 @@ def routine_with_passthrough(a_out_size="N"): type=None, ports={ "in_0": {"name": "in_0", "direction": "input", "size": "N"}, - "out_0": { - "name": "out_0", - "direction": "output", - "size": f"{a_out_size}", - }, + "out_0": {"name": "out_0", "direction": "output", "size": f"{a_out_size}"}, }, ), "b": Routine(