Skip to content

Commit

Permalink
extract draw_plot method for plotting measures using matplotlib
Browse files Browse the repository at this point in the history
  • Loading branch information
wang-boyu committed Nov 6, 2024
1 parent 87d937e commit abc97b3
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 22 deletions.
4 changes: 3 additions & 1 deletion mesa/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Solara based visualization for Mesa models."""

from mesa.visualization.mpl_space_drawing import (
from mesa.visualization.mpl_drawing import (
draw_plot,
draw_space,
)

Expand All @@ -15,6 +16,7 @@
"Slider",
"make_space_altair",
"draw_space",
"draw_plot",
"make_plot_component",
"make_space_component",
]
23 changes: 3 additions & 20 deletions mesa/visualization/components/matplotlib_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import warnings
from collections.abc import Callable

import matplotlib.pyplot as plt
import solara
from matplotlib.figure import Figure

from mesa.visualization.mpl_space_drawing import draw_space
from mesa.visualization.mpl_drawing import draw_plot, draw_space
from mesa.visualization.utils import update_counter


Expand Down Expand Up @@ -151,26 +150,10 @@ def PlotMatplotlib(
"""
update_counter.get()
fig = Figure()
ax = fig.subplots()
df = model.datacollector.get_model_vars_dataframe()
if isinstance(measure, str):
ax.plot(df.loc[:, measure])
ax.set_ylabel(measure)
elif isinstance(measure, dict):
for m, color in measure.items():
ax.plot(df.loc[:, m], label=m, color=color)
ax.legend(loc="best")
elif isinstance(measure, list | tuple):
for m in measure:
ax.plot(df.loc[:, m], label=m)
ax.legend(loc="best")

ax = fig.add_subplot()
draw_plot(model, measure, ax)
if post_process is not None:
post_process(ax)

ax.set_xlabel("Step")
# Set integer x axis
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
solara.FigureMatplotlib(
fig, format=save_format, bbox_inches="tight", dependencies=dependencies
)
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,41 @@ def _scatter(ax: Axes, arguments, **kwargs):
**{k: v[logical] for k, v in arguments.items()},
**kwargs,
)


def draw_plot(
model,
measure,
ax: Axes | None = None,
):
"""Create a Matplotlib-based plot for a measure or measures.
Args:
model (mesa.Model): The model instance.
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
ax: the axes upon which to draw the plot
post_process: a user-specified callable to do post-processing called with the Axes instance.
Returns:
plt.Axes: The Axes object with the plot drawn onto it.
"""
if ax is None:
_, ax = plt.subplots()
df = model.datacollector.get_model_vars_dataframe()
if isinstance(measure, str):
ax.plot(df.loc[:, measure])
ax.set_ylabel(measure)
elif isinstance(measure, dict):
for m, color in measure.items():
ax.plot(df.loc[:, m], label=m, color=color)
ax.legend(loc="best")
elif isinstance(measure, list | tuple):
for m in measure:
ax.plot(df.loc[:, m], label=m)
ax.legend(loc="best")

ax.set_xlabel("Step")
# Set integer x axis
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))

return ax
2 changes: 1 addition & 1 deletion tests/test_components_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
PropertyLayer,
SingleGrid,
)
from mesa.visualization.mpl_space_drawing import (
from mesa.visualization.mpl_drawing import (
draw_continuous_space,
draw_hex_grid,
draw_network,
Expand Down

0 comments on commit abc97b3

Please sign in to comment.