Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize memory peak for _preprocess_state_vector in LightningTensor #943

Merged
merged 15 commits into from
Oct 15, 2024
Merged
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.39.0-dev42"
__version__ = "0.39.0-dev43"
33 changes: 23 additions & 10 deletions pennylane_lightning/lightning_tensor/_tensornet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
except ImportError:
pass

from itertools import product

import numpy as np
import pennylane as qml
from pennylane import BasisState, DeviceError, StatePrep
Expand Down Expand Up @@ -223,20 +221,35 @@
if len(device_wires) == self._num_wires and Wires(sorted(device_wires)) == device_wires:
return np.reshape(state, output_shape).ravel(order="C")

# generate basis states on subset of qubits via the cartesian product
basis_states = np.array(list(product([0, 1], repeat=len(device_wires))))
local_dev_wires = np.array(device_wires.tolist())[::-1]

Check warning on line 224 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L224

Added line #L224 was not covered by tests
LuisAlfredoNu marked this conversation as resolved.
Show resolved Hide resolved
mlxd marked this conversation as resolved.
Show resolved Hide resolved

# generate basis states on subset of qubits via broadcasting
base = np.tile([0, 1], 2 ** (len(local_dev_wires) - 1)).astype(dtype=np.int64)
indexes = np.zeros(2 ** (len(local_dev_wires)), dtype=np.int64)

Check warning on line 228 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L227-L228

Added lines #L227 - L228 were not covered by tests

max_dev_wire = self._num_wires - 1

Check warning on line 230 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L230

Added line #L230 was not covered by tests

# get basis states to alter on full set of qubits
unravelled_indices = np.zeros((2 ** len(device_wires), self._num_wires), dtype=int)
unravelled_indices[:, device_wires] = basis_states
for i, wire in enumerate(local_dev_wires):

Check warning on line 233 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L233

Added line #L233 was not covered by tests

# get indices for which the state is changed to input state vector elements
indexes += base * 2 ** (max_dev_wire - wire)

Check warning on line 236 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L236

Added line #L236 was not covered by tests

if i == len(local_dev_wires) - 1:
continue

Check warning on line 239 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L238-L239

Added lines #L238 - L239 were not covered by tests

two_n = 2 ** (i + 1)
base = base.reshape(-1, two_n * 2)
swaper_A = two_n // 2
swaper_B = swaper_A + two_n

Check warning on line 244 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L241-L244

Added lines #L241 - L244 were not covered by tests
LuisAlfredoNu marked this conversation as resolved.
Show resolved Hide resolved
mlxd marked this conversation as resolved.
Show resolved Hide resolved

# get indices for which the state is changed to input state vector elements
ravelled_indices = np.ravel_multi_index(unravelled_indices.T, [2] * self._num_wires)
base[:, swaper_A:swaper_B] = base[:, swaper_A:swaper_B][:, ::-1]
base = base.reshape(-1)

Check warning on line 247 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L246-L247

Added lines #L246 - L247 were not covered by tests

# get full state vector to be factorized into MPS
full_state = np.zeros(2**self._num_wires, dtype=self.dtype)
for i, value in enumerate(state):
full_state[ravelled_indices[i]] = value
full_state[indexes[i]] = value

Check warning on line 252 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L252

Added line #L252 was not covered by tests
return np.reshape(full_state, output_shape).ravel(order="C")

def _apply_state_vector(self, state, device_wires: Wires):
Expand Down Expand Up @@ -285,7 +298,7 @@
None
"""
# TODO: Discuss if public interface for max_mpo_bond_dim argument
max_mpo_bond_dim = 2 ** len(wires) # Exact SVD decomposition for MPO
max_mpo_bond_dim = self._max_bond_dim

Check warning on line 301 in pennylane_lightning/lightning_tensor/_tensornet.py

View check run for this annotation

Codecov / codecov/patch

pennylane_lightning/lightning_tensor/_tensornet.py#L301

Added line #L301 was not covered by tests
mlxd marked this conversation as resolved.
Show resolved Hide resolved

# Get sorted wires and MPO site tensor
mpos, sorted_wires = gate_matrix_decompose(
Expand Down
Loading