diff --git a/rfinder/main.py b/rfinder/main.py index 3986eac..d6fbbea 100644 --- a/rfinder/main.py +++ b/rfinder/main.py @@ -351,6 +351,7 @@ def go(self,cfg_par): 1: 2d plot of RFI flagged by frequency and baseline lenght (plot_rfi_im) 2: 1d plot of RFI flagged by frequency channel (baselines_from_ms) 3: 1d plot of noise increase by frequency channel (for long and short baselines) (priors_flag) + 4: 1d plot of overall RFI flagged by scan, antenna and correlation. If cfg_par['rfi']['chunks']['time_chunks'] is enabled 1: executes 'rfi' and 'plots' procedure dividing the observation in time-steps given by cfg_par['rfi']['chunks']['time_step'] 2: collects the info about the % of RFI for each time step in Alt/Az plots @@ -427,6 +428,15 @@ def go(self,cfg_par): self.logger.warning("------ End of RFI analysis ------\n") task = 'plots' + if self.cfg_par[task]['plot_summary']['enable']==True: + if not self.cfg_par['rfi']['rfi_enable']: + rfi.load_from_ms(self.cfg_par,0,0) + for axis in self.cfg_par[task]['plot_summary']['axis']: + flag_stats = rfiST.get_flags_summary_stats(self.cfg_par, axis) + self.logger.warning(f" ------ Plotting {axis} summary plots ------\n") + rfiPL.plot_summary_stats(flag_stats, self.cfg_par, axis) + self.logger.info("------ Summary plot done ------\n") + if self.cfg_par[task]['plot_enable']==True: if self.cfg_par['rfi']['chunks']['time_enable']==True: diff --git a/rfinder/rfi.py b/rfinder/rfi.py index 04270cf..f718e33 100644 --- a/rfinder/rfi.py +++ b/rfinder/rfi.py @@ -60,6 +60,14 @@ def load_from_ms(self,cfg_par,times=0,counter=0): Chan_width, Chan_freq ''' + STOKES_TYPES = { + 0: "Undefined", 1: "I", 2: "Q", 3: "U", 4: "V", + 5: "RR", 6: "RL", 7: "LR", 8: "LL", + 9: "XX", 10: "XY", 11: "YX", 12: "YY", + 13: "RX", 14: "RY", 15: "LX", 16: "LY", + 17: "XR", 18: "XL", 19: "YR", 20: "YL", + 21: "PP", 22: "PQ", 23: "QP", 24: "QQ" + } if counter == 0 : self.logger.warning("\t ... Field, Antenna & Bandwidth Info ...\n") @@ -100,10 +108,12 @@ def load_from_ms(self,cfg_par,times=0,counter=0): antennas.close() self.ant_names = np.arange(0,self.ant_names.shape[0],1) + spw=tables.table(self.msfile+'/SPECTRAL_WINDOW') self.channelWidths=spw.getcol('CHAN_WIDTH') self.channelFreqs=spw.getcol('CHAN_FREQ') + cfg_par['rfi']['freqs'] = self.channelFreqs[0] cfg_par['rfi']['chan_widths'] = self.channelWidths[0][0] cfg_par['rfi']['lowfreq'] = float(self.channelFreqs[0][0]) cfg_par['rfi']['highfreq'] = float(self.channelFreqs[-1][-1]) @@ -122,6 +132,12 @@ def load_from_ms(self,cfg_par,times=0,counter=0): t=tables.table(self.msfile) + if counter == 0 : + cfg_par['rfi']['scans'] = list(set(t.getcol('SCAN_NUMBER'))) + corr = tables.table(self.msfile +'/POLARIZATION') + corr_types = corr.getcol('CORR_TYPE')[0] + cfg_par['rfi']['corrs'] = [STOKES_TYPES[corr_type] for corr_type in corr_types] + if counter !=0: value_end = times[1] value_start = times[0] @@ -172,8 +188,6 @@ def load_from_ms(self,cfg_par,times=0,counter=0): if cfg_par['rfi']['RFInder_mode'] == 'rms_clip': self.logger.warning('\t Correct noise_measure_edges in rfi of parameter file ###') empty_table=1 - - if not self.aperfi_badant: nrbadant =len(int(self.aperfi_badant)) @@ -190,13 +204,13 @@ def load_from_ms(self,cfg_par,times=0,counter=0): rfiST.predict_noise(cfg_par,self.channelWidths,self.interval,self.flag) cfg_par['rfi']['vis_alltimes_baseline'] = self.flag.shape[0]/nrBaseline - t.close() self.logger.info("\t ... info from MS file loaded \n\n") return empty_table + def baselines_from_ms(self,cfg_par): ''' Reads which baselines were used in the observations diff --git a/rfinder/rfinder_plots.py b/rfinder/rfinder_plots.py index d9b542a..f63b377 100644 --- a/rfinder/rfinder_plots.py +++ b/rfinder/rfinder_plots.py @@ -882,6 +882,59 @@ def plot_altaz_short(self,cfg_par): return 0 + def plot_summary_stats(self, flag_stats, cfg_par, key='corr'): + + summaryplot = cfg_par['general']['plotdir'] + f'{key}-summary.png' + + # initialize plotting parameters + params = {'font.family' :' serif', + 'font.style' : 'normal', + 'font.weight' : 'book', + 'font.size' : 20.0, + 'axes.linewidth' : 1, + 'lines.linewidth' : 1, + 'xtick.labelsize' : 16, + 'ytick.labelsize' : 16, + 'xtick.direction' :'in', + 'ytick.direction' :'in', + 'xtick.major.size' : 4, + 'xtick.major.width' : 1, + 'xtick.minor.size' : 2, + 'xtick.minor.width' : 1, + 'ytick.major.size' : 4, + 'ytick.major.width' : 1, + 'ytick.minor.size' : 2, + 'ytick.minor.width' : 1, + 'text.usetex' : False, + } + plt.rcParams.update(params) + fig, ax = plt.subplots(figsize=(16,9)) + antenna_plot = ax.bar(flag_stats.keys(), flag_stats.values(), color="orange", ec="red", align='center') + + if key in ["antenna", "ant"]: + plt.title("Antenna flags") + plt.xlabel("Antenna") + plt.ylabel("% flagged visibilities") + plt.xticks(rotation=90) + plt.savefig(summaryplot) + elif key in ["scan"]: + plt.title("Scan flags") + plt.xlabel("Scan Number") + plt.ylabel("% flagged visibilities") + #plt.xticks(rotation=90) + plt.savefig(summaryplot) + elif key in ["correlation", "corr"]: + plt.title("Correlation flags") + plt.xlabel("Correlation") + plt.ylabel("Flagged percentage (%)") + plt.ylabel("% flagged visibilities") + plt.xticks(rotation=90) + plt.savefig(summaryplot) + self.logger.info(f" ------ Saving: {summaryplot} ------\n") + + return 0 + + def gif_me_up(self,cfg_par,filenames,outmovie): diff --git a/rfinder/rfinder_stats.py b/rfinder/rfinder_stats.py index 456ba69..a60c328 100644 --- a/rfinder/rfinder_stats.py +++ b/rfinder/rfinder_stats.py @@ -8,6 +8,7 @@ import glob import casacore.tables as tables import logging +import multiprocessing from astropy.time import Time @@ -223,3 +224,97 @@ def alt_az(self,cfg_par,time): ##self.logger.info("\t ... Alt/Az done ... \n") return obs_altaz + + + def get_flags_summary_stats(self, cfg_par, axis): + + def data_query(t, taql, name, axis='ant'): + flagtab = t.query(query=taql, columns='DATA_DESC_ID,FLAG') + cell_shape = flagtab.getcell('FLAG', 0).shape + flag_col = np.empty((flagtab.nrows(), cell_shape[0], cell_shape[1]), dtype=bool) + flagtab.getcolnp('FLAG', flag_col) + ddid_col = flagtab.getcol('DATA_DESC_ID') + flagtab.done() + + if axis in ['corr']: + # For this we need to index that data appropriately + vals,counts = np.unique(flag_col[:,:,name],return_counts=True) + name = cfg_par['rfi']['corrs'][name] + else: + vals,counts = np.unique(flag_col,return_counts=True) + name = str(name) + + if len(vals) == 1: + flag_percent = 100.0 if vals[0] else 0.0 + else: + flag_percent = round(100.0*float(counts[1])/float(np.sum(counts)),2) + + self.flag_percents[name] = flag_percent + + def flag_bars(flag_stats, key): + """Displays output directly to the screen console or logfile + https://github.com/IanHeywood/ms_info/blob/master/ms_flags.py + """ + self.logger.info('') + self.logger.info(f'Flagged percentages per {key}:') + self.logger.info('') + self.logger.info(' 0% 20% 40% 60% 80% 100%') + self.logger.info(' | | | | | |') + for fs in flag_stats.items(): + name = fs[0] + average_pc = fs[1] + length = int(average_pc / 2.0) + self.logger.info(' %-9s %-7s %s'% (name,str(round(average_pc,1))+'%','∎' * length)) + self.logger.info('') + + ncpu = cfg_par['general']['ncpu'] + self.logger.info("\t ... Observing time Info ... \n") + self.msfile = cfg_par['general']['msfullpath'] + t=tables.table(self.msfile) + t = t.query(query=f"FIELD_ID=={cfg_par['general']['field']}") + self.flag_percents = multiprocessing.Manager().dict() + processes = [] + if axis in ['ant', 'antenna']: + for index, ant_name in enumerate(cfg_par['rfi']['ant_names']): + taql = f"ANTENNA1=={index} || ANTENNA2=={index}" + p = multiprocessing.Process(target=data_query, args=(t, taql, ant_name)) + p.start() + processes.append(p) + if len(processes) == ncpu: + for p in processes: + p.join() + processes = [] + if len(processes) > 1: + for p in processes: + p.join() + if axis in ['scan']: + scan_ids = list(set(t.getcol('SCAN_NUMBER'))) + for index, scan_id in enumerate(scan_ids): + taql = f'SCAN_NUMBER=={str(scan_id)}' + p = multiprocessing.Process(target=data_query, args=(t, taql, scan_id)) + p.start() + processes.append(p) + if len(processes) == ncpu: + for p in processes: + p.join() + processes = [] + if len(processes) > 1: + for p in processes: + p.join() + if axis in ['corr']: + for index, corr_type in enumerate(cfg_par['rfi']['corrs']): + taql = f'' + p = multiprocessing.Process(target=data_query, args=(t, taql, index, axis)) + p.start() + processes.append(p) + if len(processes) == ncpu: + for p in processes: + p.join() + processes = [] + if len(processes) > 1: + for p in processes: + p.join() + t.close() + flag_bars(self.flag_percents, axis) + + return self.flag_percents