Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C3SX Gate #172

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"tqdm>=4.56.0",
"setuptools>=52.0.0",
"torch>=1.8.0",
"torchdiffeq>=0.2.3",
"torchpack>=0.3.0",
"qiskit==0.38.0",
"matplotlib>=3.3.2",
Expand Down
4 changes: 4 additions & 0 deletions test/operators/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tqdm import tqdm

import qiskit.circuit.library.standard_gates as qiskit_gate
from qiskit.quantum_info import Operator

RND_TIMES = 100

Expand All @@ -21,6 +22,7 @@
{"qiskit": qiskit_gate.SGate, "tq": tq.S},
{"qiskit": qiskit_gate.TGate, "tq": tq.T},
{"qiskit": qiskit_gate.SXGate, "tq": tq.SX},
{"qiskit": qiskit_gate.C3SXGate, "tq": tq.C3SX},
{"qiskit": qiskit_gate.CXGate, "tq": tq.CNOT},
{"qiskit": qiskit_gate.CYGate, "tq": tq.CY},
{"qiskit": qiskit_gate.CZGate, "tq": tq.CZ},
Expand Down Expand Up @@ -79,6 +81,8 @@ def test_op():
if pair["tq"]().name == "SHadamard":
"""Square root of Hadamard is RY(pi/4)"""
qiskit_matrix = qiskit_gate.RYGate(theta=np.pi / 4).to_matrix()
elif pair["tq"]().name == "C3SX":
qiskit_matrix = Operator(pair["qiskit"]())
else:
qiskit_matrix = pair["qiskit"]().to_matrix()
tq_matrix = pair["tq"].matrix.numpy()
Expand Down
67 changes: 67 additions & 0 deletions torchquantum/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"s",
"t",
"sx",
"c3sx",
"cnot",
"cz",
"cy",
Expand Down Expand Up @@ -1100,6 +1101,23 @@ def singleexcitation_matrix(params):

return matrix.squeeze(0)

def c3sx_matrix():
"""Compute unitary matrix for c3sx gate.

Args:
None.

Returns:
torch.Tensor: The computed unitary matrix.

"""
mat = torch.eye(16, dtype=C_DTYPE)
mat[14][14] = (1 + 1j) / 2
mat[14][15] = (1 - 1j) / 2
mat[15][14] = (1 - 1j) / 2
mat[15][15] = (1 + 1j) / 2

return mat

