Skip to content

Commit

Permalink
Avoid getting NumPy dtypes in printed (string) scalar values
Browse files Browse the repository at this point in the history
As a consequence of [NEP
51](https://numpy.org/neps/nep-0051-scalar-representation.html#nep51),
the string representation of scalar numbers changed in NumPy 2 to
include type information. This affected printing Cirq circuit
diagrams: instead seeing numbers like 1.5, you would see
`np.float64(1.5)` and similar.

The solution is to avoid getting the repr output of NumPy scalars
directly, and instead doing `.item()` on them before passing them
to `format()` or other string-producing functions.
  • Loading branch information
mhucka committed Sep 19, 2024
1 parent d83b1c2 commit 0f52831
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
3 changes: 3 additions & 0 deletions cirq-core/cirq/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def _print(self, expr, **kwargs):
if hasattr(value, "__qualname__"):
return f"{value.__module__}.{value.__qualname__}"

if isinstance(value, np.number):
return repr(value.item())

return repr(value)


Expand Down
22 changes: 14 additions & 8 deletions cirq-core/cirq/ops/fsim_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import cmath
import math
from typing import AbstractSet, Any, Dict, Iterator, Optional, Tuple
from typing import AbstractSet, Any, Dict, Iterator, Optional, Tuple, Union

import numpy as np
import sympy
Expand Down Expand Up @@ -52,6 +52,12 @@ def _half_pi_mod_pi(param: 'cirq.TParamVal') -> bool:
return param in (-np.pi / 2, np.pi / 2, -sympy.pi / 2, sympy.pi / 2)


def _plainvalue(value: Union[int, float, complex, np.number]) -> Union[int, float, complex]:
"""Returns a plain Python number if the given value is a NumPy number.
Used to avoid a change in repr behavior introduced in NumPy 2."""
return value.item() if isinstance(value, np.number) else value


@value.value_equality(approximate=True)
class FSimGate(gate_features.InterchangeableQubitsGate, raw_types.Gate):
r"""Fermionic simulation gate.
Expand Down Expand Up @@ -196,8 +202,8 @@ def _decompose_(self, qubits) -> Iterator['cirq.OP_TREE']:
yield cirq.CZ(a, b) ** (-self.phi / np.pi)

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs') -> Tuple[str, ...]:
t = args.format_radians(self.theta)
p = args.format_radians(self.phi)
t = args.format_radians(_plainvalue(self.theta))
p = args.format_radians(_plainvalue(self.phi))
return f'FSim({t}, {p})', f'FSim({t}, {p})'

def __pow__(self, power) -> 'FSimGate':
Expand Down Expand Up @@ -476,11 +482,11 @@ def to_exponent(angle_rads: 'cirq.TParamVal') -> 'cirq.TParamVal':
yield cirq.Z(q1) ** to_exponent(after[1])

def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs') -> Tuple[str, ...]:
theta = args.format_radians(self.theta)
zeta = args.format_radians(self.zeta)
chi = args.format_radians(self.chi)
gamma = args.format_radians(self.gamma)
phi = args.format_radians(self.phi)
theta = args.format_radians(_plainvalue(self.theta))
zeta = args.format_radians(_plainvalue(self.zeta))
chi = args.format_radians(_plainvalue(self.chi))
gamma = args.format_radians(_plainvalue(self.gamma))
phi = args.format_radians(_plainvalue(self.phi))
return (
f'PhFSim({theta}, {zeta}, {chi}, {gamma}, {phi})',
f'PhFSim({theta}, {zeta}, {chi}, {gamma}, {phi})',
Expand Down
2 changes: 2 additions & 0 deletions cirq-core/cirq/protocols/circuit_diagram_info_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def __repr__(self) -> str:
def format_real(self, val: Union[sympy.Basic, int, float]) -> str:
if isinstance(val, sympy.Basic):
return str(val)
if isinstance(val, np.number):
val = val.item()
if val == int(val):
return str(int(val))
if self.precision is None:
Expand Down

0 comments on commit 0f52831

Please sign in to comment.