Skip to content

Commit

Permalink
Disable caching of bda_mapper and re-enable its test case
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Jan 26, 2024
1 parent 87aab3e commit 90888ad
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 43 deletions.
10 changes: 5 additions & 5 deletions africanus/averaging/bda_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import numba
from numba.experimental import jitclass
import numba.types
from numba import types

from africanus.constants import c as lightspeed
from africanus.util.numba import (
Expand Down Expand Up @@ -64,7 +64,7 @@ def max_chan_width(ref_freq, fractional_bandwidth):
"nchan", "flag"])


class Binner(object):
class Binner:
def __init__(self, row_start, row_end,
max_lm, decorrelation, time_bin_secs,
max_chan_freq):
Expand Down Expand Up @@ -305,7 +305,7 @@ def bda_mapper_impl(time, interval, ant1, ant2, uvw,
min_nchan=1):
return NotImplementedError

@overload(bda_mapper_impl, jit_options=JIT_OPTIONS)
@overload(bda_mapper_impl, jit_options={"nogil": True})
def nb_bda_mapper(time, interval, ant1, ant2, uvw,
chan_width, chan_freq,
max_uvw_dist,
Expand All @@ -316,7 +316,7 @@ def nb_bda_mapper(time, interval, ant1, ant2, uvw,
min_nchan=1):
have_time_bin_secs = not is_numba_type_none(time_bin_secs)

Omitted = numba.types.misc.Omitted
Omitted = types.misc.Omitted

decorr_type = (numba.typeof(decorrelation.value)
if isinstance(decorrelation, Omitted)
Expand Down Expand Up @@ -347,7 +347,7 @@ def nb_bda_mapper(time, interval, ant1, ant2, uvw,
('max_chan_freq', chan_freq.dtype),
('max_uvw_dist', max_uvw_dist)]

JitBinner = st(spec)(Binner)
JitBinner = jitclass(spec)(Binner)

def impl(time, interval, ant1, ant2, uvw,
chan_width, chan_freq,
Expand Down
76 changes: 38 additions & 38 deletions africanus/averaging/tests/test_bda_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from africanus.averaging.bda_mapping import bda_mapper, Binner

from africanus.util.numba import njit

@pytest.fixture(scope="session", params=[4096])
def nchan(request):
Expand Down Expand Up @@ -172,43 +172,43 @@ def synthesized_uvw(ants, time, phase_dir, auto_correlations):
return ant1, ant2, uvw


# @pytest.mark.parametrize("decorrelation", [0.95])
# @pytest.mark.parametrize("min_nchan", [1])
# def test_bda_mapper(time, synthesized_uvw, interval,
# chan_freq, chan_width,
# decorrelation, min_nchan):
# time = np.unique(time)
# ant1, ant2, uvw = synthesized_uvw

# nbl = ant1.shape[0]
# ntime = time.shape[0]

# time = np.repeat(time, nbl)
# interval = np.repeat(interval, nbl)
# ant1 = np.tile(ant1, ntime)
# ant2 = np.tile(ant2, ntime)
# flag_row = np.zeros(time.shape[0], dtype=np.int8)

# max_uvw_dist = np.sqrt(np.sum(uvw**2, axis=1)).max()

# row_meta = bda_mapper(time, interval, ant1, ant2, uvw, # noqa :F841
# chan_width, chan_freq,
# max_uvw_dist,
# flag_row=flag_row,
# max_fov=3.0,
# decorrelation=decorrelation,
# min_nchan=min_nchan)

# offsets = np.unique(row_meta.map[np.arange(time.shape[0]), 0])
# assert_array_equal(offsets, row_meta.offsets[:-1])
# assert row_meta.map.max() + 1 == row_meta.offsets[-1]

# # NUM_CHAN divides number of channels exactly
# num_chan = np.diff(row_meta.offsets)
# _, remainder = np.divmod(chan_width.shape[0], num_chan)
# assert np.all(remainder == 0)
# decorr_cw = chan_width.sum() / num_chan
# assert_array_equal(decorr_cw, row_meta.decorr_chan_width)
@pytest.mark.parametrize("decorrelation", [0.95])
@pytest.mark.parametrize("min_nchan", [1])
def test_bda_mapper(time, synthesized_uvw, interval,
chan_freq, chan_width,
decorrelation, min_nchan):
time = np.unique(time)
ant1, ant2, uvw = synthesized_uvw

nbl = ant1.shape[0]
ntime = time.shape[0]

time = np.repeat(time, nbl)
interval = np.repeat(interval, nbl)
ant1 = np.tile(ant1, ntime)
ant2 = np.tile(ant2, ntime)
flag_row = np.zeros(time.shape[0], dtype=np.int8)

max_uvw_dist = np.sqrt(np.sum(uvw**2, axis=1)).max()

row_meta = bda_mapper(time, interval, ant1, ant2, uvw, # noqa :F841
chan_width, chan_freq,
max_uvw_dist,
flag_row=flag_row,
max_fov=3.0,
decorrelation=decorrelation,
min_nchan=min_nchan)

offsets = np.unique(row_meta.map[np.arange(time.shape[0]), 0])
assert_array_equal(offsets, row_meta.offsets[:-1])
assert row_meta.map.max() + 1 == row_meta.offsets[-1]

# NUM_CHAN divides number of channels exactly
num_chan = np.diff(row_meta.offsets)
_, remainder = np.divmod(chan_width.shape[0], num_chan)
assert np.all(remainder == 0)
decorr_cw = chan_width.sum() / num_chan
assert_array_equal(decorr_cw, row_meta.decorr_chan_width)


def test_bda_binner(time, interval, synthesized_uvw,
Expand Down

0 comments on commit 90888ad

Please sign in to comment.