From 77228de54c10ea288850605fd72841d220bb7aee Mon Sep 17 00:00:00 2001 From: thurinj Date: Fri, 1 Mar 2024 14:04:05 +1100 Subject: [PATCH] Added generic pygmt plotting backend for attributes Implemented a full "generic" pygmt spider plot, with dynamic topography, coastlines, moment tensor and header option. The backend comes with a utility class where most individual functions are stored. --- mtuq/graphics/__init__.py | 2 +- mtuq/graphics/attrs.py | 314 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 302 insertions(+), 14 deletions(-) diff --git a/mtuq/graphics/__init__.py b/mtuq/graphics/__init__.py index badf8d0b8..35ae6ee0a 100644 --- a/mtuq/graphics/__init__.py +++ b/mtuq/graphics/__init__.py @@ -5,7 +5,7 @@ from mtuq.graphics.attrs import\ plot_time_shifts, plot_amplitude_ratios, plot_log_amplitude_ratios,\ - _plot_attrs + _plot_attrs, plot_cross_corr, _pygmt_backend from mtuq.graphics.beachball import\ plot_beachball, plot_polarities diff --git a/mtuq/graphics/attrs.py b/mtuq/graphics/attrs.py index 45a0d9c86..30878b2fc 100644 --- a/mtuq/graphics/attrs.py +++ b/mtuq/graphics/attrs.py @@ -7,6 +7,8 @@ from os.path import join from mtuq.util import defaults, warn +from mtuq.graphics._pygmt import exists_pygmt +from mtuq.event import MomentTensor def plot_time_shifts(dirname, attrs, stations, origin, key='total_shift', @@ -22,15 +24,15 @@ def plot_time_shifts(dirname, attrs, stations, origin, key='total_shift', MTUQ distinguishes between the following different types of time shifts - + - `static_shift` is an initial user-supplied time shift applied during - data processing + data processing - `time_shift` is a subsequent cross-correlation time shift applied - during misfit evaluation + during misfit evaluation - `total_shift` is the total correction, or in other words the sum of - static and cross-correlation time shifts + static and cross-correlation time shifts .. rubric :: Required input arguments @@ -87,7 +89,7 @@ def plot_amplitude_ratios(dirname, attrs, stations, origin, **kwargs): """ defaults(kwargs, { - 'colormap': 'Reds', + 'colormap': 'inferno', 'label': '$A_{obs}/A_{syn}$', 'zero_centered': False, }) @@ -127,8 +129,8 @@ def plot_log_amplitude_ratios(dirname, attrs, stations, origin, **kwargs): def _plot_attrs(dirname, stations, origin, attrs, key, - components=['Z', 'R', 'T'], format='png', backend=None, - **kwargs): + components=['Z', 'R', 'T'], format='png', backend=None, + **kwargs): """ Reads the attribute given by `key` from the `attrs` data structure, and plots how this attribute varies @@ -162,27 +164,28 @@ def _plot_attrs(dirname, stations, origin, attrs, key, if backend is None: backend = _default_backend + elif backend == _pygmt_backend and not exists_pygmt(): + warn('PyGMT backend requested but PyGMT not found'); backend = _default_backend if not callable(backend): raise TypeError - os.makedirs(dirname, exist_ok=True) for component in components: values = [] - station_list = [] + active_stations_list = [] for _i, station in enumerate(stations): if component not in attrs[_i]: continue values += [attrs[_i][component][key]] - station_list += [stations[_i]] + active_stations_list += [stations[_i]] if len(values) > 0: filename = join(dirname, component+'.'+format) - backend(filename, values, station_list, origin, **kwargs) + backend(filename, values, active_stations_list, origin, stations_list = stations, **kwargs) # @@ -191,7 +194,7 @@ def _plot_attrs(dirname, stations, origin, attrs, key, def _default_backend(filename, values, stations, origin, colormap='coolwarm', zero_centered=True, colorbar=True, - label='', width=5., height=5.): + label='', width=5., height=5., **kwargs): """ Default backend for all other `mtuq.graphics.attrs` functions @@ -230,7 +233,7 @@ def _default_backend(filename, values, stations, origin, else: min_val = np.min(values) max_val = np.max(values) - + # plot stations im = pyplot.scatter( [station.longitude for station in stations], @@ -284,3 +287,288 @@ def _default_backend(filename, values, stations, origin, pyplot.close() +def _pygmt_backend(filename, values, active_stations, origin, + colormap='polar', zero_centered=True, display_topo=True, + label='', width=5, moment_tensor=None, process=None, + stations_list=None, station_labels=True, min_val=None, max_val=None, **kwargs): + """ + PyGMT backend for plotting station attributes with hillshading using the + Miller Cylindrical projection, with an azimuth of 0/90 and a normalization + of t1 for the hillshade intensity. + """ + import pygmt + + if not stations_list: + stations_list = active_stations + print('Complete station list not passed to pygmt plotting backend \nWill plot only active stations') + # Collection of longitudes and latitudes from all available stations + longitudes = [s.longitude for s in stations_list + [origin]] + latitudes = [s.latitude for s in stations_list + [origin]] + + # Calculate the region to display with a buffer around the stations + region, lat_buffer = PyGMTUtilities.calculate_plotting_region(stations_list, origin, buffer_percentage=0.1) + + # Setting up the figure + fig = pygmt.Figure() + + # Dynamically determine the grid resolution for topography based on the range of longitudes and latitudes + # (etopo topography file will be downloaded if not found) + resolution = PyGMTUtilities.get_resolution(max(longitudes) - min(longitudes), max(latitudes) - min(latitudes)) + grid = pygmt.datasets.load_earth_relief(region=region, resolution=resolution) + + # Define a grayscale colormap for topography + pygmt.makecpt(cmap='gray', series=[-7000, 7000]) + + # Calculate the gradient (hillshade) grid with azimuth 0/300 and normalization t1 + # + shade = pygmt.grdgradient(grid=grid, azimuth="0/300", normalize="t1") + # Plot the hillshade grid as an image + if display_topo: + fig.grdimage(grid=grid, shading=shade, projection=f'J{width}i', frame='a', cmap='gray', no_clip=True) + + # Overlay coastlines + PyGMTUtilities.draw_coastlines(fig) + + # Configure the colormap for station values + colormap, cmap_reverse_flag = PyGMTUtilities.configure_colormap(colormap) + if zero_centered: + pygmt.makecpt(cmap=colormap, series=[-np.max(np.abs(values))*1.01, np.max(np.abs(values))*1.01], reverse=cmap_reverse_flag) + elif min_val is not None and max_val is not None: + pygmt.makecpt(cmap=colormap, series=[min_val, max_val], continuous=True, reverse=cmap_reverse_flag) + else: + pygmt.makecpt(cmap=colormap, series=[np.min(values), np.max(values)], continuous=True, reverse=cmap_reverse_flag) + + + # Plotting lines from origin to stations + for station in stations_list: + if station in active_stations: + # Plot line for active station as colored line + value = values[active_stations.index(station)] if station in active_stations else 0 + fig.plot( + x=[origin.longitude, station.longitude], + y=[origin.latitude, station.latitude], + cmap=True, + zvalue=value, + pen="thick,+z,-" + ) + + # Plotting stations as triangles + fig.plot( + x=[station.longitude for station in active_stations], + y=[station.latitude for station in active_stations], + style='i0.8c', # Triangle + color=values, + cmap=True, + pen="0.5p,black" + ) + + # Plotting non-active stations as hollow triangles + non_active_stations = [station for station in stations_list if station not in active_stations] + if len(non_active_stations) > 0: + fig.plot( + x=[station.longitude for station in non_active_stations], + y=[station.latitude for station in non_active_stations], + style='i0.8c', # Triangle + color=None, # Hollow (white) triangle + pen="0.5p,black" # Outline color + ) + fig.plot( + x=[station.longitude for station in non_active_stations], + y=[station.latitude for station in non_active_stations], + style='i0.6c', # Triangle + color=None, # Hollow (white) triangle + pen="0.5p,white" # Outline color + ) + + # Plotting the origin as a star + fig.plot( + x=[origin.longitude], + y=[origin.latitude], + style='a0.6c', # Star, size 0.5 cm + color='yellow', + pen="0.5p,black" + ) + + if moment_tensor is not None: + # Normalize the moment tensor components to the desired exponent + + if type(moment_tensor) is MomentTensor: + moment_tensor = moment_tensor.as_vector() + + moment_tensor = np.array(moment_tensor)/np.linalg.norm(moment_tensor) + + moment_tensor_spec = { + 'mrr': moment_tensor[0], + 'mtt': moment_tensor[1], + 'mff': moment_tensor[2], + 'mrt': moment_tensor[3], + 'mrf': moment_tensor[4], + 'mtf': moment_tensor[5], + 'exponent': 21 # Merely for size control, as the MT is normalized prior to plotting + } + + # Plot the moment tensor as a beachball + fig.meca( + spec=moment_tensor_spec, + scale="1c", # Sets a fixed size for the beachball plot + longitude=origin.longitude, + latitude=origin.latitude, + depth=10, # Depth is required, even if not used, set to a small number + convention="mt", # Use GMT's mt convention + compressionfill="red", + extensionfill="white", + pen="black" + ) + + if station_labels is True: + # Plotting station labels + for station in stations_list: + fig.text( + x=station.longitude, + y=station.latitude, + text=station.station, + font="5p,Helvetica-Bold,black", + justify="LM", + offset="-0.45c/0.125c", + fill='white' + ) + + fig.colorbar(frame=f'+l"{PyGMTUtilities.prepare_latex_annotations(label)}"', position="JMR+o1.5c/0c+w7c/0.5c") + + fig.basemap(region=region, projection=f'J{width}i', frame=True) + + # Now starts the header text above the plot -- It is not a title and can be modified. + # Add an integer increment to the text_line_val bellow to add a new line above. + text_line_val = 1 + header_lines = PyGMTUtilities.get_header(label, origin, filename, process) + + for header_line in header_lines: + fig.text(x=-148, y=(max(latitudes) + lat_buffer)+(text_line_val)*0.25, text=header_line, font="14p,Helvetica-Bold,black", justify="MC", no_clip=True) + text_line_val += 1 + + # Saving the figure + fig.savefig(filename, crop=True, dpi=300) + +class PyGMTUtilities: + @staticmethod + def calculate_plotting_region(stations, origin, buffer_percentage=0.1): + longitudes = [station.longitude for station in stations] + [origin.longitude] + latitudes = [station.latitude for station in stations] + [origin.latitude] + + lon_buffer = (max(longitudes) - min(longitudes)) * buffer_percentage + lat_buffer = (max(latitudes) - min(latitudes)) * buffer_percentage + + region = [min(longitudes) - lon_buffer, max(longitudes) + lon_buffer, + min(latitudes) - lat_buffer, max(latitudes) + lat_buffer] + return region, lat_buffer + + + @staticmethod + def get_resolution(lon_range, lat_range): + """ + Determines the resolution based on the given longitude and latitude ranges. + + Args: + lon_range (float): The range of longitudes. + lat_range (float): The range of latitudes. + + Returns: + str: pygmt etopo grid resolution based on the given ranges. + """ + + if lon_range > 10 or lat_range > 10: + return '01m' + elif lon_range > 5 or lat_range > 5: + return '15s' + elif lon_range > 2 or lat_range > 2: + return '03s' + elif lon_range > 1 or lat_range > 1: + return '01s' + else: + return '05m' + + @staticmethod + def configure_colormap(colormap): + """ + Configures the colormap based on the given input - as conventions for matplotlib and pygmt can differ + + Args: + colormap (str): The name of the colormap. + + Returns: + tuple: A tuple containing the modified colormap name and a flag indicating + whether the colormap should be reversed. + """ + cmap_reverse_flag = True if colormap.endswith('_r') else False + colormap = colormap[:-2] if cmap_reverse_flag else colormap + return colormap, cmap_reverse_flag + + @staticmethod + def prepare_latex_annotations(label): + """ + Prepares LaTeX annotations for plotting. Uses HTML for compatibility with PyGMT/GMT. + + Args: + label (str): The LaTeX label to be prepared. + + Returns: + str: The prepared label. + + """ + if label.startswith('$') and label.endswith('$'): + # Convert LaTeX to HTML for compatibility with PyGMT/GMT + return f"{label[1:-1]}" + else: + return label + + @staticmethod + def get_header(label, origin, filename, process = None): + """ + Generates a header for a plot based on the provided parameters. + + Args: + label (str): The label for the plot. Defined in default kwargs. + origin (Origin): mtuq.event.Origin object. + filename (str): The filename of the plot. Defined by default the high-level function. Used to retrieve the component. + process (Process, optional): mtuq.process_data.ProcessData object for appropriate dataset. + + Returns: + list: A list containing two lines of the header. + """ + if process is not None: + # get type of waves used for the window + window_type = process.window_type + if window_type == 'surface_wave': + window_type = 'Surface wave' + elif window_type == 'body_wave': + window_type = 'Body wave' + + component = filename.split('/')[-1].split('.')[0] + origin_time = str(origin.time)[0:19] + origin_depth = origin.depth_in_m/1000 + + label = PyGMTUtilities.prepare_latex_annotations(label) + + # if window_type exists, define Rayleigh or Love wave + if process is not None: + if window_type == 'Surface wave' and component == 'Z' or window_type == 'Surface wave' and component == 'R': + # First line of the header defined as: label - Rayleigh wave (component) + header_line_1 = f"{label} - Rayleigh wave ({component})" + elif window_type == 'Surface wave' and component == 'T': + # First line of the header defined as: label - Love wave (component) + header_line_1 = f"{label} - Love wave ({component})" + elif window_type == 'Body wave': + # First line of the header defined as: label - (component) + header_line_1 = f"{label} - Body wave ({component})" + else: + # First line of the header defined as: label - (component) + header_line_1 = f"{label} - ({component})" + + header_line_2 = f"Event Time: {origin_time} UTC, Depth: {origin_depth:.1f} km" + + return [header_line_1, header_line_2] + + @staticmethod + def draw_coastlines(fig, area_thresh=100, water_color='paleturquoise', water_transparency=55): + fig.coast(shorelines=True, area_thresh=area_thresh) + fig.coast(shorelines=False, water=water_color, transparency=water_transparency, area_thresh=area_thresh) \ No newline at end of file