mat_dict = {
"hadamard": torch.tensor(
Expand All @@ -1119,6 +1137,7 @@ def singleexcitation_matrix(params):
"s": torch.tensor([[1, 0], [0, 1j]], dtype=C_DTYPE),
"t": torch.tensor([[1, 0], [0, np.exp(1j * np.pi / 4)]], dtype=C_DTYPE),
"sx": 0.5 * torch.tensor([[1 + 1j, 1 - 1j], [1 - 1j, 1 + 1j]], dtype=C_DTYPE),
"c3sx": c3sx_matrix(),
"cnot": torch.tensor(
[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]], dtype=C_DTYPE
),
Expand Down Expand Up @@ -1622,6 +1641,53 @@ def sx(
)


def c3sx(
q_device,
wires,
params=None,
n_wires=None,
static=False,
parent_graph=None,
inverse=False,
comp_method="bmm",
):
"""Perform the c3sx gate.

Args:
q_device (tq.QuantumDevice): The QuantumDevice.
wires (Union[List[int], int]): Which qubit(s) to apply the gate.
params (torch.Tensor, optional): Parameters (if any) of the gate.
Default to None.
n_wires (int, optional): Number of qubits the gate is applied to.
Default to None.
static (bool, optional): Whether use static mode computation.
Default to False.
parent_graph (tq.QuantumGraph, optional): Parent QuantumGraph of
current operation. Default to None.
inverse (bool, optional): Whether inverse the gate. Default to False.
comp_method (bool, optional): Use 'bmm' or 'einsum' method to perform
matrix vector multiplication. Default to 'bmm'.

Returns:
None.

"""
name = "c3sx"
mat = mat_dict[name]
gate_wrapper(
name=name,
mat=mat,
method=comp_method,
q_device=q_device,
wires=wires,
params=params,
n_wires=n_wires,
static=static,
parent_graph=parent_graph,
inverse=inverse,
)


def cnot(
q_device,
wires,
Expand Down Expand Up @@ -3252,6 +3318,7 @@ def ecr(
"s": s,
"t": t,
"sx": sx,
"c3sx": c3sx,
"cnot": cnot,
"cz": cz,
"cy": cy,
Expand Down
15 changes: 8 additions & 7 deletions torchquantum/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torchquantum.functional import mat_dict
from torchquantum.operators import op_name_dict
from copy import deepcopy
import matplotlib.pyplot as plt

__all__ = [
"find_observable_groups",
Expand All @@ -32,7 +33,7 @@ def gen_bitstrings(n_wires):
return ["{:0{}b}".format(k, n_wires) for k in range(2**n_wires)]


def measure(qdev, n_shots=1024):
def measure(qdev, n_shots=1024, draw_id=None):
"""Measure the target state and obtain classical bitstream distribution
Args:
q_state: input tq.QuantumDevice
Expand All @@ -58,12 +59,12 @@ def measure(qdev, n_shots=1024):
distri = OrderedDict(sorted(distri.items()))
distri_all.append(distri)

# if draw_id is not None:
# plt.bar(distri_all[draw_id].keys(), distri_all[draw_id].values())
# plt.xticks(rotation="vertical")
# plt.xlabel("bitstring [qubit0, qubit1, ..., qubitN]")
# plt.title("distribution of measured bitstrings")
# plt.show()
if draw_id is not None:
plt.bar(distri_all[draw_id].keys(), distri_all[draw_id].values())
plt.xticks(rotation="vertical")
plt.xlabel("bitstring [qubit0, qubit1, ..., qubitN]")
plt.title("distribution of measured bitstrings")
plt.show()
return distri_all


Expand Down
16 changes: 16 additions & 0 deletions torchquantum/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"S",
"T",
"SX",
"C3SX",
"CNOT",
"CZ",
"CY",
Expand Down Expand Up @@ -111,6 +112,7 @@ class Operator(tq.QuantumModule):
"S",
"T",
"SX",
"C3SX",
"CNOT",
"CZ",
"CY",
Expand Down Expand Up @@ -714,6 +716,19 @@ def _eigvals(cls, params):
return cls.eigvals


class C3SX(Operation, metaclass=ABCMeta):
"""Class for C3SX Gate."""

num_params = 0
num_wires = 4
matrix = mat_dict["c3sx"]
func = staticmethod(tqf.c3sx)

@classmethod
def _matrix(cls, params):
return cls.matrix


class CNOT(Operation, metaclass=ABCMeta):
"""Class for CNOT Gate."""

Expand Down Expand Up @@ -1338,6 +1353,7 @@ def _matrix(cls, params):
"s": S,
"t": T,
"sx": SX,
"c3sx": C3SX,
"cx": CNOT,
"cnot": CNOT,
"cz": CZ,
Expand Down
765 changes: 765 additions & 0 deletions torchquantum/pulse/ISCA_tutorial_pulse.ipynb

Large diffs are not rendered by default.

293 changes: 293 additions & 0 deletions torchquantum/pulse/MESolver_example.ipynb

Large diffs are not rendered by default.

443 changes: 443 additions & 0 deletions torchquantum/pulse/SESolver_example.ipynb

Large diffs are not rendered by default.

234 changes: 234 additions & 0 deletions torchquantum/pulse/Two_qubit_simple_example.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions torchquantum/pulse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .utils import *
from .sesolve import sesolve
from .mesolve import mesolve
# from .smesolve import smesolve
1 change: 1 addition & 0 deletions torchquantum/pulse/hardware/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .hardware import hardware
11 changes: 11 additions & 0 deletions torchquantum/pulse/hardware/hardware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch
import numpy as np
import torchquantum as tq
import torchdiffeq




class Hardware(torch.nn.Modele):
def __init__(self,):

1 change: 1 addition & 0 deletions torchquantum/pulse/mesolve/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mesolve import mesolve
67 changes: 67 additions & 0 deletions torchquantum/pulse/mesolve/mesolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
import math
from ..solver import Solver
from ..utils import *
from torchdiffeq import odeint

def mesolve(
dens0,
H=None,
n_dt=None,
dt=0.22,
*,
L_ops=None,
exp_ops=None,
options=None,
dtype=None,
device=None
):
if options is None:
options = {}

if not 'step_size' in options:
options['step_size'] = 0.001

t_save = torch.tensor(list(range(n_dt)))*dt

args = (H, dens0, t_save, exp_ops, options)

solver = MESolver(*args, L_ops=L_ops)

solver.run()

psi_save, exp_save = solver.y_save, solver.exp_save

return psi_save, exp_save

def _lindblad_helper(L, rho):
Ldag = torch.conj(L)
return L @ rho @ Ldag - 0.5 * Ldag @ L @ rho - 0.5 * rho @ Ldag @ L

def lindbladian(H,rho,L_ops):
if L_ops is None:
return -1j * (H @ rho - rho @ H)

if type(L_ops) is not list:
L_ops = [L_ops]

_dissipator = [_lindblad_helper(L, rho) for L in L_ops]
dissipator = torch.stack(_dissipator)
return -1j * (H @ rho - rho @ H) + dissipator.sum(0)

class MESolver(Solver):

def __init__(self, *args, L_ops):
super().__init__(*args)
self.L_ops = L_ops


def f(self, t, y):
h = self.H(t)
return lindbladian(h,y,self.L_ops)


def run(self):
# self.y_save = odeint(self.f, self.psi0, self.t_save, method='rk4', options=self.options)
self.y_save = odeint(self.f, self.psi0, self.t_save)
self.exp_save = None
Loading
Loading