diff --git a/cirq-core/cirq/transformers/noise_adding.py b/cirq-core/cirq/transformers/noise_adding.py index 1e5e68c1d7c..d80e4669a9f 100644 --- a/cirq-core/cirq/transformers/noise_adding.py +++ b/cirq-core/cirq/transformers/noise_adding.py @@ -57,6 +57,7 @@ def __init__( + "sorted qubit pairs to floats" # pragma: no cover ) # pragma: no cover self.p = p + self.p_func = lambda _: p if isinstance(p, (int, float)) else lambda pair: p.get(pair, 0.0) self.target_gate = target_gate def __call__( @@ -77,7 +78,6 @@ def __call__( """ if rng is None: rng = np.random.default_rng() - p = self.p target_gate = self.target_gate # add random Pauli gates with probability p after each of the specified gate @@ -93,11 +93,8 @@ def __call__( } added_moment_ops = [] for pair in target_pairs: - if isinstance(p, float): - p_i = p - elif isinstance(p, Mapping): - pair_sorted_tuple = (pair[0], pair[1]) - p_i = p[pair_sorted_tuple] + pair_sorted_tuple = (pair[0], pair[1]) + p_i = self.p_func(pair_sorted_tuple) apply = rng.choice([True, False], p=[p_i, 1 - p_i]) if apply: choices = [