Skip to content

Commit

Permalink
Use errstate func
Browse files Browse the repository at this point in the history
  • Loading branch information
benjeffery committed Jun 27, 2024
1 parent 03e4091 commit a8cb721
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions python/tests/test_ld_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@


@contextlib.contextmanager
def suppress_division_by_zero_warning():
with np.errstate(invalid="ignore", divide="ignore"):
def suppress_overflow_div0_warning():
with np.errstate(over="ignore", invalid="ignore", divide="ignore"):
yield


Expand Down Expand Up @@ -901,7 +901,7 @@ def r2_summary_func(
D = p_AB - (p_A * p_B)
denom = p_A * p_B * (1 - p_A) * (1 - p_B)

with suppress_division_by_zero_warning():
with suppress_overflow_div0_warning():
result[k] = (D * D) / denom


Expand Down Expand Up @@ -952,7 +952,7 @@ def D_prime_summary_func(
p_B = p_AB + p_aB

D = p_AB - (p_A * p_B)
with suppress_division_by_zero_warning():
with suppress_overflow_div0_warning():
if D >= 0:
result[k] = D / min(p_A * (1 - p_B), (1 - p_A) * p_B)
else:
Expand All @@ -975,7 +975,7 @@ def r_summary_func(
D = p_AB - (p_A * p_B)
denom = p_A * p_B * (1 - p_A) * (1 - p_B)

with suppress_division_by_zero_warning():
with suppress_overflow_div0_warning():
result[k] = D / np.sqrt(denom)


Expand Down Expand Up @@ -1034,7 +1034,7 @@ def pi2_unbiased(
w_Ab = state[1, k]
w_aB = state[2, k]
w_ab = n - (w_AB + w_Ab + w_aB)
with np.errstate(over="ignore", divide="ignore", invalid="ignore"):
with suppress_overflow_div0_warning():
result[k] = (1 / (n * (n - 1) * (n - 2) * (n - 3))) * (
((w_AB + w_Ab) * (w_aB + w_ab) * (w_AB + w_aB) * (w_Ab + w_ab))
- ((w_AB * w_ab) * (w_AB + w_ab + (3 * w_Ab) + (3 * w_aB) - 1))
Expand All @@ -1052,7 +1052,7 @@ def dz_unbiased(
w_Ab = state[1, k]
w_aB = state[2, k]
w_ab = n - (w_AB + w_Ab + w_aB)
with np.errstate(over="ignore", divide="ignore", invalid="ignore"):
with suppress_overflow_div0_warning():
result[k] = (1 / (n * (n - 1) * (n - 2) * (n - 3))) * (
(
((w_AB * w_ab) - (w_Ab * w_aB))
Expand All @@ -1074,7 +1074,7 @@ def d2_unbiased(
w_Ab = state[1, k]
w_aB = state[2, k]
w_ab = n - (w_AB + w_Ab + w_aB)
with np.errstate(over="ignore", divide="ignore", invalid="ignore"):
with suppress_overflow_div0_warning():
result[k] = (1 / (n * (n - 1) * (n - 2) * (n - 3))) * (
((w_aB**2) * (w_Ab - 1) * w_Ab)
+ ((w_ab - 1) * w_ab * (w_AB - 1) * w_AB)
Expand Down

0 comments on commit a8cb721

Please sign in to comment.