Skip to content

Commit

Permalink
Allow mid-circuit stateprep.
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentmr committed Sep 6, 2023
1 parent 8eed6f2 commit 9453a02
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 32 deletions.
5 changes: 5 additions & 0 deletions pennylane_lightning/lightning_qubit/lightning_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,11 @@ def apply(self, operations, rotations=None, **kwargs):
self._apply_basis_state(operations[0].parameters[0], operations[0].wires)
operations = operations[1:]

if any(isinstance(op, (BasisState, StatePrep)) for op in operations):
tape = qml.tape.QuantumTape(ops=operations, measurements=[])
tape = qml.tape.expand_tape_state_prep(tape, skip_first=False, force_decompose=True)
operations = tape._ops

for operation in operations:
if isinstance(operation, (StatePrep, BasisState)):
raise DeviceError(
Expand Down
111 changes: 79 additions & 32 deletions tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,38 +454,38 @@ def test_apply_operation_preserve_pointer_two_wires_with_parameters(

assert pointer_before == pointer_after

@pytest.mark.parametrize("stateprep", [qml.QubitStateVector, qml.StatePrep])
def test_apply_errors_qubit_state_vector(self, stateprep, qubit_device):
"""Test that apply fails for incorrect state preparation, and > 2 qubit gates"""
dev = qubit_device(wires=2)
with pytest.raises(ValueError, match="Sum of amplitudes-squared does not equal one."):
dev.apply([stateprep(np.array([1, -1]), wires=[0])])

with pytest.raises(
DeviceError,
match=f"Operation {stateprep(np.array([1, 0]), wires=[0]).name} cannot be used after other Operations have already been applied ",
):
dev.reset()
dev.apply([qml.RZ(0.5, wires=[0]), stateprep(np.array([0, 1, 0, 0]), wires=[0, 1])])

def test_apply_errors_basis_state(self, qubit_device):
dev = qubit_device(wires=2)
with pytest.raises(
ValueError, match="BasisState parameter must consist of 0 or 1 integers."
):
dev.apply([qml.BasisState(np.array([-0.2, 4.2]), wires=[0, 1])])

with pytest.raises(
ValueError, match="BasisState parameter and wires must be of equal length."
):
dev.apply([qml.BasisState(np.array([0, 1]), wires=[0])])

with pytest.raises(
DeviceError,
match="Operation BasisState cannot be used after other Operations have already been applied ",
):
dev.reset()
dev.apply([qml.RZ(0.5, wires=[0]), qml.BasisState(np.array([1, 1]), wires=[0, 1])])
# @pytest.mark.parametrize("stateprep", [qml.QubitStateVector, qml.StatePrep])
# def test_apply_errors_qubit_state_vector(self, stateprep, qubit_device):
# """Test that apply fails for incorrect state preparation, and > 2 qubit gates"""
# dev = qubit_device(wires=2)
# with pytest.raises(ValueError, match="Sum of amplitudes-squared does not equal one."):
# dev.apply([stateprep(np.array([1, -1]), wires=[0])])

# with pytest.raises(
# DeviceError,
# match=f"Operation {stateprep(np.array([1, 0]), wires=[0]).name} cannot be used after other Operations have already been applied ",
# ):
# dev.reset()
# dev.apply([qml.RZ(0.5, wires=[0]), stateprep(np.array([0, 1, 0, 0]), wires=[0, 1])])

# def test_apply_errors_basis_state(self, qubit_device):
# dev = qubit_device(wires=2)
# with pytest.raises(
# ValueError, match="BasisState parameter must consist of 0 or 1 integers."
# ):
# dev.apply([qml.BasisState(np.array([-0.2, 4.2]), wires=[0, 1])])

# with pytest.raises(
# ValueError, match="BasisState parameter and wires must be of equal length."
# ):
# dev.apply([qml.BasisState(np.array([0, 1]), wires=[0])])

# with pytest.raises(
# DeviceError,
# match="Operation BasisState cannot be used after other Operations have already been applied ",
# ):
# dev.reset()
# dev.apply([qml.RZ(0.5, wires=[0]), qml.BasisState(np.array([1, 1]), wires=[0, 1])])


class TestExpval:
Expand Down Expand Up @@ -1469,3 +1469,50 @@ def test_warning():
"""Tests if a warning is raised when lightning device binaries are not available"""
with pytest.warns(UserWarning, match="Pre-compiled binaries for " + device_name):
qml.device(device_name, wires=1)


@pytest.mark.parametrize(
"op",
[
qml.BasisState([0, 0], wires=[0, 1]),
qml.QubitStateVector([0, 1, 0, 0], wires=[0, 1]),
qml.StatePrep([0, 1, 0, 0], wires=[0, 1]),
],
)
@pytest.mark.parametrize("theta, phi", list(zip(THETA, PHI)))
def test_circuit_with_stateprep(op, theta, phi, tol):
"""Test mid-circuit StatePrep"""
n_qubits = 5
n_wires = 2
dev_def = qml.device("default.qubit", wires=n_qubits)
dev = qml.device(device_name, wires=n_qubits)
m = 2**n_wires
U = np.random.rand(m, m) + 1j * np.random.rand(m, m)
U, _ = np.linalg.qr(U)
init_state = np.random.rand(2**n_qubits) + 1j * np.random.rand(2**n_qubits)
init_state /= np.sqrt(np.dot(np.conj(init_state), init_state))

prep = [qml.StatePrep(init_state, wires=range(n_qubits))]
ops = [
qml.RY(theta, wires=[0]),
qml.RY(phi, wires=[1]),
qml.CNOT(wires=[0, 1]),
op,
qml.QubitUnitary(U, wires=range(2, 2 + 2 * n_wires, 2)),
]
measurements = [qml.state()]
tape = qml.tape.QuantumTape(ops=ops, measurements=measurements, prep=prep)
assert np.allclose(dev.execute(tape), dev_def.execute(tape), tol)

def circuit():
qml.StatePrep(init_state, wires=range(n_qubits))
qml.RY(theta, wires=[0])
qml.RY(phi, wires=[1])
qml.CNOT(wires=[0, 1])
op
qml.QubitUnitary(U, wires=range(2, 2 + 2 * n_wires, 2))
return qml.state()

circ = qml.QNode(circuit, dev)
circ_def = qml.QNode(circuit, dev_def)
assert np.allclose(circ(), circ_def(), tol)

0 comments on commit 9453a02

Please sign in to comment.