diff --git a/pennylane/devices/execution_config.py b/pennylane/devices/execution_config.py index 7f3866d9e86..0dea17a6387 100644 --- a/pennylane/devices/execution_config.py +++ b/pennylane/devices/execution_config.py @@ -17,6 +17,7 @@ from dataclasses import dataclass, field from typing import Optional, Union +from pennylane.transforms.core import TransformDispatcher from pennylane.workflow import SUPPORTED_INTERFACE_NAMES @@ -86,7 +87,7 @@ class ExecutionConfig: ``True`` indicates to either use the device Jacobian products or fail. """ - gradient_method: Optional[str] = None + gradient_method: Optional[Union[str, TransformDispatcher]] = None """The method used to compute the gradient of the quantum circuit being executed""" gradient_keyword_arguments: Optional[dict] = None @@ -126,6 +127,14 @@ def __post_init__(self): if self.gradient_keyword_arguments is None: self.gradient_keyword_arguments = {} + if not ( + isinstance(self.gradient_method, (str, TransformDispatcher)) + or self.gradient_method is None + ): + raise ValueError( + f"gradient_method must be a str, TransformDispatcher, or None. Got {type(self.gradient_method)} instead." + ) + if isinstance(self.mcm_config, dict): self.mcm_config = MCMConfig(**self.mcm_config) elif not isinstance(self.mcm_config, MCMConfig): diff --git a/tests/devices/experimental/test_execution_config.py b/tests/devices/experimental/test_execution_config.py index 361712c112e..f6019efa716 100644 --- a/tests/devices/experimental/test_execution_config.py +++ b/tests/devices/experimental/test_execution_config.py @@ -18,6 +18,7 @@ import pytest from pennylane.devices.execution_config import ExecutionConfig, MCMConfig +from pennylane.gradients import param_shift def test_default_values(): @@ -90,3 +91,18 @@ def test_mcm_config_invalid_postselect_mode(): option = "foo" with pytest.raises(ValueError, match="Invalid postselection mode"): _ = MCMConfig(postselect_mode=option) + + +@pytest.mark.parametrize("method", ("parameter-shift", None, param_shift)) +def test_valid_gradient_method(method): + """Test valid gradient_method types.""" + config = ExecutionConfig(gradient_method=method) + assert config.gradient_method == method + + +def test_invalid_gradient_method(): + """Test that invalid types for gradient_method raise an error.""" + with pytest.raises( + ValueError, match=r"gradient_method must be a str, TransformDispatcher, or None" + ): + ExecutionConfig(gradient_method=123)