Skip to content

Commit

Permalink
feat: Device-specific supported observables (#276)
Browse files Browse the repository at this point in the history
  • Loading branch information
speller26 authored Sep 11, 2024
1 parent 7038db2 commit 4bb65ce
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 11 deletions.
8 changes: 3 additions & 5 deletions src/braket/pennylane_plugin/braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@

from braket.pennylane_plugin.translation import (
get_adjoint_gradient_result_type,
supported_observables,
supported_operations,
translate_operation,
translate_result,
Expand Down Expand Up @@ -156,6 +157,7 @@ def __init__(
self._parametrize_differentiable = parametrize_differentiable
self._run_kwargs = run_kwargs
self._supported_ops = supported_operations(self._device, verbatim=verbatim)
self._supported_obs = supported_observables(self._device, self.shots)
self._check_supported_result_types()
self._verbatim = verbatim

Expand All @@ -174,11 +176,7 @@ def operations(self) -> frozenset[str]:

@property
def observables(self) -> frozenset[str]:
base_observables = frozenset(super().observables)
# Amazon Braket only supports scalar multiplication and addition when shots==0
if not self.shots:
return base_observables.union({"Hamiltonian", "LinearCombination"})
return base_observables
return self._supported_obs

@property
def circuit(self) -> Circuit:
Expand Down
25 changes: 24 additions & 1 deletion src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
StateVector,
Variance,
)
from braket.device_schema import DeviceActionType
from braket.devices import Device
from braket.pulse import ArbitraryWaveform, ConstantWaveform, PulseSequence
from braket.tasks import GateModelQuantumTaskResult
Expand Down Expand Up @@ -96,6 +97,16 @@
}


_BRAKET_TO_PENNYLANE_OBSERVABLES = {
"x": frozenset({"PauliX"}),
"y": frozenset({"PauliY"}),
"z": frozenset({"PauliZ"}),
"h": frozenset({"Hadamard"}),
"hermitian": frozenset({"Hermitian", "Projector"}),
"i": frozenset({"Identity"}),
}


def supported_operations(device: Device, verbatim: bool = False) -> frozenset[str]:
"""Returns the operations supported by the plugin based upon the device.
Expand All @@ -111,7 +122,7 @@ def supported_operations(device: Device, verbatim: bool = False) -> frozenset[st
properties = (
device.properties.paradigm
if verbatim
else device.properties.action["braket.ir.openqasm.program"]
else device.properties.action[DeviceActionType.OPENQASM]
)
except AttributeError:
raise AttributeError("Device needs to have properties defined.")
Expand Down Expand Up @@ -514,6 +525,18 @@ def waveform(dt):
return gates.PulseGate(pulse_sequence, qubit_count=len(op.wires))


def supported_observables(device: Device, shots: int) -> frozenset[str]:
action = device.properties.action[DeviceActionType.OPENQASM]
braket_observables = set.union(
*[set(r.observables) for r in action.supportedResultTypes if r.observables]
)
supported = frozenset.union(
*[_BRAKET_TO_PENNYLANE_OBSERVABLES[braket_obs] for braket_obs in braket_observables],
)
supported |= {"Prod", "SProd"}
return supported if shots else supported | {"Sum", "Hamiltonian", "LinearCombination"}


def get_adjoint_gradient_result_type(
observable: Observable,
targets: Union[list[int], list[list[int]]],
Expand Down
16 changes: 13 additions & 3 deletions test/unit_tests/device_property_jsons.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
{"name": "StateVector", "observables": None, "minShots": 0, "maxShots": 0},
{
"name": "AdjointGradient",
"observables": ["x", "y", "z", "h", "i"],
"observables": ["x", "y", "z", "h", "i", "hermitian"],
"minShots": 0,
"maxShots": 0,
},
Expand All @@ -43,7 +43,12 @@
"version": ["1"],
"supportedOperations": ["rx", "ry", "h", "cy", "cnot", "unitary"],
"supportedResultTypes": [
{"name": "StateVector", "observables": None, "minShots": 0, "maxShots": 0},
{
"name": "StateVector",
"observables": ["x", "y", "z"],
"minShots": 0,
"maxShots": 0,
},
],
}
)
Expand All @@ -56,7 +61,12 @@
"version": ["1"],
"supportedOperations": ["rx", "ry", "h", "cy", "cnot", "unitary"],
"supportedResultTypes": [
{"name": "StateVector", "observables": None, "minShots": 0, "maxShots": 0},
{
"name": "StateVector",
"observables": ["x", "y", "z"],
"minShots": 0,
"maxShots": 0,
},
],
"supportedPragmas": [
"braket_noise_bit_flip",
Expand Down
2 changes: 1 addition & 1 deletion test/unit_tests/test_braket_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,7 +1935,7 @@ def properties(self) -> GateModelSimulatorDeviceCapabilities:
"supportedResultTypes": [
{
"name": "resultType1",
"observables": ["observable1"],
"observables": ["z"],
"minShots": 2,
"maxShots": 4,
}
Expand Down
2 changes: 1 addition & 1 deletion test/unit_tests/test_shadow_expval.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def properties(self) -> GateModelSimulatorDeviceCapabilities:
"supportedResultTypes": [
{
"name": "resultType1",
"observables": ["observable1"],
"observables": ["z"],
"minShots": 2,
"maxShots": 4,
}
Expand Down

0 comments on commit 4bb65ce

Please sign in to comment.