Skip to content

Commit

Permalink
test: add test for simulator
Browse files Browse the repository at this point in the history
  • Loading branch information
beizhansl committed Nov 1, 2023
1 parent 11af910 commit dedd046
Show file tree
Hide file tree
Showing 3 changed files with 474 additions and 0 deletions.
46 changes: 46 additions & 0 deletions tests/quafu/simulator/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# (C) Copyright 2023 Beijing Academy of Quantum Information Sciences
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

class BaseTest:
def assertDictAlmostEqual(self, dict1, dict2, delta=None, places=None, default_value=-1):
"""
Assert two dictionaries with numeric values are almost equal.
Args:
dict1 (dict): a dictionary.
dict2 (dict): a dictionary.
delta (number): threshold for comparison (defaults to 1e-8).
msg (str): return a custom message on failure.
places (int): number of decimal places for comparison.
default_value (number): default value for missing keys.
"""
def valid_comparison(value):
"""compare value to delta, within places accuracy"""
if places is not None:
return round(value, places) == 0
else:
return value < delta

# Check arguments.
if dict1 == dict2:
return
if places is None and delta is None:
delta = delta or 1e-8

# Compare all keys in both dicts, populating error_msg.
for key in set(dict1.keys()) | set(dict2.keys()):
val1 = dict1.get(key, default_value)
val2 = dict2.get(key, default_value)
if not valid_comparison(abs(val1 - val2)):
raise Exception("Dict not equal")
259 changes: 259 additions & 0 deletions tests/quafu/simulator/basis_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# (C) Copyright 2023 Beijing Academy of Quantum Information Sciences
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import pytest
from quafu import QuantumCircuit
from quafu import simulate
from base import BaseTest
import unittest
import numpy as np

class BellCircuits:
"""Container for reference circuits used by the tests."""

@staticmethod
def bell_measure_atlast():
"""Return a Bell circuit."""
qc = QuantumCircuit(2, 2)
qc.h(0)
qc.cx(0, 1)
qc.measure([0,1])
return qc

@staticmethod
def bell_measure_normal():
"""Return a Bell circuit."""
qc = QuantumCircuit(3, 2)
qc.h(0)
qc.cx(0, 1)
qc.measure([0,1])
qc.h(2)
return qc

@staticmethod
def bell_no_measure():
"""Return a Bell circuit."""
qc = QuantumCircuit(2)
qc.h(0)
qc.cx(0, 1)

return qc

class BasicCircuits:
"""Container for reference circuits used by the tests."""

@staticmethod
def singleQgate_measure_atlast():
qc = QuantumCircuit(2, 2)
qc.x(0)
qc.x(1)
qc.measure([0,1])
return qc

@staticmethod
def singleQgate_no_measure():
qc = QuantumCircuit(2)
qc.x(0)
qc.x(1)
return qc

@staticmethod
def singleQgate_measure_normal():
qc = QuantumCircuit(2)
qc.x(0)
qc.measure([0], [0])
qc.x(1)
qc.measure([1], [1])
return qc

@staticmethod
def multiQgate_measure_atlast():
qc = QuantumCircuit(2, 2)
qc.x(0)
qc.cx(0,1)
qc.measure([0,1])
return qc

@staticmethod
def multiQgate_no_measure():
qc = QuantumCircuit(2)
qc.x(0)
qc.cx(0,1)
return qc

@staticmethod
def multiQgate_measure_normal():
qc = QuantumCircuit(2)
qc.x(0)
qc.measure([0], [0])
qc.cx(0,1)
qc.measure([1], [1])
return qc

@staticmethod
def any_cbit_measure():
qc = QuantumCircuit(4,4)
qc.x(0)
qc.x(1)
qc.measure([1,2], [1,0])
qc.measure([3,0], [2,3])
return qc

@staticmethod
def after_measure():
qc = QuantumCircuit(2,22)
qc.h(0)
qc.cx(0,1)
qc.measure([0], [0])
qc.measure([1], [1])
qc.reset([0,1])
return qc




class TestSimulatorBasis(BaseTest):
"""Test C++ simulator"""
circuit = None
assertEqual = unittest.TestCase.assertEqual
assertAlmostEqual = unittest.TestCase.assertAlmostEqual
assertDictEqual = unittest.TestCase.assertDictEqual
assertListEqual = unittest.TestCase.assertListEqual
assertTrue = unittest.TestCase.assertTrue

@pytest.mark.skipif(
sys.platform == "darwin", reason="Avoid error on MacOS arm arch."
)
def test_simulate(self):
self.circuit = BellCircuits.bell_no_measure()
result = simulate(qc=self.circuit)
probs = result.probabilities
count = result.count
self.assertAlmostEqual(probs[0], 1/2)
self.assertAlmostEqual(probs[1], 0)
self.assertAlmostEqual(probs[2], 0)
self.assertAlmostEqual(probs[3], 1/2)
self.assertDictAlmostEqual(count, {})

