Skip to content

Commit

Permalink
updated the cu gate to be like qiskit's
Browse files Browse the repository at this point in the history
  • Loading branch information
01110011011101010110010001101111 committed Jul 13, 2023
1 parent ce1eca8 commit a1d9e17
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 2 deletions.
1 change: 1 addition & 0 deletions test/operators/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
{"qiskit": qiskit_gate.CU1Gate, "tq": tq.CU1},
# {'qiskit': qiskit_gate.?, 'tq': tq.CU2},
{"qiskit": qiskit_gate.CU3Gate, "tq": tq.CU3},
{"qiskit": qiskit_gate.CUGate, "tq": tq.CU},
{"qiskit": qiskit_gate.ECRGate, "tq": tq.ECR},
]

Expand Down
84 changes: 83 additions & 1 deletion torchquantum/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"cu1",
"cu2",
"cu3",
"cu",
"qubitunitary",
"qubitunitaryfast",
"qubitunitarystrict",
Expand Down Expand Up @@ -959,6 +960,40 @@ def cu3_matrix(params):

return matrix.squeeze(0)

def cu_matrix(params):
"""Compute unitary matrix for CU gate.
Args:
params (torch.Tensor): The rotation angle.
Returns:
torch.Tensor: The computed unitary matrix.
"""
theta = params[:, 0].unsqueeze(dim=-1).type(C_DTYPE)
phi = params[:, 1].unsqueeze(dim=-1).type(C_DTYPE)
lam = params[:, 2].unsqueeze(dim=-1).type(C_DTYPE)
gamma = params[:, 3].unsqueeze(dim=-1).type(C_DTYPE)

co = torch.cos(theta / 2)
si = torch.sin(theta / 2)

matrix = (
torch.tensor(
[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
dtype=C_DTYPE,
device=params.device,
)
.unsqueeze(0)
.repeat(phi.shape[0], 1, 1)
)

matrix[:, 2, 2] = co * torch.exp(1j * gamma)
matrix[:, 2, 3] = -si * torch.exp(1j * (lam + gamma))
matrix[:, 3, 2] = si * torch.exp(1j * (phi + gamma))
matrix[:, 3, 3] = co * torch.exp(1j * (phi + lam + gamma))

return matrix.squeeze(0)

def qubitunitary_matrix(params):
"""Compute unitary matrix for Qubitunitary gate.
Expand Down Expand Up @@ -1190,6 +1225,7 @@ def singleexcitation_matrix(params):
"cu1": cu1_matrix,
"cu2": cu2_matrix,
"cu3": cu3_matrix,
"cu": cu_matrix,
"qubitunitary": qubitunitary_matrix,
"qubitunitaryfast": qubitunitaryfast_matrix,
"qubitunitarystrict": qubitunitarystrict_matrix,
Expand Down Expand Up @@ -2891,6 +2927,53 @@ def cu3(
)


def cu(
q_device,
wires,
params=None,
n_wires=None,
static=False,
parent_graph=None,
inverse=False,
comp_method="bmm",
):
"""Perform the cu gate.
Args:
q_device (tq.QuantumDevice): The QuantumDevice.
wires (Union[List[int], int]): Which qubit(s) to apply the gate.
params (torch.Tensor, optional): Parameters (if any) of the gate.
Default to None.
n_wires (int, optional): Number of qubits the gate is applied to.
Default to None.
static (bool, optional): Whether use static mode computation.
Default to False.
parent_graph (tq.QuantumGraph, optional): Parent QuantumGraph of
current operation. Default to None.
inverse (bool, optional): Whether inverse the gate. Default to False.
comp_method (bool, optional): Use 'bmm' or 'einsum' method to perform
matrix vector multiplication. Default to 'bmm'.
Returns:
None.
"""
name = "cu"
mat = mat_dict[name]
gate_wrapper(
name=name,
mat=mat,
method=comp_method,
q_device=q_device,
wires=wires,
params=params,
n_wires=n_wires,
static=static,
parent_graph=parent_graph,
inverse=inverse,
)


def qubitunitary(
q_device,
wires,
Expand Down Expand Up @@ -3234,7 +3317,6 @@ def ecr(
ccnot = toffoli
ccx = toffoli
u = u3
cu = cu3
p = phaseshift
cp = cu1
cr = cu1
Expand Down
14 changes: 13 additions & 1 deletion torchquantum/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"CU1",
"CU2",
"CU3",
"CU",
"QubitUnitary",
"QubitUnitaryFast",
"TrainableUnitary",
Expand Down Expand Up @@ -145,6 +146,7 @@ class Operator(tq.QuantumModule):
"CU1",
"CU2",
"CU3",
"CU",
"QubitUnitary",
"QubitUnitaryFast",
"TrainableUnitary",
Expand Down Expand Up @@ -1102,6 +1104,16 @@ class CU3(Operation, metaclass=ABCMeta):
def _matrix(cls, params):
return tqf.cu3_matrix(params)

class CU(Operation, metaclass=ABCMeta):
"""Class for Controlled U gate (4-parameter two-qubit gate)."""

num_params = 4
num_wires = 2
func = staticmethod(tqf.cu)

@classmethod
def _matrix(cls, params):
return tqf.cu_matrix(params)

class QubitUnitary(Operation, metaclass=ABCMeta):
"""Class for controlled Qubit Unitary gate."""
Expand Down Expand Up @@ -1376,7 +1388,7 @@ def _matrix(cls, params):
"cphase": CU1,
"cu2": CU2,
"cu3": CU3,
"cu": CU3,
"cu": CU,
"qubitunitary": QubitUnitary,
"qubitunitarystrict": QubitUnitaryFast,
"qubitunitaryfast": QubitUnitaryFast,
Expand Down

0 comments on commit a1d9e17

Please sign in to comment.