Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The ragged SPW bonanza #424

Merged
merged 12 commits into from
May 12, 2021
63 changes: 51 additions & 12 deletions cubical/data_handler/ms_data_handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# CubiCal: a radio interferometric calibration suite
# (c) 2017 Rhodes University & Jonathan S. Kenyon
# http://github.com/ratt-ru/CubiCal
Expand Down Expand Up @@ -317,7 +318,8 @@ def __init__(self, ms_name, data_column, output_column=None, output_model_column

assert set(spwtabcols) <= set(
_spwtab.colnames()), "Measurement set conformance error - keyword table SPECTRAL_WINDOW incomplete. Perhaps disable --out-casa-gaintables or check your MS!"
self._spwtabcols = {t: _spwtab.getcol(t) for t in spwtabcols}
nrows = _spwtab.nrows()
self._spwtabcols = {t: [_spwtab.getcol(t, row, 1) for row in range(nrows)] for t in spwtabcols}

# read observation details
obstabcols = ["TIME_RANGE", "LOG", "SCHEDULE", "FLAG_ROW",
Expand Down Expand Up @@ -470,18 +472,19 @@ def __init__(self, ms_name, data_column, output_column=None, output_model_column
" ".join([str(ch) for ch in freqchunks + [nchan]])), file=log(0))

# now accumulate list of all frequencies, and also see if selected DDIDs have a uniform rebinning and chunking map
all_freqs = set(self.chanfreqs[self._ddids[0]])
self.do_freq_rebin = any([m is not None for m in list(self.rebin_chan_maps.values())])
self._ddids_unequal = False
ddid0_map = self.rebin_chan_maps[self._ddids[0]]
for ddid in self._ddids[1:]:
if len(self.chanfreqs[0]) != len(self.chanfreqs[ddid]):
self._ddids_unequal = True
break
map1 = self.rebin_chan_maps[ddid]
if ddid0_map is None and map1 is None:
continue
if (ddid0_map is None and map1 is not None) or (ddid0_map is not None and map1 is None) or \
len(ddid0_map) != len(map1) or (ddid0_map!=map1).any():
self._ddids_unequal = True
all_freqs.update(self.chanfreqs[ddid])

if self._ddids_unequal:
print("Selected DDIDs have differing channel structure. Processing may be less efficient.", file=log(0,"red"))
Expand Down Expand Up @@ -811,6 +814,40 @@ def fetch(self, colname, first_row=0, nrows=-1, subset=None):
(subset or self.data).getcolnp(str(colname), prealloc, first_row, nrows)
return prealloc

@staticmethod
def _get_row_chunk(array):
"""
Establishes max row chunk that can be used. Workaround for https://github.com/casacore/python-casacore/issues/130
array: array to be written, of shape (nrows, nfreq, ncorr)
"""
_maxchunk = 2**29 # max number of elements to write, see https://github.com/casacore/python-casacore/issues/130#issuecomment-748150854
nrows, nfreq, ncorr = array.shape
maxrows = max(1, _maxchunk // (nfreq*ncorr))
#if maxrows < nrows:
log(1).print(f" table I/O request of {nrows} rows: max chunk size is {maxrows} rows")
return maxrows, nrows

@staticmethod
def _getcolnp_wrapper(table, column, array, startrow):
"Calls table.getcolnp() in chunks of rows. Workaround for https://github.com/casacore/python-casacore/issues/130"
maxrows, nrows = MSDataHandler._get_row_chunk(array)
for row0 in range(0, nrows, maxrows):
table.getcolnp(column, array[row0:row0+maxrows], startrow+row0, maxrows)

@staticmethod
def _getcolslicenp_wrapper(table, column, array, begin, end, incr, startrow):
"Calls table.getcolnp() in chunks of rows. Workaround for https://github.com/casacore/python-casacore/issues/130"
maxrows, nrows = MSDataHandler._get_row_chunk(array)
for row0 in range(0, nrows, maxrows):
table.getcolslicenp(column, array[row0:row0+maxrows], begin, end, incr, startrow+row0, maxrows)

@staticmethod
def _putcol_wrapper(table, column, array, startrow):
"Calls table.putcol() in chunks of rows. Workaround for https://github.com/casacore/python-casacore/issues/130"
maxrows, nrows = MSDataHandler._get_row_chunk(array)
for row0 in range(0, nrows, maxrows):
table.putcol(column, array[row0:row0+maxrows], startrow+row0, maxrows)

def fetchslice(self, column, startrow=0, nrows=-1, subset=None):
"""
Convenience function similar to fetch(), but assumes a column of NFREQxNCORR shape,
Expand Down Expand Up @@ -839,16 +876,16 @@ def fetchslice(self, column, startrow=0, nrows=-1, subset=None):
shape = tuple([nrows] + [s for s in cell.shape]) if hasattr(cell, "shape") else nrows

prealloc = np.empty(shape, dtype=dtype)
subset.getcolnp(str(column), prealloc, startrow, nrows)
self._getcolnp_wrapper(subset, str(column), prealloc, startrow)
return prealloc
# ugly hack because getcell returns a different dtype to getcol
cell = (subset or self.data).getcol(str(column), startrow, nrow=1)[0, ...]
cell = subset.getcol(str(column), startrow, nrow=1)[0, ...]
dtype = getattr(cell, "dtype", type(cell))

shape = tuple([len(list(range(l, r + 1, i))) #inclusive in cc
for l, r, i in zip(self._ms_blc, self._ms_trc, self._ms_incr)])
prealloc = np.empty(shape, dtype=dtype)
subset.getcolslicenp(str(column), prealloc, self._ms_blc, self._ms_trc, self._ms_incr, startrow, nrows)
self._getcolslicenp_wrapper(subset, str(column), prealloc, self._ms_blc, self._ms_trc, self._ms_incr, startrow)
return prealloc

def fetchslicenp(self, column, data, startrow=0, nrows=-1, subset=None):
Expand All @@ -869,8 +906,8 @@ def fetchslicenp(self, column, data, startrow=0, nrows=-1, subset=None):
"""
subset = subset or self.data
if self._ms_blc == None:
return subset.getcolnp(column, data, startrow, nrows)
return subset.getcolslicenp(column, data, self._ms_blc, self._ms_trc, self._ms_incr, startrow, nrows)
return self._getcolnp_wrapper(subset, column, data, startrow)
return self._getcolslicenp_wrapper(subset, column, data, self._ms_blc, self._ms_trc, self._ms_incr, startrow)

def putslice(self, column, value, startrow=0, nrows=-1, subset=None):
"""
Expand All @@ -895,7 +932,7 @@ def putslice(self, column, value, startrow=0, nrows=-1, subset=None):
# if no slicing, just use putcol to put the whole thing. This always works,
# unless the MS is screwed up
if self._ms_blc == None:
return subset.putcol(str(column), value, startrow, nrows)
return self._putcol_wrapper(subset, str(column), value, startrow)
if nrows<0:
nrows = subset.nrows()

Expand All @@ -919,19 +956,21 @@ def putslice(self, column, value, startrow=0, nrows=-1, subset=None):
value[:] = np.bitwise_or.reduce(value, axis=2)[:,:,np.newaxis]

if self._channel_slice == slice(None) and self._corr_slice == slice(None):
return subset.putcol(column, value, startrow, nrows)
return self._putcol_wrapper(subset, value, startrow)
else:
# for bitflags, we want to preserve flags we haven't touched -- read the column
if column == "BITFLAG" or column == "FLAG":
value0 = subset.getcol(column)
value0 = np.empty_like(value)
self._getcolnp_wrapper(subset, column, value0, startrow)
# cheekily propagate per-corr flags to all corrs
value0[:] = np.bitwise_or.reduce(value0, axis=2)[:,:,np.newaxis]
# otherwise, init empty column
else:
ddid = subset.getcol("DATA_DESC_ID", 0, 1)[0]
shape = (nrows, self._nchan0_orig[ddid], self.nmscorrs)
value0 = np.zeros(shape, value.dtype)
value0[:, self._channel_slice, self._corr_slice] = value
return subset.putcol(str(column), value0, startrow, nrows)
return self._putcol_wrapper(subset, str(column), value0, startrow, nrows)

def define_chunk(self, chunk_time, rebin_time, fdim=1, chunk_by=None, chunk_by_jump=0, chunks_per_tile=4, max_chunks_per_tile=0):
"""
Expand Down
Loading