@pytest.mark.skipif(
sys.platform == "darwin", reason="Avoid error on MacOS arm arch."
)
def test_measure_atlast_collapse(self):
"""Test final measurement statement"""
self.circuit = BellCircuits.bell_measure_atlast()
result = simulate(qc=self.circuit)
probs = result.probabilities
self.assertAlmostEqual(probs[0], 1/2)
self.assertAlmostEqual(probs[1], 0)
self.assertAlmostEqual(probs[2], 0)
self.assertAlmostEqual(probs[3], 1/2)

@pytest.mark.skipif(
sys.platform == "darwin", reason="Avoid error on MacOS arm arch."
)
def test_measure_normal_collapse(self):
"""Test normal measurement statement"""
self.circuit = BellCircuits.bell_measure_normal()
result = simulate(qc=self.circuit, shots=1)
probs = result.probabilities
diff_00 = np.linalg.norm(np.array([1, 0, 0, 0]) - probs) ** 2
diff_11 = np.linalg.norm(np.array([0, 0, 0, 1]) - probs) ** 2
success = np.allclose([diff_00, diff_11], [0, 2]) or np.allclose([diff_00, diff_11], [2, 0])
# state is 1/sqrt(2)|00> + 1/sqrt(2)|11>, up to a global phase
self.assertTrue(success)

def test_singleQgate_measure_atlast(self):
self.circuit = BasicCircuits.singleQgate_measure_atlast()
result = simulate(qc=self.circuit, shots=1)
probs = result.probabilities
counts = result.count
self.assertAlmostEqual(probs[0], 0)
self.assertAlmostEqual(probs[1], 0)
self.assertAlmostEqual(probs[2], 0)
self.assertAlmostEqual(probs[3], 1)
self.assertDictAlmostEqual(counts, {'11':1})

def test_singleQgate_no_measure(self):
self.circuit = BasicCircuits.singleQgate_no_measure()
result = simulate(qc=self.circuit, shots=1)
probs = result.probabilities
counts = result.count
self.assertAlmostEqual(probs[0], 0)
self.assertAlmostEqual(probs[1], 0)
self.assertAlmostEqual(probs[2], 0)
self.assertAlmostEqual(probs[3], 1)
self.assertDictAlmostEqual(counts, {})

def test_singleQgate_measure_normal(self):
self.circuit = BasicCircuits.singleQgate_measure_normal()
result = simulate(qc=self.circuit, shots=10)
probs = result.probabilities
counts = result.count
self.assertAlmostEqual(probs[0], 0)
self.assertAlmostEqual(probs[1], 0)
self.assertAlmostEqual(probs[2], 0)
self.assertAlmostEqual(probs[3], 1)
self.assertDictAlmostEqual(counts, {'11':10})

def test_multiQgate_measure_atlast(self):
self.circuit = BasicCircuits.multiQgate_measure_atlast()
result = simulate(qc=self.circuit, shots=10)
probs = result.probabilities
counts = result.count
self.assertAlmostEqual(probs[0], 0)
self.assertAlmostEqual(probs[1], 0)
self.assertAlmostEqual(probs[2], 0)
self.assertAlmostEqual(probs[3], 1)
self.assertDictAlmostEqual(counts, {'11':10})

def test_multiQgate_no_measure(self):
self.circuit = BasicCircuits.multiQgate_no_measure()
result = simulate(qc=self.circuit, shots=1)
probs = result.probabilities
counts = result.count
self.assertAlmostEqual(probs[0], 0)
self.assertAlmostEqual(probs[1], 0)
self.assertAlmostEqual(probs[2], 0)
self.assertAlmostEqual(probs[3], 1)
self.assertDictAlmostEqual(counts, {})

def test_multiQgate_measure_normal(self):
self.circuit = BasicCircuits.multiQgate_measure_normal()
result = simulate(qc=self.circuit, shots=10)
probs = result.probabilities
counts = result.count
self.assertAlmostEqual(probs[0], 0)
self.assertAlmostEqual(probs[1], 0)
self.assertAlmostEqual(probs[2], 0)
self.assertAlmostEqual(probs[3], 1)
self.assertDictAlmostEqual(counts, {'11':10})

def test_anycbit_measure(self):
self.circuit = BasicCircuits.any_cbit_measure()
result = simulate(qc=self.circuit, shots=10)
probs = result.probabilities
counts = result.count
print(probs)
self.assertAlmostEqual(probs[5], 1) #0101
self.assertDictAlmostEqual(counts, {'0101':10})

def test_after_measure(self):
self.circuit = BasicCircuits.after_measure()
result = simulate(qc=self.circuit, shots=10)
probs = result.probabilities
diff_00 = np.linalg.norm(np.array([1, 0, 0, 0]) - probs) ** 2
diff_11 = np.linalg.norm(np.array([0, 0, 0, 1]) - probs) ** 2
success = np.allclose([diff_00, diff_11], [0, 2]) or np.allclose([diff_00, diff_11], [2, 0])
self.assertTrue(success)
Loading

0 comments on commit dedd046

Please sign in to comment.