Skip to content

Commit

Permalink
!2323 merge from r0.10
Browse files Browse the repository at this point in the history
Merge pull request !2323 from donghufeng/r0.10
  • Loading branch information
donghufeng authored and gitee-org committed Apr 22, 2024
2 parents 4747ee6 + 0564501 commit 48b41bb
Show file tree
Hide file tree
Showing 25 changed files with 842 additions and 62 deletions.
10 changes: 5 additions & 5 deletions ccsrc/lib/simulator/stabilizer/stabilizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ int CalcG(size_t x1, size_t z1, size_t x2, size_t z2) {
void StabilizerTableau::RowSum(size_t h, size_t i) {
int r0 = 2 * (phase.GetBit(h) + phase.GetBit(i));
for (size_t j = 0; j < n_qubits; ++j) {
r0 = CalcG(GetElement(i, j), GetElement(i, j + n_qubits), GetElement(h, j), GetElement(h, j + n_qubits));
r0 += CalcG(GetElement(i, j), GetElement(i, j + n_qubits), GetElement(h, j), GetElement(h, j + n_qubits));
table[j].SetBit(h, table[j].GetBit(h) ^ table[j].GetBit(i));
table[j + n_qubits].SetBit(h, table[j + n_qubits].GetBit(h) ^ table[j + n_qubits].GetBit(i));
}
Expand Down Expand Up @@ -247,10 +247,10 @@ size_t StabilizerTableau::ApplyMeasurement(size_t a) {
if (GetElement(i, a) == 1) {
int r0 = 2 * (tail.GetBit(2 * n_qubits) + phase.GetBit(i + n_qubits));
for (size_t j = 0; j < n_qubits; ++j) {
r0 = CalcG(GetElement(i + n_qubits, j), GetElement(i + n_qubits, j + n_qubits), tail.GetBit(j),
tail.GetBit(j + n_qubits));
tail.SetBit(j, tail.GetBit(j) ^ table[j].GetBit(i));
tail.SetBit(j + n_qubits, tail.GetBit(j + n_qubits) ^ table[j + n_qubits].GetBit(i));
r0 += CalcG(GetElement(i + n_qubits, j), GetElement(i + n_qubits, j + n_qubits), tail.GetBit(j),
tail.GetBit(j + n_qubits));
tail.SetBit(j, tail.GetBit(j) ^ table[j].GetBit(i + n_qubits));
tail.SetBit(j + n_qubits, tail.GetBit(j + n_qubits) ^ table[j + n_qubits].GetBit(i + n_qubits));
}
tail.SetBit(2 * n_qubits, (((r0 % 4) + 4) % 4) / 2);
}
Expand Down
2 changes: 1 addition & 1 deletion ccsrc/python/simulator/lib/_mq_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ PYBIND11_MODULE(_mq_vector, module) {
#ifndef __CUDACC__
using namespace mindquantum::stabilizer; // NOLINT
pybind11::class_<StabilizerTableau>(stabilizer, "StabilizerTableau")
.def(pybind11::init<size_t>())
.def(pybind11::init<size_t, unsigned>(), "n_qubits"_a, "seed"_a = 42)
.def("copy", [](const StabilizerTableau& s) { return s; })
.def("tableau_to_string", &StabilizerTableau::TableauToString)
.def("stabilizer_to_string", &StabilizerTableau::StabilizerToString)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
mindquantum.algorithm.library.mat_to_op
=======================================================

.. py:function:: mat_to_op(mat, little_endian: bool = True)
将一个基于qubit的矩阵表示转换为对应的泡利算符表示。默认以小端头表示输出QubitOperator。

参数:
- **mat** - 基于qubit的矩阵表示。
- **little_endian** - 是否使用小端头表示(默认为True,即小端头表示)。如果为True,则表示最高位Qubit为最左边的位(即小端头表示),否则表示最高位Qubit为最右边的位(即大端头表示)

返回:
:class:`~.core.QubitOperator`, 对应的泡利算符表示的QubitOperator。
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
mindquantum.algorithm.library.qudit_symmetric_decoding
========================================================

.. py:function:: qudit_symmetric_decoding(qubit: np.ndarray, n_qubits: int = 1)
对称性解码,将qubit对称态或矩阵解码成qudit态或矩阵。

.. math::
\begin{align}
\ket{00\cdots00}&\to\ket{0} \\[.5ex]
\frac{\ket{0\cdots01}+\ket{0\cdots010}+\ket{10\cdots0}}{\sqrt{d-1}}&\to\ket{1} \\
\frac{\ket{0\cdots011}+\ket{0\cdots0101}+\ket{110\cdots0}}{\sqrt{d-1}}&\to\ket{2} \\
\vdots&\qquad\vdots \\[.5ex]
\ket{11\cdots11}&\to\ket{d-1}
\end{align}
参数:
- **qubit** (np.ndarray) - 需要解码的qubit对称态或矩阵,qubit态或矩阵需满足对称性。
- **n_qubits** (int) - qubit对称态或矩阵的量子比特数。默认值:``1``。

返回:
np.ndarray,对称性解码后的qudit态或矩阵。
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
mindquantum.algorithm.library.qudit_symmetric_encoding
========================================================

.. py:function:: qudit_symmetric_encoding(qudit: np.ndarray, n_qudits: int = 1)
对称性编码,将qudit态或矩阵编码成qubit对称态或矩阵。

.. math::
\begin{align}
\ket{0}&\to\ket{00\cdots00} \\[.5ex]
\ket{1}&\to\frac{\ket{0\cdots01}+\ket{0\cdots010}+\ket{10\cdots0}}{\sqrt{d-1}} \\
\ket{2}&\to\frac{\ket{0\cdots011}+\ket{0\cdots0101}+\ket{110\cdots0}}{\sqrt{d-1}} \\
\vdots&\qquad\vdots \\[.5ex]
\ket{d-1}&\to\ket{11\cdots11}
\end{align}
参数:
- **qudit** (np.ndarray) - 需要编码的qudit态或矩阵。
- **n_qudits** (int) - qudit态或矩阵的量子位个数。默认值:``1``。

返回:
np.ndarray,对称性编码后的qubit对称态或矩阵。
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
mindquantum.algorithm.library.qutrit_symmetric_ansatz
=======================================================

.. py:function:: qutrit_symmetric_ansatz(gate: UnivMathGate, basis: str = "zyz", with_phase: bool = False)
构造一个保持任意qutrit门编码对称性的qubit ansatz。

参考文献:
`Synthesis of multivalued quantum logic circuits by elementary gates <https://journals.aps.org/pra/abstract/10.1103/PhysRevA.87.012325>`_,
`Optimal synthesis of multivalued quantum circuits <https://journals.aps.org/pra/abstract/10.1103/PhysRevA.92.062317>`_。

参数:
- **gate** (:class:`~.core.gates.UnivMathGate`) - 由qutrit门编码而来的qubit门。
- **basis** (str) - 分解的基,可以是 ``"zyz"`` 或者 ``"u3"`` 中的一个。默认值: ``"zyz"``。
- **with_phase** (bool) - 是否将全局相位以 :class:`~.core.gates.GlobalPhase` 的形式作用在量子线路上。默认值: ``False``。

返回:
:class:`~.core.circuit.Circuit`,保持qutrit编码对称性的qubit ansatz。
4 changes: 4 additions & 0 deletions docs/api_python/algorithm/mindquantum.algorithm.library.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ MindQuantum常用算法模块。
mindquantum.algorithm.library.general_ghz_state
mindquantum.algorithm.library.general_w_state
mindquantum.algorithm.library.qft
mindquantum.algorithm.library.qudit_symmetric_decoding
mindquantum.algorithm.library.qudit_symmetric_encoding
mindquantum.algorithm.library.qutrit_symmetric_ansatz
mindquantum.algorithm.library.mat_to_op
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ mindquantum.algorithm.nisq.Max2SATAnsatz

.. math::
U(\beta, \gamma) = e^{-\beta_pH_b}e^{-\gamma_pH_c}
\cdots e^{-\beta_0H_b}e^{-\gamma_0H_c}H^{\otimes n}
U(\beta, \gamma) = e^{-i\beta_pH_b}e^{-i\frac{\gamma_p}{2}H_c}
\cdots e^{-i\beta_0H_b}e^{-i\frac{\gamma_0}{2}H_c}H^{\otimes n}
.. math::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ mindquantum.algorithm.nisq.MaxCutAnsatz

.. math::
U(\beta, \gamma) = e^{-\beta_pH_b}e^{-\gamma_pH_c}
\cdots e^{-\beta_0H_b}e^{-\gamma_0H_c}H^{\otimes n}
U(\beta, \gamma) = e^{-i\beta_pH_b}e^{-i\frac{\gamma_p}{2}H_c}
\cdots e^{-i\beta_0H_b}e^{-i\frac{\gamma_0}{2}H_c}H^{\otimes n}
.. math::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def kak_decompose(gate: QuantumGate, return_u3: bool = True) -> Circuit:
(q_left, q_right), (dr, di) = utils.simult_svd(ur, ui)
d = dr + 1j * di

_, a0, a1 = kron_factor_4x4_to_2x2s(M @ q_left @ M_DAG)
_, b0, b1 = kron_factor_4x4_to_2x2s(M @ q_right.T @ M_DAG)
_, a1, a0 = kron_factor_4x4_to_2x2s(M @ q_left @ M_DAG)
_, b1, b0 = kron_factor_4x4_to_2x2s(M @ q_right.T @ M_DAG)

k = linalg.inv(A) @ np.angle(np.diag(d))
h1, h2, h3 = -k[1:]
Expand Down
1 change: 0 additions & 1 deletion mindquantum/algorithm/compiler/decompose/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def params_zyz(mat: np.ndarray):
coe = linalg.det(mat) ** (-0.5)
alpha = -np.angle(coe)
v = coe * mat
v = v.round(10)
theta = 2 * atan2(abs(v[1, 0]), abs(v[0, 0]))
phi_lam_sum = 2 * np.angle(v[1, 1])
phi_lam_diff = 2 * np.angle(v[1, 0])
Expand Down
4 changes: 2 additions & 2 deletions mindquantum/algorithm/error_mitigation/random_benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def generate_single_qubit_rb_circ(length: int, seed: int = None) -> Circuit:
Returns:
:class:`~.core.circuit.Circuit`, the single qubit randomized benchmarking circuit, the quantum
state of this circuit is zero state.
state of this circuit is zero state.
Examples:
>>> import numpy as np
Expand Down Expand Up @@ -140,7 +140,7 @@ def generate_double_qubits_rb_circ(length: int, seed: int = None) -> Circuit:
Returns:
:class:`~.core.circuit.Circuit`, the double qubit randomized benchmarking circuit, the quantum state of
this circuit is zero state.
this circuit is zero state.
Examples:
>>> import numpy as np
Expand Down
6 changes: 5 additions & 1 deletion mindquantum/algorithm/library/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from .general_ghz_state import general_ghz_state
from .general_w_state import general_w_state
from .quantum_fourier import qft
from .qudit_mapping import qudit_symmetric_decoding, qudit_symmetric_encoding, qutrit_symmetric_ansatz

__all__ = ['qft', 'amplitude_encoder', 'general_w_state', 'general_ghz_state', 'bitphaseflip_operator']
__all__ = [
'qft', 'amplitude_encoder', 'general_w_state', 'general_ghz_state', 'bitphaseflip_operator',
'qudit_symmetric_decoding', 'qudit_symmetric_encoding', 'qutrit_symmetric_ansatz'
]

__all__.sort()
Loading

0 comments on commit 48b41bb

Please sign in to comment.