Skip to content

Commit

Permalink
import cirq -> import ops, ...
Browse files Browse the repository at this point in the history
  • Loading branch information
babacry committed Sep 23, 2024
1 parent e90191f commit c6a2d3e
Showing 1 changed file with 63 additions and 56 deletions.
119 changes: 63 additions & 56 deletions cirq-core/cirq/transformers/dynamical_decoupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,47 @@
"""Transformer pass that adds dynamical decoupling operations to a circuit."""

from functools import reduce
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union, TYPE_CHECKING
from itertools import cycle

from cirq import ops, circuits, protocols
from cirq.transformers import transformer_api
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
import cirq
from cirq.protocols import unitary_protocol
from cirq.protocols.has_unitary_protocol import has_unitary
from cirq.protocols.has_stabilizer_effect_protocol import has_stabilizer_effect

import numpy as np

if TYPE_CHECKING:
import cirq

Check warning on line 31 in cirq-core/cirq/transformers/dynamical_decoupling.py

View check run for this annotation

Codecov / codecov/patch

cirq-core/cirq/transformers/dynamical_decoupling.py#L31

Added line #L31 was not covered by tests


def _get_dd_sequence_from_schema_name(schema: str) -> Tuple['cirq.Gate', ...]:
def _get_dd_sequence_from_schema_name(schema: str) -> Tuple[ops.Gate, ...]:
"""Gets dynamical decoupling sequence from a schema name."""
match schema:
case 'DEFAULT':
return (cirq.X, cirq.Y, cirq.X, cirq.Y)
return (ops.X, ops.Y, ops.X, ops.Y)
case 'XX_PAIR':
return (cirq.X, cirq.X)
return (ops.X, ops.X)
case 'X_XINV':
return (cirq.X, cirq.X**-1)
return (ops.X, ops.X**-1)
case 'YY_PAIR':
return (cirq.Y, cirq.Y)
return (ops.Y, ops.Y)
case 'Y_YINV':
return (cirq.Y, cirq.Y**-1)
return (ops.Y, ops.Y**-1)
case _:
raise ValueError('Invalid schema name.')


def _pauli_up_to_global_phase(gate: 'cirq.Gate') -> Union['cirq.Pauli', None]:
for pauli_gate in [cirq.X, cirq.Y, cirq.Z]:
if cirq.equal_up_to_global_phase(gate, pauli_gate):
def _pauli_up_to_global_phase(gate: ops.Gate) -> Union[ops.Pauli, None]:
for pauli_gate in [ops.X, ops.Y, ops.Z]:
if protocols.equal_up_to_global_phase(gate, pauli_gate):
return pauli_gate
return None


