From 4bb65cefd5819b616f0fa322239a8af93647df18 Mon Sep 17 00:00:00 2001 From: Cody Wang Date: Wed, 11 Sep 2024 14:57:24 -0700 Subject: [PATCH] feat: Device-specific supported observables (#276) --- src/braket/pennylane_plugin/braket_device.py | 8 +++---- src/braket/pennylane_plugin/translation.py | 25 +++++++++++++++++++- test/unit_tests/device_property_jsons.py | 16 ++++++++++--- test/unit_tests/test_braket_device.py | 2 +- test/unit_tests/test_shadow_expval.py | 2 +- 5 files changed, 42 insertions(+), 11 deletions(-) diff --git a/src/braket/pennylane_plugin/braket_device.py b/src/braket/pennylane_plugin/braket_device.py index abfcfb4c..fb65de05 100644 --- a/src/braket/pennylane_plugin/braket_device.py +++ b/src/braket/pennylane_plugin/braket_device.py @@ -72,6 +72,7 @@ from braket.pennylane_plugin.translation import ( get_adjoint_gradient_result_type, + supported_observables, supported_operations, translate_operation, translate_result, @@ -156,6 +157,7 @@ def __init__( self._parametrize_differentiable = parametrize_differentiable self._run_kwargs = run_kwargs self._supported_ops = supported_operations(self._device, verbatim=verbatim) + self._supported_obs = supported_observables(self._device, self.shots) self._check_supported_result_types() self._verbatim = verbatim @@ -174,11 +176,7 @@ def operations(self) -> frozenset[str]: @property def observables(self) -> frozenset[str]: - base_observables = frozenset(super().observables) - # Amazon Braket only supports scalar multiplication and addition when shots==0 - if not self.shots: - return base_observables.union({"Hamiltonian", "LinearCombination"}) - return base_observables + return self._supported_obs @property def circuit(self) -> Circuit: diff --git a/src/braket/pennylane_plugin/translation.py b/src/braket/pennylane_plugin/translation.py index 7d1fc2ac..589cecf5 100644 --- a/src/braket/pennylane_plugin/translation.py +++ b/src/braket/pennylane_plugin/translation.py @@ -28,6 +28,7 @@ StateVector, Variance, ) +from braket.device_schema import DeviceActionType from braket.devices import Device from braket.pulse import ArbitraryWaveform, ConstantWaveform, PulseSequence from braket.tasks import GateModelQuantumTaskResult @@ -96,6 +97,16 @@ } +_BRAKET_TO_PENNYLANE_OBSERVABLES = { + "x": frozenset({"PauliX"}), + "y": frozenset({"PauliY"}), + "z": frozenset({"PauliZ"}), + "h": frozenset({"Hadamard"}), + "hermitian": frozenset({"Hermitian", "Projector"}), + "i": frozenset({"Identity"}), +} + + def supported_operations(device: Device, verbatim: bool = False) -> frozenset[str]: """Returns the operations supported by the plugin based upon the device. @@ -111,7 +122,7 @@ def supported_operations(device: Device, verbatim: bool = False) -> frozenset[st properties = ( device.properties.paradigm if verbatim - else device.properties.action["braket.ir.openqasm.program"] + else device.properties.action[DeviceActionType.OPENQASM] ) except AttributeError: raise AttributeError("Device needs to have properties defined.") @@ -514,6 +525,18 @@ def waveform(dt): return gates.PulseGate(pulse_sequence, qubit_count=len(op.wires)) +def supported_observables(device: Device, shots: int) -> frozenset[str]: + action = device.properties.action[DeviceActionType.OPENQASM] + braket_observables = set.union( + *[set(r.observables) for r in action.supportedResultTypes if r.observables] + ) + supported = frozenset.union( + *[_BRAKET_TO_PENNYLANE_OBSERVABLES[braket_obs] for braket_obs in braket_observables], + ) + supported |= {"Prod", "SProd"} + return supported if shots else supported | {"Sum", "Hamiltonian", "LinearCombination"} + + def get_adjoint_gradient_result_type( observable: Observable, targets: Union[list[int], list[list[int]]], diff --git a/test/unit_tests/device_property_jsons.py b/test/unit_tests/device_property_jsons.py index 2138171b..124b1e99 100644 --- a/test/unit_tests/device_property_jsons.py +++ b/test/unit_tests/device_property_jsons.py @@ -27,7 +27,7 @@ {"name": "StateVector", "observables": None, "minShots": 0, "maxShots": 0}, { "name": "AdjointGradient", - "observables": ["x", "y", "z", "h", "i"], + "observables": ["x", "y", "z", "h", "i", "hermitian"], "minShots": 0, "maxShots": 0, }, @@ -43,7 +43,12 @@ "version": ["1"], "supportedOperations": ["rx", "ry", "h", "cy", "cnot", "unitary"], "supportedResultTypes": [ - {"name": "StateVector", "observables": None, "minShots": 0, "maxShots": 0}, + { + "name": "StateVector", + "observables": ["x", "y", "z"], + "minShots": 0, + "maxShots": 0, + }, ], } ) @@ -56,7 +61,12 @@ "version": ["1"], "supportedOperations": ["rx", "ry", "h", "cy", "cnot", "unitary"], "supportedResultTypes": [ - {"name": "StateVector", "observables": None, "minShots": 0, "maxShots": 0}, + { + "name": "StateVector", + "observables": ["x", "y", "z"], + "minShots": 0, + "maxShots": 0, + }, ], "supportedPragmas": [ "braket_noise_bit_flip", diff --git a/test/unit_tests/test_braket_device.py b/test/unit_tests/test_braket_device.py index f62825a1..2b166717 100644 --- a/test/unit_tests/test_braket_device.py +++ b/test/unit_tests/test_braket_device.py @@ -1935,7 +1935,7 @@ def properties(self) -> GateModelSimulatorDeviceCapabilities: "supportedResultTypes": [ { "name": "resultType1", - "observables": ["observable1"], + "observables": ["z"], "minShots": 2, "maxShots": 4, } diff --git a/test/unit_tests/test_shadow_expval.py b/test/unit_tests/test_shadow_expval.py index 2c1f4252..d50393f5 100644 --- a/test/unit_tests/test_shadow_expval.py +++ b/test/unit_tests/test_shadow_expval.py @@ -395,7 +395,7 @@ def properties(self) -> GateModelSimulatorDeviceCapabilities: "supportedResultTypes": [ { "name": "resultType1", - "observables": ["observable1"], + "observables": ["z"], "minShots": 2, "maxShots": 4, }