diff --git a/HISTORY.rst b/HISTORY.rst index 46d4c7c..2d1983a 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -1,6 +1,11 @@ ======= History ======= +2.2.1 (2024-7-30) +----------------------- +*Fix bug with models that use the same names for pipes and nodes. +*Allow for custom color maps + 2.2.0 (2024-7-30) ----------------------- *Added additional colorbar style options: Colorbar location, label location, font color, and font size. diff --git a/setup.cfg b/setup.cfg index d1e34a3..2adb7b6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 2.2.0 +current_version = 2.2.1 commit = True tag = True diff --git a/setup.py b/setup.py index be4e88e..5e9dda3 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,6 @@ test_suite='tests', tests_require=test_requirements, url='https://github.com/tylertrimble/viswaternet', - version='2.2.0', + version='2.2.1', zip_safe=False, ) diff --git a/viswaternet/drawing/base.py b/viswaternet/drawing/base.py index 19269a0..f954b1f 100644 --- a/viswaternet/drawing/base.py +++ b/viswaternet/drawing/base.py @@ -90,7 +90,14 @@ def draw_nodes( parameter_results, min_size, max_size) if np.min(parameter_results) < -1e-5: # Gets the cmap object from matplotlib - cmap = mpl.colormaps[cmap] + try: + cmap = mpl.colormaps[cmap] + except Exception: + if isinstance(cmap, mpl.colors.LinearSegmentedColormap) \ + or isinstance(cmap, mpl.colors.ListedColormap): + pass + else: + raise Exception('Invalid cmap!') # If both vmin and vmax are None, set vmax to the max data # value and vmin to the negative of the max data value. This # ensures that the colorbar is centered at 0. @@ -127,7 +134,14 @@ def draw_nodes( return g else: # Gets the cmap object from matplotlib - cmap = mpl.colormaps[cmap] + try: + cmap = mpl.colormaps[cmap] + except Exception: + if isinstance(cmap, mpl.colors.LinearSegmentedColormap) \ + or isinstance(cmap, mpl.colors.ListedColormap): + pass + else: + raise Exception('Invalid cmap!') # If both vmin and vmax are None, don't pass vmin and vmax, # as networkx will handle the limits of the colorbar # itself. @@ -252,7 +266,14 @@ def draw_links( parameter_results, min_size, max_size) if np.min(parameter_results) < -1e-5: # Gets the cmap object from matplotlib - cmap = mpl.colormaps[cmap] + try: + cmap = mpl.colormaps[cmap] + except Exception: + if isinstance(cmap, mpl.colors.LinearSegmentedColormap) \ + or isinstance(cmap, mpl.colors.ListedColormap): + pass + else: + raise Exception('Invalid cmap!') # If both vmin and vmax are None, set vmax to the max data # value and vmin to the negative of the max data value. This # ensures that the colorbar is centered at 0. @@ -289,7 +310,14 @@ def draw_links( return g else: # Gets the cmap object from matplotlib - cmap = mpl.colormaps[cmap] + try: + cmap = mpl.colormaps[cmap] + except Exception: + if isinstance(cmap, mpl.colors.LinearSegmentedColormap) \ + or isinstance(cmap, mpl.colors.ListedColormap): + pass + else: + raise Exception('Invalid cmap!') # If both vmin and vmax are None, don't pass vmin and vmax, # as networkx will handle the limits of the colorbar # itself. @@ -343,6 +371,7 @@ def draw_base_elements( ax, draw_nodes=True, element_list=None, + draw_originator=None, style=None): """ Draws base elements (draw_nodes, draw_links, draw_reservoirs, draw_tanks, draw_pumps, and draw_valves) @@ -401,7 +430,7 @@ def draw_base_elements( # If draw_nodes is True, then draw draw_nodes if draw_nodes: node_list = model['node_names'] - if element_list is None: + if element_list is None or draw_originator == 'link': node_list = [node_list[node_list.index(name)] for name in node_list if ((name not in model["tank_names"] @@ -452,7 +481,7 @@ def draw_base_elements( # If draw_links is True, then draw draw_links if draw_links: pipe_name_list = model['G_pipe_name_list'] - if element_list is None: + if element_list is None or draw_originator == 'node': edgelist = [model['pipe_list'][pipe_name_list.index(name)] for name in pipe_name_list if ((name not in model["pump_names"] @@ -767,7 +796,14 @@ def draw_legend( for i, text in enumerate(legend2.get_texts()): text.set_color(color_list[i]) elif cmap: - cmap = mpl.colormaps[cmap] + try: + cmap = mpl.colormaps[cmap] + except Exception: + if isinstance(cmap, mpl.colors.LinearSegmentedColormap) \ + or isinstance(cmap, mpl.colors.ListedColormap): + pass + else: + raise Exception('Invalid cmap!') cmap_value = 1 / len(intervals) for i, text in enumerate(legend2.get_texts()): text.set_color(cmap(float(cmap_value))) diff --git a/viswaternet/drawing/continuous.py b/viswaternet/drawing/continuous.py index ba3065b..e5cbfba 100644 --- a/viswaternet/drawing/continuous.py +++ b/viswaternet/drawing/continuous.py @@ -154,6 +154,7 @@ def plot_continuous_nodes( self, ax, element_list=node_list, + draw_originator='node', style=style) if draw_color_bar is True: if color_bar_title is None: @@ -319,6 +320,7 @@ def plot_continuous_links( ax, draw_nodes=draw_nodes, element_list=link_list, + draw_originator='link', style=style) if link_arrows is True: g = fancyarrowpatch_to_linecollection( diff --git a/viswaternet/drawing/discrete.py b/viswaternet/drawing/discrete.py index 81aae24..8b04f11 100644 --- a/viswaternet/drawing/discrete.py +++ b/viswaternet/drawing/discrete.py @@ -99,7 +99,14 @@ def draw_discrete_nodes( label=label_list[j]) ax.add_artist(m) else: - cmap = mpl.colormaps[cmap] + try: + cmap = mpl.colormaps[cmap] + except Exception: + if isinstance(cmap, mpl.colors.LinearSegmentedColormap) \ + or isinstance(cmap, mpl.colors.ListedColormap): + pass + else: + raise Exception('Invalid cmap!') cmapValue = 1 / len(intervals) for j, interval_name in enumerate(intervals): interval_elements = element_list.get(interval_name) @@ -200,7 +207,14 @@ def draw_discrete_links( style=link_style[j], label=label_list[j]) else: - cmap = mpl.colormaps[cmap] + try: + cmap = mpl.colormaps[cmap] + except Exception: + if isinstance(cmap, mpl.colors.LinearSegmentedColormap) \ + or isinstance(cmap, mpl.colors.ListedColormap): + pass + else: + raise Exception('Invalid cmap!') cmapValue = 1 / len(intervals) for j, interval_name in enumerate(intervals): interval_elements = element_list.get(interval_name) @@ -404,6 +418,7 @@ def plot_discrete_nodes( ax, draw_nodes=draw_nodes, element_list=node_list, + draw_originator='node', style=style) if discrete_legend_title is None: @@ -575,6 +590,7 @@ def plot_discrete_links( ax, draw_nodes=draw_nodes, element_list=link_list, + draw_originator='link', style=style) if discrete_legend_title is None: discrete_legend_title = label_generator(parameter, value, unit) diff --git a/viswaternet/drawing/unique.py b/viswaternet/drawing/unique.py index f25659d..fead718 100644 --- a/viswaternet/drawing/unique.py +++ b/viswaternet/drawing/unique.py @@ -353,6 +353,7 @@ def call_draw_base_elements(element_list=None): ax, draw_nodes=draw_nodes, element_list=element_list, + draw_originator=parameter_type, style=style) def call_draw_legend(intervals=None, element_list=None):