def _validate_dd_sequence(dd_sequence: Tuple['cirq.Gate', ...]) -> None:
def _validate_dd_sequence(dd_sequence: Tuple[ops.Gate, ...]) -> None:
"""Validates a given dynamical decoupling sequence.
Args:
Expand All @@ -65,19 +72,19 @@ def _validate_dd_sequence(dd_sequence: Tuple['cirq.Gate', ...]) -> None:
'Dynamical decoupling sequence should only contain gates that are essentially'
' Pauli gates.'
)
matrices = [cirq.unitary(gate) for gate in dd_sequence]
matrices = [unitary_protocol.unitary(gate) for gate in dd_sequence]
product = reduce(np.matmul, matrices)

if not cirq.equal_up_to_global_phase(product, np.eye(2)):
if not protocols.equal_up_to_global_phase(product, np.eye(2)):
raise ValueError(
'Invalid dynamical decoupling sequence. Expect sequence production equals'
f' identity up to a global phase, got {product}.'.replace('\n', ' ')
)


def _parse_dd_sequence(
schema: Union[str, Tuple['cirq.Gate', ...]]
) -> Tuple[Tuple['cirq.Gate', ...], Dict['cirq.Gate', 'cirq.Pauli']]:
schema: Union[str, Tuple[ops.Gate, ...]]
) -> Tuple[Tuple[ops.Gate, ...], Dict[ops.Gate, ops.Pauli]]:
"""Parses and returns dynamical decoupling sequence and its associated pauli map from schema."""
dd_sequence = None
if isinstance(schema, str):
Expand All @@ -87,33 +94,33 @@ def _parse_dd_sequence(
dd_sequence = schema

# Map Gate to Puali gate. This is necessary as dd sequence might contain gates like X^-1.
pauli_map: Dict['cirq.Gate', 'cirq.Pauli'] = {}
pauli_map: Dict[ops.Gate, ops.Pauli] = {}
for gate in dd_sequence:
pauli_gate = _pauli_up_to_global_phase(gate)
if pauli_gate is not None:
pauli_map[gate] = pauli_gate
for gate in [cirq.X, cirq.Y, cirq.Z]:
for gate in [ops.X, ops.Y, ops.Z]:
pauli_map[gate] = gate

return (dd_sequence, pauli_map)


def _is_single_qubit_operation(operation: 'cirq.Operation') -> bool:
def _is_single_qubit_operation(operation: ops.Operation) -> bool:
return len(operation.qubits) == 1


def _is_single_qubit_gate_moment(moment: 'cirq.Moment') -> bool:
return all([_is_single_qubit_operation(op) for op in moment])
def _is_single_qubit_gate_moment(moment: circuits.Moment) -> bool:
return all(_is_single_qubit_operation(op) for op in moment)


def _is_clifford_op(op: 'cirq.Operation') -> bool:
return cirq.has_stabilizer_effect(op) and cirq.has_unitary(op)
def _is_clifford_op(op: ops.Operation) -> bool:
return has_unitary(op) and has_stabilizer_effect(op)


def _calc_busy_moment_range_of_each_qubit(
circuit: 'cirq.FrozenCircuit',
) -> Dict['cirq.Qid', list[int]]:
busy_moment_range_by_qubit: Dict['cirq.Qid', list[int]] = {
circuit: circuits.FrozenCircuit,
) -> Dict[ops.Qid, list[int]]:
busy_moment_range_by_qubit: Dict[ops.Qid, list[int]] = {
q: [len(circuit), -1] for q in circuit.all_qubits()
}
for moment_id, moment in enumerate(circuit):
Expand All @@ -123,71 +130,71 @@ def _calc_busy_moment_range_of_each_qubit(
return busy_moment_range_by_qubit


def _is_insertable_moment(moment: 'cirq.Moment', single_qubit_gate_moments_only: bool) -> bool:
def _is_insertable_moment(moment: circuits.Moment, single_qubit_gate_moments_only: bool) -> bool:
return not single_qubit_gate_moments_only or _is_single_qubit_gate_moment(moment)


def _merge_single_qubit_ops_to_phxz(
q: 'cirq.Qid', ops: Tuple['cirq.Operation', ...]
) -> 'cirq.Operation':
q: ops.Qid, operations: Tuple[ops.Operation, ...]
) -> ops.Operation:
"""Merges [op1, op2, ...] and returns an equivalent op"""
if len(ops) == 1:
return ops[0]
matrices = [cirq.unitary(op) for op in reversed(ops)]
if len(operations) == 1:
return operations[0]
matrices = [unitary_protocol.unitary(op) for op in reversed(operations)]
product = reduce(np.matmul, matrices)
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(product) or cirq.I
gate = single_qubit_decompositions.single_qubit_matrix_to_phxz(product) or ops.I
return gate.on(q)


def _try_merge_single_qubit_ops_of_two_moments(
m1: 'cirq.Moment', m2: 'cirq.Moment'
) -> Tuple['cirq.Moment', ...]:
m1: circuits.Moment, m2: circuits.Moment
) -> Tuple[circuits.Moment, ...]:
"""Merge single qubit ops of 2 moments if possible, returns 2 moments otherwise."""
for q in m1.qubits & m2.qubits:
op1 = m1.operation_at(q)
op2 = m2.operation_at(q)
if any(
not (_is_single_qubit_operation(op) and cirq.has_unitary(op))
not (_is_single_qubit_operation(op) and has_unitary(op))
for op in [op1, op2]
if op is not None
):
return (m1, m2)
ops: set['cirq.Operation'] = set()
merged_ops: set[ops.Operation] = set()
# Merge all operators on q to a single op.
for q in m1.qubits | m2.qubits:
# ops_on_q may contain 1 op or 2 ops.
ops_on_q = [op for op in [m.operation_at(q) for m in [m1, m2]] if op is not None]
ops.add(_merge_single_qubit_ops_to_phxz(q, tuple(ops_on_q)))
return (cirq.Moment(ops),)
merged_ops.add(_merge_single_qubit_ops_to_phxz(q, tuple(ops_on_q)))
return (circuits.Moment(merged_ops),)


def _calc_pulled_through(
moment: 'cirq.Moment', input_pauli_ops: 'cirq.PauliString'
) -> 'cirq.PauliString':
moment: circuits.Moment, input_pauli_ops: ops.PauliString
) -> ops.PauliString:
"""Calculates the pulled_through such that circuit(input_puali_ops, moment.clifford_ops) is
equivalent to circuit(moment.clifford_ops, pulled_through).
"""
clifford_ops_in_moment: list['cirq.Operation'] = [
clifford_ops_in_moment: list[ops.Operation] = [
op for op in moment.operations if _is_clifford_op(op)
]
return input_pauli_ops.after(clifford_ops_in_moment)


def _get_stop_qubits(moment: 'cirq.Moment') -> set['cirq.Qid']:
stop_pulling_through_qubits: set['cirq.Qid'] = set()
def _get_stop_qubits(moment: circuits.Moment) -> set[ops.Qid]:
stop_pulling_through_qubits: set[ops.Qid] = set()
for op in moment:
if (not _is_clifford_op(op) and not _is_single_qubit_operation(op)) or not cirq.has_unitary(
if (not _is_clifford_op(op) and not _is_single_qubit_operation(op)) or not has_unitary(
op
): # multi-qubit clifford op or non-mergable op.
stop_pulling_through_qubits.update(op.qubits)
return stop_pulling_through_qubits


def _need_merge_pulled_through(op_at_q: 'cirq.Operation', is_at_last_busy_moment: bool) -> bool:
def _need_merge_pulled_through(op_at_q: ops.Operation, is_at_last_busy_moment: bool) -> bool:
"""With a pulling through puali gate before op_at_q, need to merge with the
pauli in the conditions below."""
# The op must be mergable and single-qubit
if not (_is_single_qubit_operation(op_at_q) and cirq.has_unitary(op_at_q)):
if not (_is_single_qubit_operation(op_at_q) and has_unitary(op_at_q)):
return False
# Either non-Clifford or at the last busy moment
return is_at_last_busy_moment or not _is_clifford_op(op_at_q)
Expand All @@ -198,7 +205,7 @@ def add_dynamical_decoupling(
circuit: 'cirq.AbstractCircuit',
*,
context: Optional['cirq.TransformerContext'] = None,
schema: Union[str, Tuple['cirq.Gate', ...]] = 'DEFAULT',
schema: Union[str, Tuple[ops.Gate, ...]] = 'DEFAULT',
single_qubit_gate_moments_only: bool = True,
) -> 'cirq.Circuit':
"""Adds dynamical decoupling gate operations to a given circuit.
Expand All @@ -222,14 +229,14 @@ def add_dynamical_decoupling(
busy_moment_range_by_qubit = _calc_busy_moment_range_of_each_qubit(orig_circuit)

# Stores all the moments of the output circuit chronically
transformed_moments: list['cirq.Moment'] = []
transformed_moments: list[circuits.Moment] = []
# A PauliString stores the result of 'pulling' Pauli gates past each operations
# right before the current moment.
pulled_through: 'cirq.PauliString' = cirq.PauliString()
pulled_through: ops.PauliString = ops.PauliString()
# Iterator of gate to be used in dd sequence for each qubit.
dd_iter_by_qubits = {q: cycle(base_dd_sequence) for q in circuit.all_qubits()}

def _update_pulled_through(q: 'cirq.Qid', insert_gate: 'cirq.Gate') -> 'cirq.Operation':
def _update_pulled_through(q: ops.Qid, insert_gate: ops.Gate) -> ops.Operation:
nonlocal pulled_through, pauli_map
pulled_through *= pauli_map[insert_gate].on(q)
return insert_gate.on(q)
Expand All @@ -254,7 +261,7 @@ def _update_pulled_through(q: 'cirq.Qid', insert_gate: 'cirq.Gate') -> 'cirq.Ope
# In detail: stop pulling through for multi-qubit non-Clifford ops or gates without
# unitary representation (e.g., measure gates). If there are remaining pulled through ops,
# insert into a new moment before current moment.
stop_pulling_through_qubits: set['cirq.Qid'] = _get_stop_qubits(moment)
stop_pulling_through_qubits: set[ops.Qid] = _get_stop_qubits(moment)
new_moment_ops = []
for q in stop_pulling_through_qubits:
# Insert the remaining pulled_through
Expand All @@ -270,7 +277,7 @@ def _update_pulled_through(q: 'cirq.Qid', insert_gate: 'cirq.Gate') -> 'cirq.Ope
if busy_moment_range_by_qubit[q][0] < moment_id <= busy_moment_range_by_qubit[q][1]:
new_moment_ops.append(_update_pulled_through(q, next(dd_iter_by_qubits[q])))
moments_to_be_appended = _try_merge_single_qubit_ops_of_two_moments(
transformed_moments.pop(), cirq.Moment(new_moment_ops)
transformed_moments.pop(), circuits.Moment(new_moment_ops)
)
transformed_moments.extend(moments_to_be_appended)

Expand Down Expand Up @@ -304,7 +311,7 @@ def _update_pulled_through(q: 'cirq.Qid', insert_gate: 'cirq.Gate') -> 'cirq.Ope
updated_moment_ops.add(updated_op)

if updated_moment_ops:
updated_moment = cirq.Moment(updated_moment_ops)
updated_moment = circuits.Moment(updated_moment_ops)
transformed_moments.append(updated_moment)

# Step 3, update pulled through.
Expand All @@ -318,8 +325,8 @@ def _update_pulled_through(q: 'cirq.Qid', insert_gate: 'cirq.Gate') -> 'cirq.Ope
if ending_moment_ops:
transformed_moments.extend(
_try_merge_single_qubit_ops_of_two_moments(
transformed_moments.pop(), cirq.Moment(ending_moment_ops)
transformed_moments.pop(), circuits.Moment(ending_moment_ops)
)
)

return cirq.Circuit(transformed_moments)
return circuits.Circuit(transformed_moments)

0 comments on commit c6a2d3e

Please sign in to comment.