Skip to content

Commit

Permalink
fixes #93. Adds workaround for casacore/python-casacore#130
Browse files Browse the repository at this point in the history
  • Loading branch information
o-smirnov committed Dec 18, 2020
1 parent 8347462 commit dfc504c
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 127 deletions.
59 changes: 48 additions & 11 deletions cubical/data_handler/ms_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,18 +471,19 @@ def __init__(self, ms_name, data_column, output_column=None, output_model_column
" ".join([str(ch) for ch in freqchunks + [nchan]])), file=log(0))

# now accumulate list of all frequencies, and also see if selected DDIDs have a uniform rebinning and chunking map
all_freqs = set(self.chanfreqs[self._ddids[0]])
self.do_freq_rebin = any([m is not None for m in list(self.rebin_chan_maps.values())])
self._ddids_unequal = False
ddid0_map = self.rebin_chan_maps[self._ddids[0]]
for ddid in self._ddids[1:]:
if len(self.chanfreqs[0]) != len(self.chanfreqs[ddid]):
self._ddids_unequal = True
break
map1 = self.rebin_chan_maps[ddid]
if ddid0_map is None and map1 is None:
continue
if (ddid0_map is None and map1 is not None) or (ddid0_map is not None and map1 is None) or \
len(ddid0_map) != len(map1) or (ddid0_map!=map1).any():
self._ddids_unequal = True
all_freqs.update(self.chanfreqs[ddid])

if self._ddids_unequal:
print("Selected DDIDs have differing channel structure. Processing may be less efficient.", file=log(0,"red"))
Expand Down Expand Up @@ -812,6 +813,40 @@ def fetch(self, colname, first_row=0, nrows=-1, subset=None):
(subset or self.data).getcolnp(str(colname), prealloc, first_row, nrows)
return prealloc

