Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Waveform plot improvements #252

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mtuq/graphics/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from matplotlib import pyplot
from obspy.geodetics import gps2dist_azimuth
# location to degree distance with obspy
from obspy.geodetics import locations2degrees


def station_label_writer(ax, station, origin, units='km'):
Expand Down Expand Up @@ -31,7 +33,8 @@ def station_label_writer(ax, station, origin, units='km'):
label = '%d km' % round(distance_in_m/1000.)

elif units=='deg':
label = '%d%s' % (round(m_to_deg(distance_in_m)), u'\N{DEGREE SIGN}')
label = '%d%s' % (round(locations2degrees(origin.latitude, origin.longitude,
station.latitude, station.longitude)), u'\N{DEGREE SIGN}')

pyplot.text(0.2,0.35, label, fontsize=11, transform=ax.transAxes)

Expand Down
38 changes: 36 additions & 2 deletions mtuq/graphics/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(self, process_bw, process_sw, misfit_bw, misfit_sw,
if not process_bw:
pass
if not process_sw:
raise Excpetion()
raise Exception()

if process_sw.freq_max > 1.:
units = 'Hz'
Expand Down Expand Up @@ -243,6 +243,23 @@ def write(self, height, width, margin_left, margin_top):

line = _focal_mechanism(self.lune_dict)
line += ', '+_gamma_delta(self.lune_dict)


# Get additional header info. Only N, Np and Ns possible at the moment, but this could change in the future
# Check which fields are present as attributes in MomentTensorHeader object.
# Possible way to extend this is to create additional "formating" functions and call them here.
if hasattr(self, 'N') and hasattr(self, 'Np') and hasattr(self, 'Ns'):
self.additional_info = "N-Np-Ns : " + str(self.N) + "-" + str(self.Np) + "-" + str(self.Ns)
elif hasattr(self, 'N'):
self.additional_info = "N : " + str(self.N)
else:
self.additional_info = None

# After checking which fields are present, add them to the line
if self.additional_info:
line += ', ' + self.additional_info

# Write the modified line
_write_text(line, px, py, ax, fontsize=14)


Expand Down Expand Up @@ -291,7 +308,7 @@ def __init__(self, process_bw, process_sw, misfit_bw, misfit_sw,
if not process_bw:
pass
if not process_sw:
raise Excpetion()
raise Exception()

if process_sw.freq_max > 1.:
units = 'Hz'
Expand Down Expand Up @@ -366,6 +383,23 @@ def write(self, height, width, margin_left, margin_top):
py -= 0.30

line = _phi_theta(self.force_dict)

# Same as in MomentTensorHeader above.
if hasattr(self, 'N') and hasattr(self, 'Np') and hasattr(self, 'Ns'):
self.additional_info = "N-Np-Ns : " + str(self.N) + "-" + str(self.Np) + "-" + str(self.Ns)
elif hasattr(self, 'N'):
self.additional_info = "N : " + str(self.N)
else:
self.additional_info = None

# After checking which fields are present, add them to the line
if self.additional_info:
line += ', ' + self.additional_info

# Write the modified line
_write_text(line, px, py, ax, fontsize=14)

# Write the modified line
_write_text(line, px, py, ax, fontsize=14)


Expand Down
111 changes: 83 additions & 28 deletions mtuq/graphics/waveforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,26 @@ def plot_waveforms1(filename,

_add_component_labels1(axes)

# determine maximum trace amplitudes
max_amplitude = _max(data, synthetics)

# Calculate maximum amplitudes for dataset
max_amplitudes = np.asarray([
_max(data[i], synthetics[i]) if data[i] and synthetics[i] else 0.0
for i in range(len(data))
])


if normalize == 'median_amplitude':
max_amplitudes_median = np.median(max_amplitudes[max_amplitudes>0])
max_amplitudes = [max_amplitudes_median if data[i] and synthetics[i] else 0.0 for i in range(len(data))]
elif normalize == 'maximum_amplitude':
max_amplitudes = [max_amplitude if data[i] and synthetics[i] else 0.0 for i in range(len(data))]
elif normalize == 'station_amplitude' or normalize == 'trace_amplitude':
pass
else:
raise ValueError("Invalid normalization method specified.")

#
# loop over stations
#
Expand Down Expand Up @@ -100,7 +118,7 @@ def plot_waveforms1(filename,
continue

_plot_ZRT(axes[ir], 1, dat, syn, component,
normalize, trace_label_writer, max_amplitude, total_misfit)
normalize, trace_label_writer, max_amplitudes[_i], total_misfit)

ir += 1

Expand Down Expand Up @@ -153,6 +171,31 @@ def plot_waveforms2(filename,
max_amplitude_sw = _max(data_sw, synthetics_sw)


# Calculate maximum amplitudes for body wave and surface wave data
max_amplitudes_bw = np.asarray([
_max(data_bw[i], synthetics_bw[i]) if data_bw[i] and synthetics_bw[i] else 0.0
for i in range(len(data_bw))
])

max_amplitudes_sw = np.asarray([
_max(data_sw[i], synthetics_sw[i]) if data_sw[i] and synthetics_sw[i] else 0.0
for i in range(len(data_sw))
])

# Normalize amplitudes based on the specified method
if normalize == 'median_amplitude':
max_amplitudes_bw_median = np.median(max_amplitudes_bw[max_amplitudes_bw > 0])
max_amplitudes_sw_median = np.median(max_amplitudes_sw[max_amplitudes_sw > 0])
max_amplitudes_bw = [max_amplitudes_bw_median if data_bw[i] and synthetics_bw[i] else 0.0 for i in range(len(data_bw))]
max_amplitudes_sw = [max_amplitudes_sw_median if data_sw[i] and synthetics_sw[i] else 0.0 for i in range(len(data_sw))]
elif normalize == 'maximum_amplitude':
max_amplitudes_bw = [max_amplitude_bw if data_bw[i] and synthetics_bw[i] else 0.0 for i in range(len(data_bw))]
max_amplitudes_sw = [max_amplitude_sw if data_sw[i] and synthetics_sw[i] else 0.0 for i in range(len(data_sw))]
elif normalize == 'station_amplitude' or normalize == 'trace_amplitude':
pass
else:
raise ValueError("Invalid normalization method specified.")

#
# loop over stations
#
Expand Down Expand Up @@ -191,8 +234,7 @@ def plot_waveforms2(filename,
continue

_plot_ZR(axes[ir], 1, dat, syn, component,
normalize, trace_label_writer, max_amplitude_bw, total_misfit_bw)

normalize, trace_label_writer, max_amplitudes_bw[_i], total_misfit_bw)

#
# plot surface wave traces
Expand All @@ -216,7 +258,7 @@ def plot_waveforms2(filename,
continue

_plot_ZRT(axes[ir], 3, dat, syn, component,
normalize, trace_label_writer, max_amplitude_sw, total_misfit_sw)
normalize, trace_label_writer, max_amplitudes_sw[_i], total_misfit_sw)


ir += 1
Expand Down Expand Up @@ -247,6 +289,9 @@ def plot_data_greens1(filename,
# calculate total misfit for display in figure header
total_misfit = misfit(data, greens.select(origin), source, optimization_level=0)

# Get the number of stations used
N_total = _count([data])

# prepare figure header
if 'header' in kwargs:
header = kwargs.pop('header')
Expand All @@ -257,7 +302,8 @@ def plot_data_greens1(filename,

header = _prepare_header(
model, solver, source, source_dict, origin,
process_data, misfit, total_misfit)
process_data, misfit, total_misfit,
additional_header_info={'N': N_total})

plot_waveforms1(filename,
data, synthetics, stations, origin,
Expand Down Expand Up @@ -299,6 +345,10 @@ def plot_data_greens2(filename,

total_misfit_sw = misfit_sw(
data_sw, greens_sw.select(origin), source, optimization_level=0)

N_total = len(stations)
N_p_used = _count([data_bw])
N_s_used = _count([data_sw])


# prepare figure header
Expand All @@ -312,7 +362,8 @@ def plot_data_greens2(filename,
header = _prepare_header(
model, solver, source, source_dict, origin,
process_data_bw, process_data_sw,
misfit_bw, misfit_sw, total_misfit_bw, total_misfit_sw)
misfit_bw, misfit_sw, total_misfit_bw, total_misfit_sw,
additional_header_info={'N': N_total, 'Np': N_p_used, 'Ns': N_s_used})

plot_waveforms2(filename,
data_bw, data_sw, synthetics_bw, synthetics_sw, stations, origin,
Expand Down Expand Up @@ -373,7 +424,7 @@ def _initialize(nrows=None, ncolumns=None, column_width_ratios=None,

def _plot_ZRT(axes, ic, dat, syn, component,
normalize='maximum_amplitude', trace_label_writer=None,
max_amplitude=1., total_misfit=1.):
normalization_amplitude=1., total_misfit=1.):

# plot traces
if component=='Z':
Expand All @@ -387,26 +438,21 @@ def _plot_ZRT(axes, ic, dat, syn, component,

_plot(axis, dat, syn)

# normalize amplitude
# normalize amplitude -- logic for station_amplitude, median_amplitude, and maximum_amplitude is done at higher level
if normalize=='trace_amplitude':
max_trace = _max(dat, syn)
ylim = [-1.5*max_trace, +1.5*max_trace]
axis.set_ylim(*ylim)
elif normalize=='station_amplitude':
max_stream = _max(stream_dat, stream_syn)
ylim = [-1.5*max_stream, +1.5*max_stream]
axis.set_ylim(*ylim)
elif normalize=='maximum_amplitude':
ylim = [-0.75*max_amplitude, +0.75*max_amplitude]
elif normalize=='station_amplitude' or normalize=='median_amplitude' or normalize=='maximum_amplitude':
ylim = [-1.25*normalization_amplitude, +1.25*normalization_amplitude]
axis.set_ylim(*ylim)

if trace_label_writer is not None:
trace_label_writer(axis, dat, syn, total_misfit)


def _plot_ZR(axes, ic, dat, syn, component,
normalize='maximum_amplitude', trace_label_writer=None,
max_amplitude=1., total_misfit=1.):
normalization_amplitude=1., total_misfit=1.):

# plot traces
if component=='Z':
Expand All @@ -418,24 +464,18 @@ def _plot_ZR(axes, ic, dat, syn, component,

_plot(axis, dat, syn)

# normalize amplitude
# normalize amplitude -- logic for station_amplitude, median_amplitude, and maximum_amplitude is done at higher level
if normalize=='trace_amplitude':
max_trace = _max(dat, syn)
ylim = [-1.5*max_trace, +1.5*max_trace]
axis.set_ylim(*ylim)
elif normalize=='station_amplitude':
max_stream = _max(stream_dat, stream_syn)
ylim = [-1.5*max_stream, +1.5*max_stream]
elif normalize=='station_amplitude' or normalize=='median_amplitude' or normalize=='maximum_amplitude':
ylim = [-1.25*normalization_amplitude, +1.25*normalization_amplitude]
axis.set_ylim(*ylim)
elif normalize=='maximum_amplitude':
ylim = [-0.75*max_amplitude, +0.75*max_amplitude]
axis.set_ylim(*ylim)


if trace_label_writer is not None:
trace_label_writer(axis, dat, syn, total_misfit)


def _plot(axis, dat, syn, label=None):
""" Plots data and synthetics time series on current axes
"""
Expand All @@ -450,9 +490,9 @@ def _plot(axis, dat, syn, label=None):
s = syn.data

axis.plot(t, d, 'k', linewidth=1.5,
clip_on=False, zorder=10)
clip_on=True, zorder=10)
axis.plot(t, s[start:stop], 'r', linewidth=1.25,
clip_on=False, zorder=10)
clip_on=True, zorder=10)


def _add_component_labels1(axes, body_wave_labels=True, surface_wave_labels=True):
Expand Down Expand Up @@ -566,8 +606,23 @@ def _hide_axes(axes):
col.get_yaxis().set_ticks([])
col.patch.set_visible(False)

def station_number_header_decorator(header_function):
def wrapper(*args, **kwargs):
# Call the original header function with all args except 'additional_header_info'
header = header_function(*args, **{k: v for k, v in kwargs.items() if k != 'additional_header_info'})

# Now handle the 'additional_header_info' specifically
additional_header_info = kwargs.get('additional_header_info', {})
for key, value in additional_header_info.items():
setattr(header, key, value) # Dynamically add new attributes to the header object

return header
return wrapper



def _prepare_header(model, solver, source, source_dict, origin, *args):
@station_number_header_decorator
def _prepare_header(model, solver, source, source_dict, origin, *args, **kwargs):
# prepares figure header

if len(args)==3:
Expand Down
Loading