Skip to content

Commit

Permalink
Add dynamic_one_shot tensorflow support + expanded testing (#5973)
Browse files Browse the repository at this point in the history
**Context:**
As name says. Gradient workflows no longer raise errors after the merge
of #5791 , but their correctness is yet to be verified.

**Description of the Change:**
* Updated casting rules in `dynamic_one_shot`'s processing function for
tensorflow.
* For the changes to be fully integrated, the way that the interface is
passed around when calling a QNode needed to be changed, so the
following changes were made:
* `QNode` has updated behaviour for how `mcm_config` is used during
execution. In `QNode._execution_component`, a copy of
`self.execute_kwargs["mcm_config"]` is the source of truth, and in
`qml.execute`, `config.mcm_config` is the source of truth.
* Added a private `pad-invalid-samples` `postselect_mode`. The
`postselect_mode` is switched to this automatically in `qml.execute` if
executing with jax and shots and `postselect_mode == "hw-like"`. This
way we standardize how the MCM transforms determine if jax is being
used.
  * Updates to `capture` module to accommodate the above changes.

**Benefits:**
* `dynamic_one_shot` doesn't cast to interfaces inside the ML boundary
* `dynamic_one_shot` works with tensorflow
* Expanded tests

**Possible Drawbacks:**

**Related GitHub Issues:**
Fixes #5736, #5710 

Duplicate of #5861 which was closed due to release branch merge stuff.

---------

Co-authored-by: Jay Soni <[email protected]>
Co-authored-by: Astral Cai <[email protected]>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Yushao Chen (Jerry) <[email protected]>
Co-authored-by: Christina Lee <[email protected]>
Co-authored-by: Thomas R. Bromley <[email protected]>
Co-authored-by: soranjh <[email protected]>
Co-authored-by: Pietropaolo Frisoni <[email protected]>
Co-authored-by: Ahmed Darwish <[email protected]>
Co-authored-by: Utkarsh <[email protected]>
Co-authored-by: David Wierichs <[email protected]>
Co-authored-by: Christina Lee <[email protected]>
Co-authored-by: Mikhail Andrenkov <[email protected]>
Co-authored-by: Diego <[email protected]>
Co-authored-by: Josh Izaac <[email protected]>
Co-authored-by: Diego <[email protected]>
Co-authored-by: Vincent Michaud-Rioux <[email protected]>
Co-authored-by: lillian542 <[email protected]>
Co-authored-by: Jack Brown <[email protected]>
Co-authored-by: Paul Finlay <[email protected]>
Co-authored-by: David Ittah <[email protected]>
Co-authored-by: Cristian Emiliano Godinez Ramirez <[email protected]>
Co-authored-by: Vincent Michaud-Rioux <[email protected]>
  • Loading branch information
1 parent d7c984e commit eb7c349
Show file tree
Hide file tree
Showing 14 changed files with 493 additions and 355 deletions.
2 changes: 1 addition & 1 deletion doc/releases/changelog-0.37.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -929,4 +929,4 @@ Kenya Sakka,
Jay Soni,
Kazuki Tsuoka,
Haochen Paul Wang,
David Wierichs.
David Wierichs.
12 changes: 9 additions & 3 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@

<h3>New features since last release</h3>

* A new method `process_density_matrix` has been added to the `ProbabilityMP` and `DensityMatrixMP`
classes, allowing for more efficient handling of quantum density matrices, particularly with batch
processing support. This method simplifies the calculation of probabilities from quantum states
represented as density matrices.
[(#5830)](https://github.com/PennyLaneAI/pennylane/pull/5830)

* Resolved the bug in `qml.ThermalRelaxationError` where there was a typo from `tq` to `tg`.
[(#5988)](https://github.com/PennyLaneAI/pennylane/issues/5988)

* A new method `process_density_matrix` has been added to the `ProbabilityMP` and `DensityMatrixMP` classes, allowing for more efficient handling of quantum density matrices, particularly with batch processing support. This method simplifies the calculation of probabilities from quantum states represented as density matrices.
[(#5830)](https://github.com/PennyLaneAI/pennylane/pull/5830)

* The `qml.PrepSelPrep` template is added. The template implements a block-encoding of a linear
combination of unitaries.
[(#5756)](https://github.com/PennyLaneAI/pennylane/pull/5756)
Expand Down Expand Up @@ -57,6 +60,9 @@
* `QuantumScript.hash` is now cached, leading to performance improvements.
[(#5919)](https://github.com/PennyLaneAI/pennylane/pull/5919)

* `qml.dynamic_one_shot` now supports circuits using the `"tensorflow"` interface.
[(#5973)](https://github.com/PennyLaneAI/pennylane/pull/5973)

* The representation for `Wires` has now changed to be more copy-paste friendly.
[(#5958)](https://github.com/PennyLaneAI/pennylane/pull/5958)

Expand Down
5 changes: 1 addition & 4 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,10 +532,7 @@ def preprocess(

transform_program.add_transform(validate_device_wires, self.wires, name=self.name)
transform_program.add_transform(
mid_circuit_measurements,
device=self,
mcm_config=config.mcm_config,
interface=config.interface,
mid_circuit_measurements, device=self, mcm_config=config.mcm_config
)
transform_program.add_transform(
decompose,
Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/execution_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __post_init__(self):
None,
):
raise ValueError(f"Invalid mid-circuit measurements method '{self.mcm_method}'.")
if self.postselect_mode not in ("hw-like", "fill-shots", None):
if self.postselect_mode not in ("hw-like", "fill-shots", "pad-invalid-samples", None):
raise ValueError(f"Invalid postselection mode '{self.postselect_mode}'.")


Expand Down
5 changes: 2 additions & 3 deletions pennylane/devices/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,22 +151,21 @@ def mid_circuit_measurements(
tape: qml.tape.QuantumTape,
device,
mcm_config=MCMConfig(),
interface=None,
**kwargs, # pylint: disable=unused-argument
) -> tuple[QuantumTapeBatch, PostprocessingFn]:
"""Provide the transform to handle mid-circuit measurements.
If the tape or device uses finite-shot, use the native implementation (i.e. no transform),
and use the ``qml.defer_measurements`` transform otherwise.
"""

if isinstance(mcm_config, dict):
mcm_config = MCMConfig(**mcm_config)
mcm_method = mcm_config.mcm_method
if mcm_method is None:
mcm_method = "one-shot" if tape.shots else "deferred"

if mcm_method == "one-shot":
return qml.dynamic_one_shot(tape, interface=interface)
return qml.dynamic_one_shot(tape, postselect_mode=mcm_config.postselect_mode)
if mcm_method == "tree-traversal":
return (tape,), null_postprocessing
return qml.defer_measurements(tape, device=device)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/devices/qubit/apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def apply_mid_measure(
axis = wire.toarray()[0]
slices = [slice(None)] * qml.math.ndim(state)
slices[axis] = 0
prob0 = qml.math.norm(state[tuple(slices)]) ** 2
prob0 = qml.math.real(qml.math.norm(state[tuple(slices)])) ** 2

if prng_key is not None:
# pylint: disable=import-outside-toplevel
Expand Down
1 change: 1 addition & 0 deletions pennylane/math/single_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def _take_autograd(tensor, indices, axis=None):
ar.autoray._SUBMODULE_ALIASES["tensorflow", "arctan"] = "tensorflow.math"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "arctan2"] = "tensorflow.math"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "mod"] = "tensorflow.math"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "logical_and"] = "tensorflow.math"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "kron"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "moveaxis"] = "tensorflow.experimental.numpy"
ar.autoray._SUBMODULE_ALIASES["tensorflow", "sinc"] = "tensorflow.experimental.numpy"
Expand Down
2 changes: 1 addition & 1 deletion pennylane/measurements/mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def __mul__(self, other):
return self._transform_bin_op(lambda a, b: a * b, other)

def __rmul__(self, other):
return self._apply(lambda v: other * v)
return self._apply(lambda v: other * qml.math.cast_like(v, other))

def __truediv__(self, other):
return self._transform_bin_op(lambda a, b: a / b, other)
Expand Down
96 changes: 67 additions & 29 deletions pennylane/transforms/dynamic_one_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def func(x, y):

aux_tapes = [init_auxiliary_tape(t) for t in tapes]

interface = kwargs.get("interface", None)
postselect_mode = kwargs.get("postselect_mode", None)

def reshape_data(array):
return qml.math.squeeze(qml.math.vstack(array))
Expand Down Expand Up @@ -161,7 +161,9 @@ def processing_fn(results, has_partitioned_shots=None, batched_results=None):
results = [
reshape_data(tuple(res[i] for res in results)) for i, _ in enumerate(results[0])
]
return parse_native_mid_circuit_measurements(tape, aux_tapes, results, interface=interface)
return parse_native_mid_circuit_measurements(
tape, aux_tapes, results, postselect_mode=postselect_mode
)

return aux_tapes, processing_fn

Expand Down Expand Up @@ -227,7 +229,7 @@ def parse_native_mid_circuit_measurements(
circuit: qml.tape.QuantumScript,
aux_tapes: qml.tape.QuantumScript,
results: TensorLike,
interface=None,
postselect_mode=None,
):
"""Combines, gathers and normalizes the results of native mid-circuit measurement runs.
Expand All @@ -247,20 +249,27 @@ def measurement_with_no_shots(measurement):
else np.nan
)

interface = interface or qml.math.get_deep_interface(circuit.data)
interface = qml.math.get_deep_interface(results)
interface = "numpy" if interface == "builtins" else interface
interface = "tensorflow" if interface == "tf" else interface
active_qjit = qml.compiler.active()

all_mcms = [op for op in aux_tapes[0].operations if is_mcm(op)]
n_mcms = len(all_mcms)
mcm_samples = qml.math.hstack(tuple(res.reshape((-1, 1)) for res in results[-n_mcms:]))
mcm_samples = qml.math.hstack(
tuple(qml.math.reshape(res, (-1, 1)) for res in results[-n_mcms:])
)
mcm_samples = qml.math.array(mcm_samples, like=interface)
# Can't use boolean dtype array with tf, hence why conditionally setting items to 0 or 1
has_postselect = qml.math.array(
[[int(op.postselect is not None) for op in all_mcms]], like=interface
[[op.postselect is not None for op in all_mcms]],
like=interface,
dtype=mcm_samples.dtype,
)
postselect = qml.math.array(
[[0 if op.postselect is None else op.postselect for op in all_mcms]], like=interface
[[0 if op.postselect is None else op.postselect for op in all_mcms]],
like=interface,
dtype=mcm_samples.dtype,
)
is_valid = qml.math.all(mcm_samples * has_postselect == postselect, axis=1)
has_valid = qml.math.any(is_valid)
Expand All @@ -277,9 +286,11 @@ def measurement_with_no_shots(measurement):
if interface != "jax" and m.mv and not has_valid:
meas = measurement_with_no_shots(m)
elif m.mv and active_qjit:
meas = gather_mcm_qjit(m, mcm_samples, is_valid) # pragma: no cover
meas = gather_mcm_qjit(
m, mcm_samples, is_valid, postselect_mode=postselect_mode
) # pragma: no cover
elif m.mv:
meas = gather_mcm(m, mcm_samples, is_valid)
meas = gather_mcm(m, mcm_samples, is_valid, postselect_mode=postselect_mode)
elif interface != "jax" and not has_valid:
meas = measurement_with_no_shots(m)
m_count += 1
Expand All @@ -296,12 +307,15 @@ def measurement_with_no_shots(measurement):
# We return the sum of counts (`result[1]`) weighting by `is_valid`, which is `0` for invalid samples
if isinstance(m, CountsMP):
normalized_meas.append(
(result[0][0], qml.math.sum(result[1] * is_valid.reshape((-1, 1)), axis=0))
(
result[0][0],
qml.math.sum(result[1] * qml.math.reshape(is_valid, (-1, 1)), axis=0),
)
)
m_count += 1
continue
result = qml.math.squeeze(result)
meas = gather_non_mcm(m, result, is_valid)
meas = gather_non_mcm(m, result, is_valid, postselect_mode=postselect_mode)
m_count += 1
if isinstance(m, SampleMP):
meas = qml.math.squeeze(meas)
Expand All @@ -310,7 +324,7 @@ def measurement_with_no_shots(measurement):
return tuple(normalized_meas) if len(normalized_meas) > 1 else normalized_meas[0]


def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover
def gather_mcm_qjit(measurement, samples, is_valid, postselect_mode=None): # pragma: no cover
"""Process MCM measurements when the Catalyst compiler is active.
Args:
Expand All @@ -331,7 +345,7 @@ def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover
raise LookupError("MCM not found")
meas = qml.math.squeeze(meas)
if isinstance(measurement, (CountsMP, ProbabilityMP)):
interface = qml.math.get_deep_interface(is_valid)
interface = qml.math.get_interface(is_valid)
sum_valid = qml.math.sum(is_valid)
count_1 = qml.math.sum(meas * is_valid)
if isinstance(measurement, CountsMP):
Expand All @@ -341,10 +355,10 @@ def gather_mcm_qjit(measurement, samples, is_valid): # pragma: no cover
if isinstance(measurement, ProbabilityMP):
counts = qml.math.array([sum_valid - count_1, count_1], like=interface)
return counts / sum_valid
return gather_non_mcm(measurement, meas, is_valid)
return gather_non_mcm(measurement, meas, is_valid, postselect_mode=postselect_mode)


def gather_non_mcm(measurement, samples, is_valid):
def gather_non_mcm(measurement, samples, is_valid, postselect_mode=None):
"""Combines, gathers and normalizes several measurements with trivial measurement values.
Args:
Expand All @@ -365,25 +379,39 @@ def gather_non_mcm(measurement, samples, is_valid):
if not measurement.all_outcomes:
tmp = Counter({k: v for k, v in tmp.items() if v > 0})
return dict(sorted(tmp.items()))
if isinstance(measurement, ExpectationMP):
return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid)
if isinstance(measurement, ProbabilityMP):
return qml.math.sum(samples * is_valid.reshape((-1, 1)), axis=0) / qml.math.sum(is_valid)

if isinstance(measurement, SampleMP):
is_interface_jax = qml.math.get_deep_interface(is_valid) == "jax"
if is_interface_jax and samples.ndim == 2:
is_valid = is_valid.reshape((-1, 1))
if postselect_mode == "pad-invalid-samples" and samples.ndim == 2:
is_valid = qml.math.reshape(is_valid, (-1, 1))
return (
qml.math.where(is_valid, samples, fill_in_value)
if is_interface_jax
if postselect_mode == "pad-invalid-samples"
else samples[is_valid]
)

if (interface := qml.math.get_interface(is_valid)) == "tensorflow":
# Tensorflow requires arrays that are used for arithmetic with each other to have the
# same dtype. We don't cast if measuring samples as float tf.Tensors cannot be used to
# index other tf.Tensors (is_valid is used to index valid samples).
is_valid = qml.math.cast_like(is_valid, samples)

if isinstance(measurement, ExpectationMP):
return qml.math.sum(samples * is_valid) / qml.math.sum(is_valid)
if isinstance(measurement, ProbabilityMP):
return qml.math.sum(samples * qml.math.reshape(is_valid, (-1, 1)), axis=0) / qml.math.sum(
is_valid
)

# VarianceMP
expval = qml.math.sum(samples * is_valid) / qml.math.sum(is_valid)
if interface == "tensorflow":
# Casting needed for tensorflow
samples = qml.math.cast_like(samples, expval)
is_valid = qml.math.cast_like(is_valid, expval)
return qml.math.sum((samples - expval) ** 2 * is_valid) / qml.math.sum(is_valid)


def gather_mcm(measurement, samples, is_valid):
def gather_mcm(measurement, samples, is_valid, postselect_mode=None):
"""Combines, gathers and normalizes several measurements with non-trivial measurement values.
Args:
Expand All @@ -404,20 +432,30 @@ def gather_mcm(measurement, samples, is_valid):
if isinstance(measurement, ProbabilityMP):
values = [list(m.branches.values()) for m in mv]
values = list(itertools.product(*values))
values = [qml.math.array([v], like=interface) for v in values]
values = [qml.math.array([v], like=interface, dtype=mcm_samples.dtype) for v in values]
# Need to use boolean functions explicitly as Tensorflow does not allow integer math
# on boolean arrays
counts = [
qml.math.sum(qml.math.all(mcm_samples == v, axis=1) * is_valid) for v in values
qml.math.count_nonzero(
qml.math.logical_and(qml.math.all(mcm_samples == v, axis=1), is_valid)
)
for v in values
]
counts = qml.math.array(counts, like=interface)
return counts / qml.math.sum(counts)
if isinstance(measurement, CountsMP):
mcm_samples = [{"".join(str(int(v)) for v in tuple(s)): 1} for s in mcm_samples]
return gather_non_mcm(measurement, mcm_samples, is_valid)
return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_mode=postselect_mode)
mcm_samples = qml.math.ravel(qml.math.array(mv.concretize(samples), like=interface))
if isinstance(measurement, ProbabilityMP):
counts = [qml.math.sum((mcm_samples == v) * is_valid) for v in list(mv.branches.values())]
# Need to use boolean functions explicitly as Tensorflow does not allow integer math
# on boolean arrays
counts = [
qml.math.count_nonzero(qml.math.logical_and((mcm_samples == v), is_valid))
for v in list(mv.branches.values())
]
counts = qml.math.array(counts, like=interface)
return counts / qml.math.sum(counts)
if isinstance(measurement, CountsMP):
mcm_samples = [{float(s): 1} for s in mcm_samples]
return gather_non_mcm(measurement, mcm_samples, is_valid)
return gather_non_mcm(measurement, mcm_samples, is_valid, postselect_mode=postselect_mode)
43 changes: 33 additions & 10 deletions pennylane/workflow/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,28 @@ def _deprecated_arguments_warnings(
return tapes, override_shots, expand_fn, max_expansion, device_batch_transform


def _update_mcm_config(mcm_config: "qml.devices.MCMConfig", interface: str, finite_shots: bool):
"""Helper function to update the mid-circuit measurements configuration based on
execution parameters"""
if interface == "jax-jit" and mcm_config.mcm_method == "deferred":
# This is a current limitation of defer_measurements. "hw-like" behaviour is
# not yet accessible.
if mcm_config.postselect_mode == "hw-like":
raise ValueError(
"Using postselect_mode='hw-like' is not supported with jax-jit when using "
"mcm_method='deferred'."
)
mcm_config.postselect_mode = "fill-shots"

if (
finite_shots
and "jax" in interface
and mcm_config.mcm_method in (None, "one-shot")
and mcm_config.postselect_mode in (None, "hw-like")
):
mcm_config.postselect_mode = "pad-invalid-samples"


def execute(
tapes: QuantumTapeBatch,
device: device_type,
Expand Down Expand Up @@ -697,16 +719,17 @@ def cost_fn(params, x):
)

# Mid-circuit measurement configuration validation
mcm_interface = _get_interface_name(tapes, "auto") if interface is None else interface
if mcm_interface == "jax-jit" and config.mcm_config.mcm_method == "deferred":
# This is a current limitation of defer_measurements. "hw-like" behaviour is
# not yet accessible.
if config.mcm_config.postselect_mode == "hw-like":
raise ValueError(
"Using postselect_mode='hw-like' is not supported with jax-jit when using "
"mcm_method='deferred'."
)
config.mcm_config.postselect_mode = "fill-shots"
mcm_interface = interface or _get_interface_name(tapes, "auto")
finite_shots = (
(
qml.measurements.Shots(device.shots)
if isinstance(device, qml.devices.LegacyDevice)
else device.shots
)
if override_shots is False
else override_shots
)
_update_mcm_config(config.mcm_config, mcm_interface, finite_shots)

is_gradient_transform = isinstance(gradient_fn, qml.transforms.core.TransformDispatcher)
transform_program, inner_transform = _make_transform_programs(
Expand Down
Loading

0 comments on commit eb7c349

Please sign in to comment.