Skip to content

Commit

Permalink
Added generic pygmt plotting backend for attributes
Browse files Browse the repository at this point in the history
Implemented a full "generic" pygmt spider plot, with dynamic topography, coastlines, moment tensor and header option. The backend comes with a utility class where most individual functions are stored.
  • Loading branch information
thurinj committed Mar 1, 2024
1 parent 568e49a commit 77228de
Show file tree
Hide file tree
Showing 2 changed files with 302 additions and 14 deletions.
2 changes: 1 addition & 1 deletion mtuq/graphics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mtuq.graphics.attrs import\
plot_time_shifts, plot_amplitude_ratios, plot_log_amplitude_ratios,\
_plot_attrs
_plot_attrs, plot_cross_corr, _pygmt_backend

from mtuq.graphics.beachball import\
plot_beachball, plot_polarities
Expand Down
314 changes: 301 additions & 13 deletions mtuq/graphics/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from os.path import join

from mtuq.util import defaults, warn
from mtuq.graphics._pygmt import exists_pygmt
from mtuq.event import MomentTensor


def plot_time_shifts(dirname, attrs, stations, origin, key='total_shift',
Expand All @@ -22,15 +24,15 @@ def plot_time_shifts(dirname, attrs, stations, origin, key='total_shift',
MTUQ distinguishes between the following different types of
time shifts
- `static_shift` is an initial user-supplied time shift applied during
data processing
data processing
- `time_shift` is a subsequent cross-correlation time shift applied
during misfit evaluation
during misfit evaluation
- `total_shift` is the total correction, or in other words the sum of
static and cross-correlation time shifts
static and cross-correlation time shifts
.. rubric :: Required input arguments
Expand Down Expand Up @@ -87,7 +89,7 @@ def plot_amplitude_ratios(dirname, attrs, stations, origin, **kwargs):
"""
defaults(kwargs, {
'colormap': 'Reds',
'colormap': 'inferno',
'label': '$A_{obs}/A_{syn}$',
'zero_centered': False,
})
Expand Down Expand Up @@ -127,8 +129,8 @@ def plot_log_amplitude_ratios(dirname, attrs, stations, origin, **kwargs):


def _plot_attrs(dirname, stations, origin, attrs, key,
components=['Z', 'R', 'T'], format='png', backend=None,
**kwargs):
components=['Z', 'R', 'T'], format='png', backend=None,
**kwargs):

""" Reads the attribute given by `key` from the `attrs` data structure, and
plots how this attribute varies
Expand Down Expand Up @@ -162,27 +164,28 @@ def _plot_attrs(dirname, stations, origin, attrs, key,

if backend is None:
backend = _default_backend
elif backend == _pygmt_backend and not exists_pygmt():
warn('PyGMT backend requested but PyGMT not found'); backend = _default_backend

if not callable(backend):
raise TypeError


os.makedirs(dirname, exist_ok=True)

for component in components:
values = []
station_list = []
active_stations_list = []

for _i, station in enumerate(stations):
if component not in attrs[_i]:
continue

values += [attrs[_i][component][key]]
station_list += [stations[_i]]
active_stations_list += [stations[_i]]

if len(values) > 0:
filename = join(dirname, component+'.'+format)
backend(filename, values, station_list, origin, **kwargs)
backend(filename, values, active_stations_list, origin, stations_list = stations, **kwargs)


#
Expand All @@ -191,7 +194,7 @@ def _plot_attrs(dirname, stations, origin, attrs, key,

def _default_backend(filename, values, stations, origin,
colormap='coolwarm', zero_centered=True, colorbar=True,
label='', width=5., height=5.):
label='', width=5., height=5., **kwargs):

""" Default backend for all other `mtuq.graphics.attrs` functions
Expand Down Expand Up @@ -230,7 +233,7 @@ def _default_backend(filename, values, stations, origin,
else:
min_val = np.min(values)
max_val = np.max(values)

# plot stations
im = pyplot.scatter(
[station.longitude for station in stations],
Expand Down Expand Up @@ -284,3 +287,288 @@ def _default_backend(filename, values, stations, origin,
pyplot.close()


def _pygmt_backend(filename, values, active_stations, origin,
colormap='polar', zero_centered=True, display_topo=True,
label='', width=5, moment_tensor=None, process=None,
stations_list=None, station_labels=True, min_val=None, max_val=None, **kwargs):
"""
PyGMT backend for plotting station attributes with hillshading using the
Miller Cylindrical projection, with an azimuth of 0/90 and a normalization
of t1 for the hillshade intensity.
"""
import pygmt

if not stations_list:
stations_list = active_stations
print('Complete station list not passed to pygmt plotting backend \nWill plot only active stations')
# Collection of longitudes and latitudes from all available stations
longitudes = [s.longitude for s in stations_list + [origin]]
latitudes = [s.latitude for s in stations_list + [origin]]

# Calculate the region to display with a buffer around the stations
region, lat_buffer = PyGMTUtilities.calculate_plotting_region(stations_list, origin, buffer_percentage=0.1)

# Setting up the figure
fig = pygmt.Figure()

# Dynamically determine the grid resolution for topography based on the range of longitudes and latitudes
# (etopo topography file will be downloaded if not found)
resolution = PyGMTUtilities.get_resolution(max(longitudes) - min(longitudes), max(latitudes) - min(latitudes))
grid = pygmt.datasets.load_earth_relief(region=region, resolution=resolution)

# Define a grayscale colormap for topography
pygmt.makecpt(cmap='gray', series=[-7000, 7000])

# Calculate the gradient (hillshade) grid with azimuth 0/300 and normalization t1
# <https://www.pygmt.org/dev/gallery/images/grdgradient_shading.html>
shade = pygmt.grdgradient(grid=grid, azimuth="0/300", normalize="t1")
# Plot the hillshade grid as an image
if display_topo:
fig.grdimage(grid=grid, shading=shade, projection=f'J{width}i', frame='a', cmap='gray', no_clip=True)

# Overlay coastlines
PyGMTUtilities.draw_coastlines(fig)

# Configure the colormap for station values
colormap, cmap_reverse_flag = PyGMTUtilities.configure_colormap(colormap)
if zero_centered:
pygmt.makecpt(cmap=colormap, series=[-np.max(np.abs(values))*1.01, np.max(np.abs(values))*1.01], reverse=cmap_reverse_flag)
elif min_val is not None and max_val is not None:
pygmt.makecpt(cmap=colormap, series=[min_val, max_val], continuous=True, reverse=cmap_reverse_flag)
else:
pygmt.makecpt(cmap=colormap, series=[np.min(values), np.max(values)], continuous=True, reverse=cmap_reverse_flag)


# Plotting lines from origin to stations
for station in stations_list:
if station in active_stations:
# Plot line for active station as colored line
value = values[active_stations.index(station)] if station in active_stations else 0
fig.plot(
x=[origin.longitude, station.longitude],
y=[origin.latitude, station.latitude],
cmap=True,
zvalue=value,
pen="thick,+z,-"
)

# Plotting stations as triangles
fig.plot(
x=[station.longitude for station in active_stations],
y=[station.latitude for station in active_stations],
style='i0.8c', # Triangle
color=values,
cmap=True,
pen="0.5p,black"
)

# Plotting non-active stations as hollow triangles
non_active_stations = [station for station in stations_list if station not in active_stations]
if len(non_active_stations) > 0:
fig.plot(
x=[station.longitude for station in non_active_stations],
y=[station.latitude for station in non_active_stations],
style='i0.8c', # Triangle
color=None, # Hollow (white) triangle
pen="0.5p,black" # Outline color
)
fig.plot(
x=[station.longitude for station in non_active_stations],
y=[station.latitude for station in non_active_stations],
style='i0.6c', # Triangle
color=None, # Hollow (white) triangle
pen="0.5p,white" # Outline color
)

# Plotting the origin as a star
fig.plot(
x=[origin.longitude],
y=[origin.latitude],
style='a0.6c', # Star, size 0.5 cm
color='yellow',
pen="0.5p,black"
)

if moment_tensor is not None:
# Normalize the moment tensor components to the desired exponent

if type(moment_tensor) is MomentTensor:
moment_tensor = moment_tensor.as_vector()

moment_tensor = np.array(moment_tensor)/np.linalg.norm(moment_tensor)

moment_tensor_spec = {
'mrr': moment_tensor[0],
'mtt': moment_tensor[1],
'mff': moment_tensor[2],
'mrt': moment_tensor[3],
'mrf': moment_tensor[4],
'mtf': moment_tensor[5],
'exponent': 21 # Merely for size control, as the MT is normalized prior to plotting
}

# Plot the moment tensor as a beachball
fig.meca(
spec=moment_tensor_spec,
scale="1c", # Sets a fixed size for the beachball plot
longitude=origin.longitude,
latitude=origin.latitude,
depth=10, # Depth is required, even if not used, set to a small number
convention="mt", # Use GMT's mt convention <https://www.pygmt.org/dev/api/generated/pygmt.Figure.meca.html>
compressionfill="red",
extensionfill="white",
pen="black"
)

if station_labels is True:
# Plotting station labels
for station in stations_list:
fig.text(
x=station.longitude,
y=station.latitude,
text=station.station,
font="5p,Helvetica-Bold,black",
justify="LM",
offset="-0.45c/0.125c",
fill='white'
)

fig.colorbar(frame=f'+l"{PyGMTUtilities.prepare_latex_annotations(label)}"', position="JMR+o1.5c/0c+w7c/0.5c")

fig.basemap(region=region, projection=f'J{width}i', frame=True)

# Now starts the header text above the plot -- It is not a title and can be modified.
# Add an integer increment to the text_line_val bellow to add a new line above.
text_line_val = 1
header_lines = PyGMTUtilities.get_header(label, origin, filename, process)

for header_line in header_lines:
fig.text(x=-148, y=(max(latitudes) + lat_buffer)+(text_line_val)*0.25, text=header_line, font="14p,Helvetica-Bold,black", justify="MC", no_clip=True)
text_line_val += 1

# Saving the figure
fig.savefig(filename, crop=True, dpi=300)

class PyGMTUtilities:
@staticmethod
def calculate_plotting_region(stations, origin, buffer_percentage=0.1):
longitudes = [station.longitude for station in stations] + [origin.longitude]
latitudes = [station.latitude for station in stations] + [origin.latitude]

lon_buffer = (max(longitudes) - min(longitudes)) * buffer_percentage
lat_buffer = (max(latitudes) - min(latitudes)) * buffer_percentage

region = [min(longitudes) - lon_buffer, max(longitudes) + lon_buffer,
min(latitudes) - lat_buffer, max(latitudes) + lat_buffer]
return region, lat_buffer


@staticmethod
def get_resolution(lon_range, lat_range):
"""
Determines the resolution based on the given longitude and latitude ranges.
Args:
lon_range (float): The range of longitudes.
lat_range (float): The range of latitudes.
Returns:
str: pygmt etopo grid resolution based on the given ranges.
"""

if lon_range > 10 or lat_range > 10:
return '01m'
elif lon_range > 5 or lat_range > 5:
return '15s'
elif lon_range > 2 or lat_range > 2:
return '03s'
elif lon_range > 1 or lat_range > 1:
return '01s'
else:
return '05m'

@staticmethod
def configure_colormap(colormap):
"""
Configures the colormap based on the given input - as conventions for matplotlib and pygmt can differ
Args:
colormap (str): The name of the colormap.
Returns:
tuple: A tuple containing the modified colormap name and a flag indicating
whether the colormap should be reversed.
"""
cmap_reverse_flag = True if colormap.endswith('_r') else False
colormap = colormap[:-2] if cmap_reverse_flag else colormap
return colormap, cmap_reverse_flag

@staticmethod
def prepare_latex_annotations(label):
"""
Prepares LaTeX annotations for plotting. Uses HTML for compatibility with PyGMT/GMT.
Args:
label (str): The LaTeX label to be prepared.
Returns:
str: The prepared label.
"""
if label.startswith('$') and label.endswith('$'):
# Convert LaTeX to HTML for compatibility with PyGMT/GMT
return f"<math>{label[1:-1]}</math>"
else:
return label

@staticmethod
def get_header(label, origin, filename, process = None):
"""
Generates a header for a plot based on the provided parameters.
Args:
label (str): The label for the plot. Defined in default kwargs.
origin (Origin): mtuq.event.Origin object.
filename (str): The filename of the plot. Defined by default the high-level function. Used to retrieve the component.
process (Process, optional): mtuq.process_data.ProcessData object for appropriate dataset.
Returns:
list: A list containing two lines of the header.
"""
if process is not None:
# get type of waves used for the window
window_type = process.window_type
if window_type == 'surface_wave':
window_type = 'Surface wave'
elif window_type == 'body_wave':
window_type = 'Body wave'

component = filename.split('/')[-1].split('.')[0]
origin_time = str(origin.time)[0:19]
origin_depth = origin.depth_in_m/1000

label = PyGMTUtilities.prepare_latex_annotations(label)

# if window_type exists, define Rayleigh or Love wave
if process is not None:
if window_type == 'Surface wave' and component == 'Z' or window_type == 'Surface wave' and component == 'R':
# First line of the header defined as: label - Rayleigh wave (component)
header_line_1 = f"{label} - Rayleigh wave ({component})"
elif window_type == 'Surface wave' and component == 'T':
# First line of the header defined as: label - Love wave (component)
header_line_1 = f"{label} - Love wave ({component})"
elif window_type == 'Body wave':
# First line of the header defined as: label - (component)
header_line_1 = f"{label} - Body wave ({component})"
else:
# First line of the header defined as: label - (component)
header_line_1 = f"{label} - ({component})"

header_line_2 = f"Event Time: {origin_time} UTC, Depth: {origin_depth:.1f} km"

return [header_line_1, header_line_2]

@staticmethod
def draw_coastlines(fig, area_thresh=100, water_color='paleturquoise', water_transparency=55):
fig.coast(shorelines=True, area_thresh=area_thresh)
fig.coast(shorelines=False, water=water_color, transparency=water_transparency, area_thresh=area_thresh)

0 comments on commit 77228de

Please sign in to comment.