Skip to content

Commit

Permalink
add support for _contextual optimisation on local emulator (#303)
Browse files Browse the repository at this point in the history
  • Loading branch information
cqc-melf authored Mar 28, 2024
1 parent c82e1ae commit dc5e336
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 10 deletions.
26 changes: 20 additions & 6 deletions pytket/extensions/qiskit/backends/aer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import itertools
from collections import defaultdict
from dataclasses import dataclass
import json
from logging import warning
from typing import (
Dict,
Expand Down Expand Up @@ -66,6 +67,7 @@
)
from pytket.utils.operators import QubitPauliOperator
from pytket.utils.results import KwargTypes
from pytket.utils import prepare_circuit

from .ibm_utils import _STATUS_MAP, _batch_circuits
from .._metadata import __extension_version__
Expand Down Expand Up @@ -133,7 +135,7 @@ def required_predicates(self) -> List[Predicate]:

@property
def _result_id_type(self) -> _ResultIdTuple:
return (str, int)
return (str, int, int, str)

@property
def backend_info(self) -> BackendInfo:
Expand Down Expand Up @@ -243,6 +245,8 @@ def process_circuits(
See :py:meth:`pytket.backends.Backend.process_circuits`.
Supported kwargs: `seed`, `postprocess`.
"""
postprocess = kwargs.get("postprocess", False)

circuits = list(circuits)
n_shots_list = Backend._get_n_shots_as_list(
n_shots,
Expand All @@ -268,14 +272,22 @@ def process_circuits(
replace_implicit_swaps = self.supports_state or self.supports_unitary

for (n_shots, batch), indices in zip(circuit_batches, batch_order):
qcs = []
qcs, ppcirc_strs, tkc_qubits_count = [], [], []
for tkc in batch:
qc = tk_to_qiskit(tkc, replace_implicit_swaps)
if postprocess:
c0, ppcirc = prepare_circuit(tkc, allow_classical=False)
ppcirc_rep = ppcirc.to_dict()
else:
c0, ppcirc_rep = tkc, None

qc = tk_to_qiskit(c0, replace_implicit_swaps)
if self.supports_state:
qc.save_state()
elif self.supports_unitary:
qc.save_unitary()
qcs.append(qc)
tkc_qubits_count.append(c0.n_qubits)
ppcirc_strs.append(json.dumps(ppcirc_rep))

if self._needs_transpile:
qcs = transpile(qcs, self._qiskit_backend)
Expand All @@ -291,7 +303,7 @@ def process_circuits(
seed += 1
jobid = job.job_id()
for i, ind in enumerate(indices):
handle = ResultHandle(jobid, i)
handle = ResultHandle(jobid, i, tkc_qubits_count[i], ppcirc_strs[i])
handle_list[ind] = handle
self._cache[handle] = {"job": job}
return cast(List[ResultHandle], handle_list)
Expand All @@ -312,7 +324,7 @@ def get_result(self, handle: ResultHandle, **kwargs: KwargTypes) -> BackendResul
try:
return super().get_result(handle)
except CircuitNotRunError:
jobid, _ = handle
jobid, _, qubit_n, ppc = handle
try:
job: "AerJob" = self._cache[handle]["job"]
except KeyError:
Expand All @@ -321,7 +333,9 @@ def get_result(self, handle: ResultHandle, **kwargs: KwargTypes) -> BackendResul
res = job.result()
backresults = qiskit_result_to_backendresult(res)
for circ_index, backres in enumerate(backresults):
self._cache[ResultHandle(jobid, circ_index)]["result"] = backres
self._cache[ResultHandle(jobid, circ_index, qubit_n, ppc)][
"result"
] = backres

return cast(BackendResult, self._cache[handle]["result"])

Expand Down
3 changes: 2 additions & 1 deletion pytket/extensions/qiskit/backends/ibmq_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class IBMQEmulatorBackend(Backend):

_supports_shots = False
_supports_counts = True
_supports_contextual_optimisation = False
_supports_contextual_optimisation = True
_persistent_handles = False
_supports_expectation = False

Expand Down Expand Up @@ -116,6 +116,7 @@ def process_circuits(
See :py:meth:`pytket.backends.Backend.process_circuits`.
Supported kwargs: `seed`, `postprocess`.
"""

if valid_check:
self._ibmq._check_all_circuits(circuits)
return self._aer.process_circuits(
Expand Down
22 changes: 19 additions & 3 deletions tests/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def test_aer_result_handle() -> None:

handles = b.process_circuits([c, c.copy()], n_shots=2)

ids, indices = zip(*(han for han in handles))
ids, indices, _, _ = zip(*(han for han in handles))

assert all(isinstance(idval, str) for idval in ids)
assert indices == (0, 1)
Expand All @@ -725,7 +725,7 @@ def test_aer_result_handle() -> None:
errorinfo.value
)

wronghandle = ResultHandle("asdf", 3)
wronghandle = ResultHandle("asdf", 3, 0, "jsonstr")

with pytest.raises(CircuitNotRunError) as errorinfoCirc:
_ = b.get_result(wronghandle)
Expand Down Expand Up @@ -1114,6 +1114,23 @@ def test_postprocess() -> None:
b.cancel(h)


@pytest.mark.flaky(reruns=3, reruns_delay=10)
@pytest.mark.skipif(skip_remote_tests, reason=REASON)
def test_postprocess_emu(brisbane_emulator_backend: IBMQEmulatorBackend) -> None:
assert brisbane_emulator_backend.supports_contextual_optimisation
c = Circuit(2, 2)
c.X(0).X(1).measure_all()
c = brisbane_emulator_backend.get_compiled_circuit(c)
h = brisbane_emulator_backend.process_circuit(c, n_shots=10, postprocess=True)
ppcirc = Circuit.from_dict(json.loads(cast(str, h[3])))
ppcmds = ppcirc.get_commands()
assert len(ppcmds) > 0
assert all(ppcmd.op.type == OpType.ClassicalTransform for ppcmd in ppcmds)
r = brisbane_emulator_backend.get_result(h)
counts = r.get_counts()
assert sum(counts.values()) == 10


@pytest.mark.skipif(skip_remote_tests, reason=REASON)
def test_available_devices(ibm_provider: IBMProvider) -> None:
backend_info_list = IBMQBackend.available_devices(instance="ibm-q/open/main")
Expand Down Expand Up @@ -1416,7 +1433,6 @@ def test_ibmq_local_emulator(
brisbane_emulator_backend: IBMQEmulatorBackend,
) -> None:
b = brisbane_emulator_backend
assert not b.supports_contextual_optimisation
circ = Circuit(2).H(0).CX(0, 1).measure_all()
circ1 = b.get_compiled_circuit(circ)
h = b.process_circuit(circ1, n_shots=100)
Expand Down

0 comments on commit dc5e336

Please sign in to comment.