Skip to content

Commit

Permalink
Decompose protocol must create a context when given None and SimpleQu…
Browse files Browse the repository at this point in the history
…bitManager must add a prefix to its qubits (#6172)
  • Loading branch information
NoureldinYosri authored Jun 30, 2023
1 parent 7b753f6 commit 9849695
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 7 deletions.
18 changes: 12 additions & 6 deletions cirq-core/cirq/ops/qubit_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def qfree(self, qubits: Iterable['cirq.Qid']) -> None:
class _BaseAncillaQid(raw_types.Qid):
id: int
dim: int = 2
prefix: str = ''

def _comparison_key(self) -> int:
return self.id
Expand All @@ -49,39 +50,44 @@ def dimension(self) -> int:

def __repr__(self) -> str:
dim_str = f', dim={self.dim}' if self.dim != 2 else ''
return f"cirq.ops.{type(self).__name__}({self.id}{dim_str})"
prefix_str = f', prefix={self.prefix}' if self.prefix != '' else ''
return f"cirq.ops.{type(self).__name__}({self.id}{dim_str}{prefix_str})"


class CleanQubit(_BaseAncillaQid):
"""An internal qid type that represents a clean ancilla allocation."""

def __str__(self) -> str:
dim_str = f' (d={self.dimension})' if self.dim != 2 else ''
return f"_c({self.id}){dim_str}"
return f"{self.prefix}_c({self.id}){dim_str}"


class BorrowableQubit(_BaseAncillaQid):
"""An internal qid type that represents a dirty ancilla allocation."""

def __str__(self) -> str:
dim_str = f' (d={self.dimension})' if self.dim != 2 else ''
return f"_b({self.id}){dim_str}"
return f"{self.prefix}_b({self.id}){dim_str}"


class SimpleQubitManager(QubitManager):
"""Allocates a new `CleanQubit`/`BorrowableQubit` for every `qalloc`/`qborrow` request."""

def __init__(self):
def __init__(self, prefix: str = ''):
self._clean_id = 0
self._borrow_id = 0
self._prefix = prefix

def qalloc(self, n: int, dim: int = 2) -> List['cirq.Qid']:
self._clean_id += n
return [CleanQubit(i, dim) for i in range(self._clean_id - n, self._clean_id)]
return [CleanQubit(i, dim, self._prefix) for i in range(self._clean_id - n, self._clean_id)]

def qborrow(self, n: int, dim: int = 2) -> List['cirq.Qid']:
self._borrow_id = self._borrow_id + n
return [BorrowableQubit(i, dim) for i in range(self._borrow_id - n, self._borrow_id)]
return [
BorrowableQubit(i, dim, self._prefix)
for i in range(self._borrow_id - n, self._borrow_id)
]

def qfree(self, qubits: Iterable['cirq.Qid']) -> None:
for q in qubits:
Expand Down
10 changes: 10 additions & 0 deletions cirq-core/cirq/protocols/decompose_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import dataclasses
import inspect
from collections import defaultdict
Expand Down Expand Up @@ -49,6 +50,8 @@

DecomposeResult = Union[None, NotImplementedType, 'cirq.OP_TREE']

_CONTEXT_COUNTER = itertools.count() # Use _reset_context_counter() to reset the counter.


@runtime_checkable
class OpDecomposerWithContext(Protocol):
Expand Down Expand Up @@ -299,6 +302,8 @@ def decompose(
"acceptable to keep."
)

if context is None:
context = DecompositionContext(ops.SimpleQubitManager(prefix='_decompose_protocol'))
args = _DecomposeArgs(
context=context,
intercepting_decomposer=intercepting_decomposer,
Expand Down Expand Up @@ -364,6 +369,11 @@ def decompose_once(
TypeError: `val` didn't have a `_decompose_` method (or that method returned
`NotImplemented` or `None`) and `default` wasn't set.
"""
if context is None:
context = DecompositionContext(
ops.SimpleQubitManager(prefix=f'_decompose_protocol_{next(_CONTEXT_COUNTER)}')
)

method = getattr(val, '_decompose_with_context_', None)
decomposed = NotImplemented if method is None else method(*args, **kwargs, context=context)
if decomposed is NotImplemented or None:
Expand Down
42 changes: 41 additions & 1 deletion cirq-core/cirq/protocols/decompose_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Optional
from unittest import mock
import pytest
Expand Down Expand Up @@ -387,7 +388,7 @@ def test_decompose_recursive_dfs(with_context: bool):
circuit = cirq.Circuit(moment)
for val in [gate_op, tagged_op, controlled_op, classically_controlled_op, moment, circuit]:
mock_qm.reset_mock()
_ = cirq.decompose(val)
_ = cirq.decompose(val, context=cirq.DecompositionContext(qubit_manager=mock_qm))
assert mock_qm.method_calls == expected_calls

mock_qm.reset_mock()
Expand All @@ -398,3 +399,42 @@ def test_decompose_recursive_dfs(with_context: bool):
if with_context
else mock_qm.method_calls == expected_calls
)


class G1(cirq.Gate):
def _num_qubits_(self) -> int:
return 1

def _decompose_with_context_(self, qubits, context):
yield cirq.CNOT(qubits[0], context.qubit_manager.qalloc(1)[0])


class G2(cirq.Gate):
def _num_qubits_(self) -> int:
return 1

def _decompose_with_context_(self, qubits, context):
yield G1()(*context.qubit_manager.qalloc(1))


@mock.patch('cirq.protocols.decompose_protocol._CONTEXT_COUNTER', itertools.count())
def test_successive_decompose_once_succeed():
op = G2()(cirq.NamedQubit('q'))
d1 = cirq.decompose_once(op)
d2 = cirq.decompose_once(d1[0])
assert d2 == [
cirq.CNOT(
cirq.ops.CleanQubit(0, prefix='_decompose_protocol_0'),
cirq.ops.CleanQubit(0, prefix='_decompose_protocol_1'),
)
]


def test_decompose_without_context_succeed():
op = G2()(cirq.NamedQubit('q'))
assert cirq.decompose(op, keep=lambda op: op.gate is cirq.CNOT) == [
cirq.CNOT(
cirq.ops.CleanQubit(0, prefix='_decompose_protocol'),
cirq.ops.CleanQubit(1, prefix='_decompose_protocol'),
)
]

0 comments on commit 9849695

Please sign in to comment.