@staticmethod
def _get_row_chunk(array):
"""
Establishes max row chunk that can be used. Workaround for https://github.com/casacore/python-casacore/issues/130
array: array to be written, of shape (nrows, nfreq, ncorr)
"""
_maxchunk = 2**29 # max number of elements to write, see https://github.com/casacore/python-casacore/issues/130#issuecomment-748150854
nrows, nfreq, ncorr = array.shape
maxrows = max(1, _maxchunk // (nfreq*ncorr))
#if maxrows < nrows:
log(0).print(f" table I/O request of {nrows} rows: max chunk size is {maxrows} rows")
return maxrows, nrows

@staticmethod
def _getcolnp_wrapper(table, column, array, startrow):
"Calls table.getcolnp() in chunks of rows. Workaround for https://github.com/casacore/python-casacore/issues/130"
maxrows, nrows = MSDataHandler._get_row_chunk(array)
for row0 in range(0, nrows, maxrows):
table.getcolnp(column, array[row0:row0+maxrows], startrow+row0, maxrows)

@staticmethod
def _getcolslicenp_wrapper(table, column, array, begin, end, incr, startrow):
"Calls table.getcolnp() in chunks of rows. Workaround for https://github.com/casacore/python-casacore/issues/130"
maxrows, nrows = MSDataHandler._get_row_chunk(array)
for row0 in range(0, nrows, maxrows):
table.getcolslicenp(column, array[row0:row0+maxrows], begin, end, incr, startrow+row0, maxrows)

@staticmethod
def _putcol_wrapper(table, column, array, startrow):
"Calls table.putcol() in chunks of rows. Workaround for https://github.com/casacore/python-casacore/issues/130"
maxrows, nrows = MSDataHandler._get_row_chunk(array)
for row0 in range(0, nrows, maxrows):
table.putcol(column, array[row0:row0+maxrows], startrow+row0, maxrows)

def fetchslice(self, column, startrow=0, nrows=-1, subset=None):
"""
Convenience function similar to fetch(), but assumes a column of NFREQxNCORR shape,
Expand Down Expand Up @@ -840,16 +875,16 @@ def fetchslice(self, column, startrow=0, nrows=-1, subset=None):
shape = tuple([nrows] + [s for s in cell.shape]) if hasattr(cell, "shape") else nrows

prealloc = np.empty(shape, dtype=dtype)
subset.getcolnp(str(column), prealloc, startrow, nrows)
self._getcolnp_wrapper(subset, str(column), prealloc, startrow)
return prealloc
# ugly hack because getcell returns a different dtype to getcol
cell = (subset or self.data).getcol(str(column), startrow, nrow=1)[0, ...]
cell = subset.getcol(str(column), startrow, nrow=1)[0, ...]
dtype = getattr(cell, "dtype", type(cell))

shape = tuple([len(list(range(l, r + 1, i))) #inclusive in cc
for l, r, i in zip(self._ms_blc, self._ms_trc, self._ms_incr)])
prealloc = np.empty(shape, dtype=dtype)
subset.getcolslicenp(str(column), prealloc, self._ms_blc, self._ms_trc, self._ms_incr, startrow, nrows)
self._getcolslicenp_wrapper(subset, str(column), prealloc, self._ms_blc, self._ms_trc, self._ms_incr, startrow)
return prealloc

def fetchslicenp(self, column, data, startrow=0, nrows=-1, subset=None):
Expand All @@ -870,8 +905,8 @@ def fetchslicenp(self, column, data, startrow=0, nrows=-1, subset=None):
"""
subset = subset or self.data
if self._ms_blc == None:
return subset.getcolnp(column, data, startrow, nrows)
return subset.getcolslicenp(column, data, self._ms_blc, self._ms_trc, self._ms_incr, startrow, nrows)
return self._getcolnp_wrapper(subset, column, data, startrow)
return self._getcolslicenp_wrapper(subset, column, data, self._ms_blc, self._ms_trc, self._ms_incr, startrow)

def putslice(self, column, value, startrow=0, nrows=-1, subset=None):
"""
Expand All @@ -896,7 +931,7 @@ def putslice(self, column, value, startrow=0, nrows=-1, subset=None):
# if no slicing, just use putcol to put the whole thing. This always works,
# unless the MS is screwed up
if self._ms_blc == None:
return subset.putcol(str(column), value, startrow, nrows)
return self._putcol_wrapper(subset, str(column), value, startrow)
if nrows<0:
nrows = subset.nrows()

Expand All @@ -920,19 +955,21 @@ def putslice(self, column, value, startrow=0, nrows=-1, subset=None):
value[:] = np.bitwise_or.reduce(value, axis=2)[:,:,np.newaxis]

if self._channel_slice == slice(None) and self._corr_slice == slice(None):
return subset.putcol(column, value, startrow, nrows)
return self._putcol_wrapper(subset, value, startrow)
else:
# for bitflags, we want to preserve flags we haven't touched -- read the column
if column == "BITFLAG" or column == "FLAG":
value0 = subset.getcol(column)
value0 = np.empty_like(value)
self._getcolnp_wrapper(subset, column, value0, startrow)
# cheekily propagate per-corr flags to all corrs
value0[:] = np.bitwise_or.reduce(value0, axis=2)[:,:,np.newaxis]
# otherwise, init empty column
else:
ddid = subset.getcol("DATA_DESC_ID", 0, 1)[0]
shape = (nrows, self._nchan0_orig[ddid], self.nmscorrs)
value0 = np.zeros(shape, value.dtype)
value0[:, self._channel_slice, self._corr_slice] = value
return subset.putcol(str(column), value0, startrow, nrows)
return self._putcol_wrapper(subset, str(column), value0, startrow, nrows)

def define_chunk(self, chunk_time, rebin_time, fdim=1, chunk_by=None, chunk_by_jump=0, chunks_per_tile=4, max_chunks_per_tile=0):
"""
Expand Down
Loading

0 comments on commit dfc504c

Please sign in to comment.