Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Verification is not allowing differentiation of non-differentiable gates #1065

Open
josh146 opened this issue Aug 28, 2024 · 8 comments
Labels
bug Something isn't working

Comments

@josh146
Copy link
Member

josh146 commented Aug 28, 2024

The new gradient verification is slightly too overzealous, and won't allow circuits to be differentiated even when the operation which is not supported for differentiation is not being differentiated:

import pennylane as qml
import numpy as np
from jax import numpy as jnp

dev = qml.device("lightning.qubit", wires=6)

@qml.qjit
@qml.qnode(dev)
def cost(params):
    qml.BasisState(np.array([1, 1, 0, 0, 0, 0]), wires=range(6))
    qml.DoubleExcitation(params[0], wires=[0, 1, 2, 3])
    qml.DoubleExcitation(params[1], wires=[0, 1, 4, 5])
    return qml.expval(qml.PauliZ(0))

@qml.qjit
def dcost(params):
    return qml.grad(cost)(params)
>>> x = jnp.array([0.2, 0.7])
>>> print("Cost:", cost(x))
Cost: -0.7472525151031233
>>> print("Gradient:", dcost(x))
File ~/miniconda3/lib/python3.10/site-packages/catalyst/device/verification.py:234, in verify_operations.<locals>._op_checker(op, state)
    232 _mcm_op_checker(op)
    233 if grad_method == \"adjoint\":
--> 234     _adj_diff_op_checker(op)
    235 elif grad_method == \"parameter-shift\":
    236     _paramshift_op_checker(op)

File ~/miniconda3/lib/python3.10/site-packages/catalyst/device/verification.py:152, in verify_operations.<locals>._adj_diff_op_checker(op)
    148     op_name = op.name
    149 if not qjit_device.qjit_capabilities.native_ops.get(
    150     op_name, EMPTY_PROPERTIES
    151 ).differentiable:
--> 152     raise DifferentiableCompileError(
    153         f\"{op.name} is non-differentiable on '{qjit_device.original_device.name}' device\"
    154     )

DifferentiableCompileError: BasisState is non-differentiable on 'lightning.qubit' device"

In this case, it is failing to allow this circuit to pass verification even though BasisState is not being differentiated.

Previously, this example would work fine, since BasisState was always being decomposed down to non-parametrizable gates (qml.X).

Note that this is currently affecting our VQE + catalyst demos, and they are no longer executable. A temporary workaround I can do is:

qml.BasisState.compute_decomposition(np.array([1, 1, 0, 0, 0, 0]), wires=range(6))

but this is not ideal.

@josh146 josh146 added the bug Something isn't working label Aug 28, 2024
@josh146
Copy link
Member Author

josh146 commented Aug 28, 2024

Note that if I set diff_method="parameter-shift", I get a compilation error:

>>> dcost(x)
dcost:13:3: error: 'func.func' op cloned during the gradient pass is not free of quantum ops:

"func.func"() <{function_type = (tensor<6xi64>, tensor<2xf64>, index) -> tensor<?xf64>, sym_name = "cost.qgrad", sym_visibility = "private"}> ({
^bb0(%arg0: tensor<6xi64>, %arg1: tensor<2xf64>, %arg2: index):
  %0 = "arith.constant"() <{value = sparse<15, -1.5707963267948966> : tensor<16xf64>}> : () -> tensor<16xf64>
...

@dime10
Copy link
Collaborator

dime10 commented Aug 28, 2024

Note that if I set diff_method="parameter-shift", I get a compilation error:

>>> dcost(x)
dcost:13:3: error: 'func.func' op cloned during the gradient pass is not free of quantum ops:

"func.func"() <{function_type = (tensor<6xi64>, tensor<2xf64>, index) -> tensor<?xf64>, sym_name = "cost.qgrad", sym_visibility = "private"}> ({
^bb0(%arg0: tensor<6xi64>, %arg1: tensor<2xf64>, %arg2: index):
  %0 = "arith.constant"() <{value = sparse<15, -1.5707963267948966> : tensor<16xf64>}> : () -> tensor<16xf64>
...

@erick-xanadu I think this is related to our interface discussion on your PR, since the new gates aren't implementing one of the quantum gate interfaces, the gradient passes would need to remove them explicitly.
Should be an easy fix.

@dime10
Copy link
Collaborator

dime10 commented Aug 28, 2024

The main problem mentioned here might be difficult to solve quickly, we need a way to track which gate parameters came from differentiated function arguments.

A workaround would be completing the decomposition of gates not supported for differentiation, which we want anyways. I believe that should be fairly quick to implement. While this is inefficient in some cases (unnecessary decomposing), it does match the previous behaviour for StatePrep.

@erick-xanadu
Copy link
Contributor

@erick-xanadu I think this is related to our interface discussion, since the new gates aren't implementing the interface the gradient passes would need to remove them explicitly.

I agree. However, I don't see how we got here because verification should have caught this error similar to above.

@dime10
Copy link
Collaborator

dime10 commented Aug 28, 2024

@erick-xanadu I think this is related to our interface discussion, since the new gates aren't implementing the interface the gradient passes would need to remove them explicitly.

I agree. However, I don't see how we got here because verification should have caught this error similar to above.

I don't know if verification for parameter-shift has been implemented yet.

@lillian542
Copy link
Contributor

I don't know if verification for parameter-shift has been implemented yet.

A naive verification for parameter shift is implemented, just confirming that op.grad_method in {"A", None} for all the operations. The more thorough verification that was discussed didn't make it in yet.

@josh146
Copy link
Member Author

josh146 commented Aug 28, 2024

@erick-xanadu I think this is related to our interface discussion on your PR, since the new gates aren't implementing one of the quantum gate interfaces, the gradient passes would need to remove them explicitly.
Should be an easy fix.

If the parameter-shift bug is an easy fix, should I split this into its own issue separate from the verification discussion?

@josh146
Copy link
Member Author

josh146 commented Aug 28, 2024

@dime10 @erick-xanadu #1072

(we can treat this now as two separate bugs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants