Skip to content

Commit

Permalink
Add summary plots for ant,scan,corr
Browse files Browse the repository at this point in the history
  • Loading branch information
Athanaseus committed Nov 6, 2023
1 parent 933b904 commit 6e66de3
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 3 deletions.
10 changes: 10 additions & 0 deletions rfinder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def go(self,cfg_par):
1: 2d plot of RFI flagged by frequency and baseline lenght (plot_rfi_im)
2: 1d plot of RFI flagged by frequency channel (baselines_from_ms)
3: 1d plot of noise increase by frequency channel (for long and short baselines) (priors_flag)
4: 1d plot of overall RFI flagged by scan, antenna and correlation.
If cfg_par['rfi']['chunks']['time_chunks'] is enabled
1: executes 'rfi' and 'plots' procedure dividing the observation in time-steps given by cfg_par['rfi']['chunks']['time_step']
2: collects the info about the % of RFI for each time step in Alt/Az plots
Expand Down Expand Up @@ -427,6 +428,15 @@ def go(self,cfg_par):
self.logger.warning("------ End of RFI analysis ------\n")

task = 'plots'
if self.cfg_par[task]['plot_summary']['enable']==True:
if not self.cfg_par['rfi']['rfi_enable']:
rfi.load_from_ms(self.cfg_par,0,0)
for axis in self.cfg_par[task]['plot_summary']['axis']:
flag_stats = rfiST.get_flags_summary_stats(self.cfg_par, axis)
self.logger.warning(f" ------ Plotting {axis} summary plots ------\n")
rfiPL.plot_summary_stats(flag_stats, self.cfg_par, axis)
self.logger.info("------ Summary plot done ------\n")

if self.cfg_par[task]['plot_enable']==True:

if self.cfg_par['rfi']['chunks']['time_enable']==True:
Expand Down
20 changes: 17 additions & 3 deletions rfinder/rfi.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ def load_from_ms(self,cfg_par,times=0,counter=0):
Chan_width, Chan_freq
'''
STOKES_TYPES = {
0: "Undefined", 1: "I", 2: "Q", 3: "U", 4: "V",
5: "RR", 6: "RL", 7: "LR", 8: "LL",
9: "XX", 10: "XY", 11: "YX", 12: "YY",
13: "RX", 14: "RY", 15: "LX", 16: "LY",
17: "XR", 18: "XL", 19: "YR", 20: "YL",
21: "PP", 22: "PQ", 23: "QP", 24: "QQ"
}

if counter == 0 :
self.logger.warning("\t ... Field, Antenna & Bandwidth Info ...\n")
Expand Down Expand Up @@ -100,10 +108,12 @@ def load_from_ms(self,cfg_par,times=0,counter=0):
antennas.close()
self.ant_names = np.arange(0,self.ant_names.shape[0],1)


spw=tables.table(self.msfile+'/SPECTRAL_WINDOW')
self.channelWidths=spw.getcol('CHAN_WIDTH')

self.channelFreqs=spw.getcol('CHAN_FREQ')
cfg_par['rfi']['freqs'] = self.channelFreqs[0]
cfg_par['rfi']['chan_widths'] = self.channelWidths[0][0]
cfg_par['rfi']['lowfreq'] = float(self.channelFreqs[0][0])
cfg_par['rfi']['highfreq'] = float(self.channelFreqs[-1][-1])
Expand All @@ -122,6 +132,12 @@ def load_from_ms(self,cfg_par,times=0,counter=0):

t=tables.table(self.msfile)

if counter == 0 :
cfg_par['rfi']['scans'] = list(set(t.getcol('SCAN_NUMBER')))
corr = tables.table(self.msfile +'/POLARIZATION')
corr_types = corr.getcol('CORR_TYPE')[0]
cfg_par['rfi']['corrs'] = [STOKES_TYPES[corr_type] for corr_type in corr_types]

if counter !=0:
value_end = times[1]
value_start = times[0]
Expand Down Expand Up @@ -172,8 +188,6 @@ def load_from_ms(self,cfg_par,times=0,counter=0):
if cfg_par['rfi']['RFInder_mode'] == 'rms_clip':
self.logger.warning('\t Correct noise_measure_edges in rfi of parameter file ###')
empty_table=1



if not self.aperfi_badant:
nrbadant =len(int(self.aperfi_badant))
Expand All @@ -190,13 +204,13 @@ def load_from_ms(self,cfg_par,times=0,counter=0):
rfiST.predict_noise(cfg_par,self.channelWidths,self.interval,self.flag)
cfg_par['rfi']['vis_alltimes_baseline'] = self.flag.shape[0]/nrBaseline


t.close()

self.logger.info("\t ... info from MS file loaded \n\n")

return empty_table


def baselines_from_ms(self,cfg_par):
'''
Reads which baselines were used in the observations
Expand Down
53 changes: 53 additions & 0 deletions rfinder/rfinder_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,59 @@ def plot_altaz_short(self,cfg_par):

return 0

def plot_summary_stats(self, flag_stats, cfg_par, key='corr'):

summaryplot = cfg_par['general']['plotdir'] + f'{key}-summary.png'

