Skip to content

Commit

Permalink
update arg name to align with default.tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
multiphaseCFD committed Jun 12, 2024
1 parent 974e0cf commit e628968
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions pennylane_lightning/lightning_tensor/_tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self,
num_wires,
method: str = "mps",
dtype=np.complex128,
c_dtype=np.complex128,
max_bond_dim: int = 128,
cutoff: float = 0,
cutoff_mode: str = "abs",
Expand All @@ -62,7 +62,7 @@ def __init__(
self._method = method
self._cutoff = cutoff
self._cutoff_mode = cutoff_mode
self._dtype = dtype
self._c_dtype = c_dtype

Check warning on line 65 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L60-L65

Added lines #L60 - L65 were not covered by tests

if device_name != "lightning.tensor":
raise DeviceError(f'The device name "{device_name}" is not a valid option.')

Check warning on line 68 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L67-L68

Added lines #L67 - L68 were not covered by tests
Expand All @@ -73,7 +73,7 @@ def __init__(
@property
def dtype(self):
"""Returns the tensor network data type."""
return self._dtype
return self._c_dtype

Check warning on line 76 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L76

Added line #L76 was not covered by tests

@property
def device_name(self):
Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/lightning_tensor/lightning_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class LightningTensor(Device):
Defaults to ``None`` if not specified.
method (str): Supported method. Currently, only ``mps`` is supported.
c_dtype: Datatypes for the tensor representation. Must be one of
``np.complex64`` or ``np.complex128``. Default is ``np.complex128``.
``numpy.complex64`` or ``numpy.complex128``. Default is ``numpy.complex128``.
Keyword Args:
max_bond_dim (int): The maximum bond dimension to be used in the MPS simulation. Default is 128.
The accuracy of the wavefunction representation comes with a memory tradeoff which can be
Expand Down
2 changes: 1 addition & 1 deletion tests/lightning_tensor/test_measurements_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
)
def lightning_tn(request):
"""Fixture for creating a LightningTensorNet object."""
return LightningTensorNet(num_wires=5, max_bond_dim=128, dtype=request.param)
return LightningTensorNet(num_wires=5, max_bond_dim=128, c_dtype=request.param)


class TestMeasurementFunction:
Expand Down
2 changes: 1 addition & 1 deletion tests/lightning_tensor/test_tensornet_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
@pytest.mark.parametrize("device_name", ["lightning.tensor"])
def test_device_name_and_init(num_wires, bondDims, dtype, device_name):
"""Test the class initialization and returned properties."""
tensornet = LightningTensorNet(num_wires, bondDims, dtype=dtype, device_name=device_name)
tensornet = LightningTensorNet(num_wires, bondDims, c_dtype=dtype, device_name=device_name)
assert tensornet.dtype == dtype
assert tensornet.device_name == device_name
assert tensornet.num_wires == num_wires
Expand Down

0 comments on commit e628968

Please sign in to comment.