From e628968f8567858922a04dc838a94096adf401f2 Mon Sep 17 00:00:00 2001 From: Shuli Shu <08cnbj@gmail.com> Date: Wed, 12 Jun 2024 17:43:23 +0000 Subject: [PATCH] update arg name to align with default.tensor --- pennylane_lightning/lightning_tensor/_tensornet.py | 6 +++--- pennylane_lightning/lightning_tensor/lightning_tensor.py | 2 +- tests/lightning_tensor/test_measurements_class.py | 2 +- tests/lightning_tensor/test_tensornet_class.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pennylane_lightning/lightning_tensor/_tensornet.py b/pennylane_lightning/lightning_tensor/_tensornet.py index bda160f3fd..5b5a28170e 100644 --- a/pennylane_lightning/lightning_tensor/_tensornet.py +++ b/pennylane_lightning/lightning_tensor/_tensornet.py @@ -51,7 +51,7 @@ def __init__( self, num_wires, method: str = "mps", - dtype=np.complex128, + c_dtype=np.complex128, max_bond_dim: int = 128, cutoff: float = 0, cutoff_mode: str = "abs", @@ -62,7 +62,7 @@ def __init__( self._method = method self._cutoff = cutoff self._cutoff_mode = cutoff_mode - self._dtype = dtype + self._c_dtype = c_dtype if device_name != "lightning.tensor": raise DeviceError(f'The device name "{device_name}" is not a valid option.') @@ -73,7 +73,7 @@ def __init__( @property def dtype(self): """Returns the tensor network data type.""" - return self._dtype + return self._c_dtype @property def device_name(self): diff --git a/pennylane_lightning/lightning_tensor/lightning_tensor.py b/pennylane_lightning/lightning_tensor/lightning_tensor.py index 7887c08af6..b0db8f84b4 100644 --- a/pennylane_lightning/lightning_tensor/lightning_tensor.py +++ b/pennylane_lightning/lightning_tensor/lightning_tensor.py @@ -192,7 +192,7 @@ class LightningTensor(Device): Defaults to ``None`` if not specified. method (str): Supported method. Currently, only ``mps`` is supported. c_dtype: Datatypes for the tensor representation. Must be one of - ``np.complex64`` or ``np.complex128``. Default is ``np.complex128``. + ``numpy.complex64`` or ``numpy.complex128``. Default is ``numpy.complex128``. Keyword Args: max_bond_dim (int): The maximum bond dimension to be used in the MPS simulation. Default is 128. The accuracy of the wavefunction representation comes with a memory tradeoff which can be diff --git a/tests/lightning_tensor/test_measurements_class.py b/tests/lightning_tensor/test_measurements_class.py index 8149725534..6d9574aa83 100644 --- a/tests/lightning_tensor/test_measurements_class.py +++ b/tests/lightning_tensor/test_measurements_class.py @@ -40,7 +40,7 @@ ) def lightning_tn(request): """Fixture for creating a LightningTensorNet object.""" - return LightningTensorNet(num_wires=5, max_bond_dim=128, dtype=request.param) + return LightningTensorNet(num_wires=5, max_bond_dim=128, c_dtype=request.param) class TestMeasurementFunction: diff --git a/tests/lightning_tensor/test_tensornet_class.py b/tests/lightning_tensor/test_tensornet_class.py index 769ba82c14..9e2113fb2d 100644 --- a/tests/lightning_tensor/test_tensornet_class.py +++ b/tests/lightning_tensor/test_tensornet_class.py @@ -39,7 +39,7 @@ @pytest.mark.parametrize("device_name", ["lightning.tensor"]) def test_device_name_and_init(num_wires, bondDims, dtype, device_name): """Test the class initialization and returned properties.""" - tensornet = LightningTensorNet(num_wires, bondDims, dtype=dtype, device_name=device_name) + tensornet = LightningTensorNet(num_wires, bondDims, c_dtype=dtype, device_name=device_name) assert tensornet.dtype == dtype assert tensornet.device_name == device_name assert tensornet.num_wires == num_wires