From c6a2d3e45c4b62cade8e9d090f857361e612a7d0 Mon Sep 17 00:00:00 2001 From: Renyi Chen Date: Mon, 23 Sep 2024 13:03:50 -0700 Subject: [PATCH] import cirq -> import ops, ... --- .../cirq/transformers/dynamical_decoupling.py | 119 +++++++++--------- 1 file changed, 63 insertions(+), 56 deletions(-) diff --git a/cirq-core/cirq/transformers/dynamical_decoupling.py b/cirq-core/cirq/transformers/dynamical_decoupling.py index 728eeca4ca6..31e7704f257 100644 --- a/cirq-core/cirq/transformers/dynamical_decoupling.py +++ b/cirq-core/cirq/transformers/dynamical_decoupling.py @@ -15,40 +15,47 @@ """Transformer pass that adds dynamical decoupling operations to a circuit.""" from functools import reduce -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union, TYPE_CHECKING from itertools import cycle +from cirq import ops, circuits, protocols from cirq.transformers import transformer_api from cirq.transformers.analytical_decompositions import single_qubit_decompositions -import cirq +from cirq.protocols import unitary_protocol +from cirq.protocols.has_unitary_protocol import has_unitary +from cirq.protocols.has_stabilizer_effect_protocol import has_stabilizer_effect + import numpy as np +if TYPE_CHECKING: + import cirq + -def _get_dd_sequence_from_schema_name(schema: str) -> Tuple['cirq.Gate', ...]: +def _get_dd_sequence_from_schema_name(schema: str) -> Tuple[ops.Gate, ...]: """Gets dynamical decoupling sequence from a schema name.""" match schema: case 'DEFAULT': - return (cirq.X, cirq.Y, cirq.X, cirq.Y) + return (ops.X, ops.Y, ops.X, ops.Y) case 'XX_PAIR': - return (cirq.X, cirq.X) + return (ops.X, ops.X) case 'X_XINV': - return (cirq.X, cirq.X**-1) + return (ops.X, ops.X**-1) case 'YY_PAIR': - return (cirq.Y, cirq.Y) + return (ops.Y, ops.Y) case 'Y_YINV': - return (cirq.Y, cirq.Y**-1) + return (ops.Y, ops.Y**-1) case _: raise ValueError('Invalid schema name.') -def _pauli_up_to_global_phase(gate: 'cirq.Gate') -> Union['cirq.Pauli', None]: - for pauli_gate in [cirq.X, cirq.Y, cirq.Z]: - if cirq.equal_up_to_global_phase(gate, pauli_gate): +def _pauli_up_to_global_phase(gate: ops.Gate) -> Union[ops.Pauli, None]: + for pauli_gate in [ops.X, ops.Y, ops.Z]: + if protocols.equal_up_to_global_phase(gate, pauli_gate): return pauli_gate return None -def _validate_dd_sequence(dd_sequence: Tuple['cirq.Gate', ...]) -> None: +def _validate_dd_sequence(dd_sequence: Tuple[ops.Gate, ...]) -> None: """Validates a given dynamical decoupling sequence. Args: @@ -65,10 +72,10 @@ def _validate_dd_sequence(dd_sequence: Tuple['cirq.Gate', ...]) -> None: 'Dynamical decoupling sequence should only contain gates that are essentially' ' Pauli gates.' ) - matrices = [cirq.unitary(gate) for gate in dd_sequence] + matrices = [unitary_protocol.unitary(gate) for gate in dd_sequence] product = reduce(np.matmul, matrices) - if not cirq.equal_up_to_global_phase(product, np.eye(2)): + if not protocols.equal_up_to_global_phase(product, np.eye(2)): raise ValueError( 'Invalid dynamical decoupling sequence. Expect sequence production equals' f' identity up to a global phase, got {product}.'.replace('\n', ' ') @@ -76,8 +83,8 @@ def _validate_dd_sequence(dd_sequence: Tuple['cirq.Gate', ...]) -> None: def _parse_dd_sequence( - schema: Union[str, Tuple['cirq.Gate', ...]] -) -> Tuple[Tuple['cirq.Gate', ...], Dict['cirq.Gate', 'cirq.Pauli']]: + schema: Union[str, Tuple[ops.Gate, ...]] +) -> Tuple[Tuple[ops.Gate, ...], Dict[ops.Gate, ops.Pauli]]: """Parses and returns dynamical decoupling sequence and its associated pauli map from schema.""" dd_sequence = None if isinstance(schema, str): @@ -87,33 +94,33 @@ def _parse_dd_sequence( dd_sequence = schema # Map Gate to Puali gate. This is necessary as dd sequence might contain gates like X^-1. - pauli_map: Dict['cirq.Gate', 'cirq.Pauli'] = {} + pauli_map: Dict[ops.Gate, ops.Pauli] = {} for gate in dd_sequence: pauli_gate = _pauli_up_to_global_phase(gate) if pauli_gate is not None: pauli_map[gate] = pauli_gate - for gate in [cirq.X, cirq.Y, cirq.Z]: + for gate in [ops.X, ops.Y, ops.Z]: pauli_map[gate] = gate return (dd_sequence, pauli_map) -def _is_single_qubit_operation(operation: 'cirq.Operation') -> bool: +def _is_single_qubit_operation(operation: ops.Operation) -> bool: return len(operation.qubits) == 1 -def _is_single_qubit_gate_moment(moment: 'cirq.Moment') -> bool: - return all([_is_single_qubit_operation(op) for op in moment]) +def _is_single_qubit_gate_moment(moment: circuits.Moment) -> bool: + return all(_is_single_qubit_operation(op) for op in moment) -def _is_clifford_op(op: 'cirq.Operation') -> bool: - return cirq.has_stabilizer_effect(op) and cirq.has_unitary(op) +def _is_clifford_op(op: ops.Operation) -> bool: + return has_unitary(op) and has_stabilizer_effect(op) def _calc_busy_moment_range_of_each_qubit( - circuit: 'cirq.FrozenCircuit', -) -> Dict['cirq.Qid', list[int]]: - busy_moment_range_by_qubit: Dict['cirq.Qid', list[int]] = { + circuit: circuits.FrozenCircuit, +) -> Dict[ops.Qid, list[int]]: + busy_moment_range_by_qubit: Dict[ops.Qid, list[int]] = { q: [len(circuit), -1] for q in circuit.all_qubits() } for moment_id, moment in enumerate(circuit): @@ -123,71 +130,71 @@ def _calc_busy_moment_range_of_each_qubit( return busy_moment_range_by_qubit -def _is_insertable_moment(moment: 'cirq.Moment', single_qubit_gate_moments_only: bool) -> bool: +def _is_insertable_moment(moment: circuits.Moment, single_qubit_gate_moments_only: bool) -> bool: return not single_qubit_gate_moments_only or _is_single_qubit_gate_moment(moment) def _merge_single_qubit_ops_to_phxz( - q: 'cirq.Qid', ops: Tuple['cirq.Operation', ...] -) -> 'cirq.Operation': + q: ops.Qid, operations: Tuple[ops.Operation, ...] +) -> ops.Operation: """Merges [op1, op2, ...] and returns an equivalent op""" - if len(ops) == 1: - return ops[0] - matrices = [cirq.unitary(op) for op in reversed(ops)] + if len(operations) == 1: + return operations[0] + matrices = [unitary_protocol.unitary(op) for op in reversed(operations)] product = reduce(np.matmul, matrices) - gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(product) or cirq.I + gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(product) or ops.I return gate.on(q) def _try_merge_single_qubit_ops_of_two_moments( - m1: 'cirq.Moment', m2: 'cirq.Moment' -) -> Tuple['cirq.Moment', ...]: + m1: circuits.Moment, m2: circuits.Moment +) -> Tuple[circuits.Moment, ...]: """Merge single qubit ops of 2 moments if possible, returns 2 moments otherwise.""" for q in m1.qubits & m2.qubits: op1 = m1.operation_at(q) op2 = m2.operation_at(q) if any( - not (_is_single_qubit_operation(op) and cirq.has_unitary(op)) + not (_is_single_qubit_operation(op) and has_unitary(op)) for op in [op1, op2] if op is not None ): return (m1, m2) - ops: set['cirq.Operation'] = set() + merged_ops: set[ops.Operation] = set() # Merge all operators on q to a single op. for q in m1.qubits | m2.qubits: # ops_on_q may contain 1 op or 2 ops. ops_on_q = [op for op in [m.operation_at(q) for m in [m1, m2]] if op is not None] - ops.add(_merge_single_qubit_ops_to_phxz(q, tuple(ops_on_q))) - return (cirq.Moment(ops),) + merged_ops.add(_merge_single_qubit_ops_to_phxz(q, tuple(ops_on_q))) + return (circuits.Moment(merged_ops),) def _calc_pulled_through( - moment: 'cirq.Moment', input_pauli_ops: 'cirq.PauliString' -) -> 'cirq.PauliString': + moment: circuits.Moment, input_pauli_ops: ops.PauliString +) -> ops.PauliString: """Calculates the pulled_through such that circuit(input_puali_ops, moment.clifford_ops) is equivalent to circuit(moment.clifford_ops, pulled_through). """ - clifford_ops_in_moment: list['cirq.Operation'] = [ + clifford_ops_in_moment: list[ops.Operation] = [ op for op in moment.operations if _is_clifford_op(op) ] return input_pauli_ops.after(clifford_ops_in_moment) -def _get_stop_qubits(moment: 'cirq.Moment') -> set['cirq.Qid']: - stop_pulling_through_qubits: set['cirq.Qid'] = set() +def _get_stop_qubits(moment: circuits.Moment) -> set[ops.Qid]: + stop_pulling_through_qubits: set[ops.Qid] = set() for op in moment: - if (not _is_clifford_op(op) and not _is_single_qubit_operation(op)) or not cirq.has_unitary( + if (not _is_clifford_op(op) and not _is_single_qubit_operation(op)) or not has_unitary( op ): # multi-qubit clifford op or non-mergable op. stop_pulling_through_qubits.update(op.qubits) return stop_pulling_through_qubits -def _need_merge_pulled_through(op_at_q: 'cirq.Operation', is_at_last_busy_moment: bool) -> bool: +def _need_merge_pulled_through(op_at_q: ops.Operation, is_at_last_busy_moment: bool) -> bool: """With a pulling through puali gate before op_at_q, need to merge with the pauli in the conditions below.""" # The op must be mergable and single-qubit - if not (_is_single_qubit_operation(op_at_q) and cirq.has_unitary(op_at_q)): + if not (_is_single_qubit_operation(op_at_q) and has_unitary(op_at_q)): return False # Either non-Clifford or at the last busy moment return is_at_last_busy_moment or not _is_clifford_op(op_at_q) @@ -198,7 +205,7 @@ def add_dynamical_decoupling( circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None, - schema: Union[str, Tuple['cirq.Gate', ...]] = 'DEFAULT', + schema: Union[str, Tuple[ops.Gate, ...]] = 'DEFAULT', single_qubit_gate_moments_only: bool = True, ) -> 'cirq.Circuit': """Adds dynamical decoupling gate operations to a given circuit. @@ -222,14 +229,14 @@ def add_dynamical_decoupling( busy_moment_range_by_qubit = _calc_busy_moment_range_of_each_qubit(orig_circuit) # Stores all the moments of the output circuit chronically - transformed_moments: list['cirq.Moment'] = [] + transformed_moments: list[circuits.Moment] = [] # A PauliString stores the result of 'pulling' Pauli gates past each operations # right before the current moment. - pulled_through: 'cirq.PauliString' = cirq.PauliString() + pulled_through: ops.PauliString = ops.PauliString() # Iterator of gate to be used in dd sequence for each qubit. dd_iter_by_qubits = {q: cycle(base_dd_sequence) for q in circuit.all_qubits()} - def _update_pulled_through(q: 'cirq.Qid', insert_gate: 'cirq.Gate') -> 'cirq.Operation': + def _update_pulled_through(q: ops.Qid, insert_gate: ops.Gate) -> ops.Operation: nonlocal pulled_through, pauli_map pulled_through *= pauli_map[insert_gate].on(q) return insert_gate.on(q) @@ -254,7 +261,7 @@ def _update_pulled_through(q: 'cirq.Qid', insert_gate: 'cirq.Gate') -> 'cirq.Ope # In detail: stop pulling through for multi-qubit non-Clifford ops or gates without # unitary representation (e.g., measure gates). If there are remaining pulled through ops, # insert into a new moment before current moment. - stop_pulling_through_qubits: set['cirq.Qid'] = _get_stop_qubits(moment) + stop_pulling_through_qubits: set[ops.Qid] = _get_stop_qubits(moment) new_moment_ops = [] for q in stop_pulling_through_qubits: # Insert the remaining pulled_through @@ -270,7 +277,7 @@ def _update_pulled_through(q: 'cirq.Qid', insert_gate: 'cirq.Gate') -> 'cirq.Ope if busy_moment_range_by_qubit[q][0] < moment_id <= busy_moment_range_by_qubit[q][1]: new_moment_ops.append(_update_pulled_through(q, next(dd_iter_by_qubits[q]))) moments_to_be_appended = _try_merge_single_qubit_ops_of_two_moments( - transformed_moments.pop(), cirq.Moment(new_moment_ops) + transformed_moments.pop(), circuits.Moment(new_moment_ops) ) transformed_moments.extend(moments_to_be_appended) @@ -304,7 +311,7 @@ def _update_pulled_through(q: 'cirq.Qid', insert_gate: 'cirq.Gate') -> 'cirq.Ope updated_moment_ops.add(updated_op) if updated_moment_ops: - updated_moment = cirq.Moment(updated_moment_ops) + updated_moment = circuits.Moment(updated_moment_ops) transformed_moments.append(updated_moment) # Step 3, update pulled through. @@ -318,8 +325,8 @@ def _update_pulled_through(q: 'cirq.Qid', insert_gate: 'cirq.Gate') -> 'cirq.Ope if ending_moment_ops: transformed_moments.extend( _try_merge_single_qubit_ops_of_two_moments( - transformed_moments.pop(), cirq.Moment(ending_moment_ops) + transformed_moments.pop(), circuits.Moment(ending_moment_ops) ) ) - return cirq.Circuit(transformed_moments) + return circuits.Circuit(transformed_moments)