# initialize plotting parameters
params = {'font.family' :' serif',
'font.style' : 'normal',
'font.weight' : 'book',
'font.size' : 20.0,
'axes.linewidth' : 1,
'lines.linewidth' : 1,
'xtick.labelsize' : 16,
'ytick.labelsize' : 16,
'xtick.direction' :'in',
'ytick.direction' :'in',
'xtick.major.size' : 4,
'xtick.major.width' : 1,
'xtick.minor.size' : 2,
'xtick.minor.width' : 1,
'ytick.major.size' : 4,
'ytick.major.width' : 1,
'ytick.minor.size' : 2,
'ytick.minor.width' : 1,
'text.usetex' : False,
}
plt.rcParams.update(params)
fig, ax = plt.subplots(figsize=(16,9))
antenna_plot = ax.bar(flag_stats.keys(), flag_stats.values(), color="orange", ec="red", align='center')

if key in ["antenna", "ant"]:
plt.title("Antenna flags")
plt.xlabel("Antenna")
plt.ylabel("% flagged visibilities")
plt.xticks(rotation=90)
plt.savefig(summaryplot)
elif key in ["scan"]:
plt.title("Scan flags")
plt.xlabel("Scan Number")
plt.ylabel("% flagged visibilities")
#plt.xticks(rotation=90)
plt.savefig(summaryplot)
elif key in ["correlation", "corr"]:
plt.title("Correlation flags")
plt.xlabel("Correlation")
plt.ylabel("Flagged percentage (%)")
plt.ylabel("% flagged visibilities")
plt.xticks(rotation=90)
plt.savefig(summaryplot)
self.logger.info(f" ------ Saving: {summaryplot} ------\n")

return 0


def gif_me_up(self,cfg_par,filenames,outmovie):


Expand Down
95 changes: 95 additions & 0 deletions rfinder/rfinder_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import glob
import casacore.tables as tables
import logging
import multiprocessing


from astropy.time import Time
Expand Down Expand Up @@ -223,3 +224,97 @@ def alt_az(self,cfg_par,time):
##self.logger.info("\t ... Alt/Az done ... \n")

return obs_altaz


def get_flags_summary_stats(self, cfg_par, axis):

def data_query(t, taql, name, axis='ant'):
flagtab = t.query(query=taql, columns='DATA_DESC_ID,FLAG')
cell_shape = flagtab.getcell('FLAG', 0).shape
flag_col = np.empty((flagtab.nrows(), cell_shape[0], cell_shape[1]), dtype=bool)
flagtab.getcolnp('FLAG', flag_col)
ddid_col = flagtab.getcol('DATA_DESC_ID')
flagtab.done()

if axis in ['corr']:
# For this we need to index that data appropriately
vals,counts = np.unique(flag_col[:,:,name],return_counts=True)
name = cfg_par['rfi']['corrs'][name]
else:
vals,counts = np.unique(flag_col,return_counts=True)
name = str(name)

if len(vals) == 1:
flag_percent = 100.0 if vals[0] else 0.0
else:
flag_percent = round(100.0*float(counts[1])/float(np.sum(counts)),2)

self.flag_percents[name] = flag_percent

def flag_bars(flag_stats, key):
"""Displays output directly to the screen console or logfile
https://github.com/IanHeywood/ms_info/blob/master/ms_flags.py
"""
self.logger.info('')
self.logger.info(f'Flagged percentages per {key}:')
self.logger.info('')
self.logger.info(' 0% 20% 40% 60% 80% 100%')
self.logger.info(' | | | | | |')
for fs in flag_stats.items():
name = fs[0]
average_pc = fs[1]
length = int(average_pc / 2.0)
self.logger.info(' %-9s %-7s %s'% (name,str(round(average_pc,1))+'%','∎' * length))
self.logger.info('')

ncpu = cfg_par['general']['ncpu']
self.logger.info("\t ... Observing time Info ... \n")
self.msfile = cfg_par['general']['msfullpath']
t=tables.table(self.msfile)
t = t.query(query=f"FIELD_ID=={cfg_par['general']['field']}")
self.flag_percents = multiprocessing.Manager().dict()
processes = []
if axis in ['ant', 'antenna']:
for index, ant_name in enumerate(cfg_par['rfi']['ant_names']):
taql = f"ANTENNA1=={index} || ANTENNA2=={index}"
p = multiprocessing.Process(target=data_query, args=(t, taql, ant_name))
p.start()
processes.append(p)
if len(processes) == ncpu:
for p in processes:
p.join()
processes = []
if len(processes) > 1:
for p in processes:
p.join()
if axis in ['scan']:
scan_ids = list(set(t.getcol('SCAN_NUMBER')))
for index, scan_id in enumerate(scan_ids):
taql = f'SCAN_NUMBER=={str(scan_id)}'
p = multiprocessing.Process(target=data_query, args=(t, taql, scan_id))
p.start()
processes.append(p)
if len(processes) == ncpu:
for p in processes:
p.join()
processes = []
if len(processes) > 1:
for p in processes:
p.join()
if axis in ['corr']:
for index, corr_type in enumerate(cfg_par['rfi']['corrs']):
taql = f''
p = multiprocessing.Process(target=data_query, args=(t, taql, index, axis))
p.start()
processes.append(p)
if len(processes) == ncpu:
for p in processes:
p.join()
processes = []
if len(processes) > 1:
for p in processes:
p.join()
t.close()
flag_bars(self.flag_percents, axis)

return self.flag_percents

0 comments on commit 6e66de3

Please sign in to comment.