diff --git a/mesa/visualization/UserParam.py b/mesa/visualization/UserParam.py new file mode 100644 index 00000000000..5b342471ddb --- /dev/null +++ b/mesa/visualization/UserParam.py @@ -0,0 +1,56 @@ +class UserParam: + _ERROR_MESSAGE = "Missing or malformed inputs for '{}' Option '{}'" + + def maybe_raise_error(self, param_type, valid): + if valid: + return + msg = self._ERROR_MESSAGE.format(param_type, self.label) + raise ValueError(msg) + + +class Slider(UserParam): + """ + A number-based slider input with settable increment. + + Example: + + slider_option = Slider("My Slider", value=123, min=10, max=200, step=0.1) + + Args: + label: The displayed label in the UI + value: The initial value of the slider + min: The minimum possible value of the slider + max: The maximum possible value of the slider + step: The step between min and max for a range of possible values + dtype: either int or float + """ + + def __init__( + self, + label="", + value=None, + min=None, + max=None, + step=1, + dtype=None, + ): + self.label = label + self.value = value + self.min = min + self.max = max + self.step = step + + # Validate option type to make sure values are supplied properly + valid = not (self.value is None or self.min is None or self.max is None) + self.maybe_raise_error("slider", valid) + + if dtype is None: + self.is_float_slider = self._check_values_are_float(value, min, max, step) + else: + self.is_float_slider = dtype is float + + def _check_values_are_float(self, value, min, max, step): + return any(isinstance(n, float) for n in (value, min, max, step)) + + def get(self, attr): + return getattr(self, attr) diff --git a/mesa/visualization/__init__.py b/mesa/visualization/__init__.py index e69de29bb2d..d8a0ebecf86 100644 --- a/mesa/visualization/__init__.py +++ b/mesa/visualization/__init__.py @@ -0,0 +1,14 @@ +from .components.altair import make_space_altair +from .components.matplotlib import make_plot_measure, make_space_matplotlib +from .solara_viz import JupyterViz, SolaraViz, make_text +from .UserParam import Slider + +__all__ = [ + "JupyterViz", + "SolaraViz", + "make_text", + "Slider", + "make_space_altair", + "make_space_matplotlib", + "make_plot_measure", +] diff --git a/mesa/visualization/components/altair.py b/mesa/visualization/components/altair.py new file mode 100644 index 00000000000..1d23b170bda --- /dev/null +++ b/mesa/visualization/components/altair.py @@ -0,0 +1,86 @@ +import contextlib + +import solara + +with contextlib.suppress(ImportError): + import altair as alt + +from mesa.visualization.utils import update_counter + + +def make_space_altair(agent_portrayal=None): + if agent_portrayal is None: + + def agent_portrayal(a): + return {"id": a.unique_id} + + def MakeSpaceAltair(model): + return SpaceAltair(model, agent_portrayal) + + return MakeSpaceAltair + + +@solara.component +def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): + update_counter.get() + space = getattr(model, "grid", None) + if space is None: + # Sometimes the space is defined as model.space instead of model.grid + space = model.space + chart = _draw_grid(space, agent_portrayal) + solara.FigureAltair(chart) + + +def _draw_grid(space, agent_portrayal): + def portray(g): + all_agent_data = [] + for content, (x, y) in g.coord_iter(): + if not content: + continue + if not hasattr(content, "__iter__"): + # Is a single grid + content = [content] # noqa: PLW2901 + for agent in content: + # use all data from agent portrayal, and add x,y coordinates + agent_data = agent_portrayal(agent) + agent_data["x"] = x + agent_data["y"] = y + all_agent_data.append(agent_data) + return all_agent_data + + all_agent_data = portray(space) + invalid_tooltips = ["color", "size", "x", "y"] + + encoding_dict = { + # no x-axis label + "x": alt.X("x", axis=None, type="ordinal"), + # no y-axis label + "y": alt.Y("y", axis=None, type="ordinal"), + "tooltip": [ + alt.Tooltip(key, type=alt.utils.infer_vegalite_type([value])) + for key, value in all_agent_data[0].items() + if key not in invalid_tooltips + ], + } + has_color = "color" in all_agent_data[0] + if has_color: + encoding_dict["color"] = alt.Color("color", type="nominal") + has_size = "size" in all_agent_data[0] + if has_size: + encoding_dict["size"] = alt.Size("size", type="quantitative") + + chart = ( + alt.Chart( + alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict) + ) + .mark_point(filled=True) + .properties(width=280, height=280) + # .configure_view(strokeOpacity=0) # hide grid/chart lines + ) + # This is the default value for the marker size, which auto-scales + # according to the grid area. + if not has_size: + length = min(space.width, space.height) + chart = chart.mark_point(size=30000 / length**2, filled=True) + + return chart diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py new file mode 100644 index 00000000000..b1c61581b71 --- /dev/null +++ b/mesa/visualization/components/matplotlib.py @@ -0,0 +1,246 @@ +from collections import defaultdict + +import networkx as nx +import solara +from matplotlib.figure import Figure +from matplotlib.ticker import MaxNLocator + +import mesa +from mesa.experimental.cell_space import VoronoiGrid +from mesa.visualization.utils import update_counter + + +def make_space_matplotlib(agent_portrayal=None): + if agent_portrayal is None: + + def agent_portrayal(a): + return {"id": a.unique_id} + + def MakeSpaceMatplotlib(model): + return SpaceMatplotlib(model, agent_portrayal) + + return MakeSpaceMatplotlib + + +@solara.component +def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = None): + update_counter.get() + space_fig = Figure() + space_ax = space_fig.subplots() + space = getattr(model, "grid", None) + if space is None: + # Sometimes the space is defined as model.space instead of model.grid + space = model.space + if isinstance(space, mesa.space.NetworkGrid): + _draw_network_grid(space, space_ax, agent_portrayal) + elif isinstance(space, mesa.space.ContinuousSpace): + _draw_continuous_space(space, space_ax, agent_portrayal) + elif isinstance(space, VoronoiGrid): + _draw_voronoi(space, space_ax, agent_portrayal) + else: + _draw_grid(space, space_ax, agent_portrayal) + solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies) + + +# matplotlib scatter does not allow for multiple shapes in one call +def _split_and_scatter(portray_data, space_ax): + grouped_data = defaultdict(lambda: {"x": [], "y": [], "s": [], "c": []}) + + # Extract data from the dictionary + x = portray_data["x"] + y = portray_data["y"] + s = portray_data["s"] + c = portray_data["c"] + m = portray_data["m"] + + if not (len(x) == len(y) == len(s) == len(c) == len(m)): + raise ValueError( + "Length mismatch in portrayal data lists: " + f"x: {len(x)}, y: {len(y)}, size: {len(s)}, " + f"color: {len(c)}, marker: {len(m)}" + ) + + # Group the data by marker + for i in range(len(x)): + marker = m[i] + grouped_data[marker]["x"].append(x[i]) + grouped_data[marker]["y"].append(y[i]) + grouped_data[marker]["s"].append(s[i]) + grouped_data[marker]["c"].append(c[i]) + + # Plot each group with the same marker + for marker, data in grouped_data.items(): + space_ax.scatter(data["x"], data["y"], s=data["s"], c=data["c"], marker=marker) + + +def _draw_grid(space, space_ax, agent_portrayal): + def portray(g): + x = [] + y = [] + s = [] # size + c = [] # color + m = [] # shape + for i in range(g.width): + for j in range(g.height): + content = g._grid[i][j] + if not content: + continue + if not hasattr(content, "__iter__"): + # Is a single grid + content = [content] + for agent in content: + data = agent_portrayal(agent) + x.append(i) + y.append(j) + + # This is the default value for the marker size, which auto-scales + # according to the grid area. + default_size = (180 / max(g.width, g.height)) ** 2 + # establishing a default prevents misalignment if some agents are not given size, color, etc. + size = data.get("size", default_size) + s.append(size) + color = data.get("color", "tab:blue") + c.append(color) + mark = data.get("shape", "o") + m.append(mark) + out = {"x": x, "y": y, "s": s, "c": c, "m": m} + return out + + space_ax.set_xlim(-1, space.width) + space_ax.set_ylim(-1, space.height) + _split_and_scatter(portray(space), space_ax) + + +def _draw_network_grid(space, space_ax, agent_portrayal): + graph = space.G + pos = nx.spring_layout(graph, seed=0) + nx.draw( + graph, + ax=space_ax, + pos=pos, + **agent_portrayal(graph), + ) + + +def _draw_continuous_space(space, space_ax, agent_portrayal): + def portray(space): + x = [] + y = [] + s = [] # size + c = [] # color + m = [] # shape + for agent in space._agent_to_index: + data = agent_portrayal(agent) + _x, _y = agent.pos + x.append(_x) + y.append(_y) + + # This is matplotlib's default marker size + default_size = 20 + # establishing a default prevents misalignment if some agents are not given size, color, etc. + size = data.get("size", default_size) + s.append(size) + color = data.get("color", "tab:blue") + c.append(color) + mark = data.get("shape", "o") + m.append(mark) + out = {"x": x, "y": y, "s": s, "c": c, "m": m} + return out + + # Determine border style based on space.torus + border_style = "solid" if not space.torus else (0, (5, 10)) + + # Set the border of the plot + for spine in space_ax.spines.values(): + spine.set_linewidth(1.5) + spine.set_color("black") + spine.set_linestyle(border_style) + + width = space.x_max - space.x_min + x_padding = width / 20 + height = space.y_max - space.y_min + y_padding = height / 20 + space_ax.set_xlim(space.x_min - x_padding, space.x_max + x_padding) + space_ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding) + + # Portray and scatter the agents in the space + _split_and_scatter(portray(space), space_ax) + + +def _draw_voronoi(space, space_ax, agent_portrayal): + def portray(g): + x = [] + y = [] + s = [] # size + c = [] # color + + for cell in g.all_cells: + for agent in cell.agents: + data = agent_portrayal(agent) + x.append(cell.coordinate[0]) + y.append(cell.coordinate[1]) + if "size" in data: + s.append(data["size"]) + if "color" in data: + c.append(data["color"]) + out = {"x": x, "y": y} + # This is the default value for the marker size, which auto-scales + # according to the grid area. + out["s"] = s + if len(c) > 0: + out["c"] = c + + return out + + x_list = [i[0] for i in space.centroids_coordinates] + y_list = [i[1] for i in space.centroids_coordinates] + x_max = max(x_list) + x_min = min(x_list) + y_max = max(y_list) + y_min = min(y_list) + + width = x_max - x_min + x_padding = width / 20 + height = y_max - y_min + y_padding = height / 20 + space_ax.set_xlim(x_min - x_padding, x_max + x_padding) + space_ax.set_ylim(y_min - y_padding, y_max + y_padding) + space_ax.scatter(**portray(space)) + + for cell in space.all_cells: + polygon = cell.properties["polygon"] + space_ax.fill( + *zip(*polygon), + alpha=min(1, cell.properties[space.cell_coloring_property]), + c="red", + ) # Plot filled polygon + space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in red + + +def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): + def MakePlotMeasure(model): + return PlotMatplotlib(model, measure) + + return MakePlotMeasure + + +@solara.component +def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): + update_counter.get() + fig = Figure() + ax = fig.subplots() + df = model.datacollector.get_model_vars_dataframe() + if isinstance(measure, str): + ax.plot(df.loc[:, measure]) + ax.set_ylabel(measure) + elif isinstance(measure, dict): + for m, color in measure.items(): + ax.plot(df.loc[:, m], label=m, color=color) + fig.legend() + elif isinstance(measure, list | tuple): + for m in measure: + ax.plot(df.loc[:, m], label=m) + fig.legend() + # Set integer x axis + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + solara.FigureMatplotlib(fig, dependencies=dependencies) diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py new file mode 100644 index 00000000000..0d91adfbfab --- /dev/null +++ b/mesa/visualization/solara_viz.py @@ -0,0 +1,452 @@ +""" +Mesa visualization module for creating interactive model visualizations. + +This module provides components to create browser- and Jupyter notebook-based visualizations of +Mesa models, allowing users to watch models run step-by-step and interact with model parameters. + +Key features: + - SolaraViz: Main component for creating visualizations, supporting grid displays and plots + - ModelController: Handles model execution controls (step, play, pause, reset) + - UserInputs: Generates UI elements for adjusting model parameters + - Card: Renders individual visualization elements (space, measures) + +The module uses Solara for rendering in Jupyter notebooks or as standalone web applications. +It supports various types of visualizations including matplotlib plots, agent grids, and +custom visualization components. + +Usage: + 1. Define an agent_portrayal function to specify how agents should be displayed + 2. Set up model_params to define adjustable parameters + 3. Create a SolaraViz instance with your model, parameters, and desired measures + 4. Display the visualization in a Jupyter notebook or run as a Solara app + +See the Visualization Tutorial and example models for more details. +""" + +import copy +import threading +from typing import TYPE_CHECKING + +import reacton.ipywidgets as widgets +import solara +from solara.alias import rv + +import mesa.visualization.components.altair as components_altair +import mesa.visualization.components.matplotlib as components_matplotlib +from mesa.visualization.UserParam import Slider +from mesa.visualization.utils import force_update, update_counter + +if TYPE_CHECKING: + from mesa.model import Model + + +# TODO: Turn this function into a Solara component once the current_step.value +# dependency is passed to measure() +def Card( + model, measures, agent_portrayal, space_drawer, dependencies, color, layout_type +): + """ + Create a card component for visualizing model space or measures. + + Args: + model: The Mesa model instance + measures: List of measures to be plotted + agent_portrayal: Function to define agent appearance + space_drawer: Method to render agent space + dependencies: List of dependencies for updating the visualization + color: Background color of the card + layout_type: Type of layout (Space or Measure) + + Returns: + rv.Card: A card component containing the visualization + """ + with rv.Card( + style_=f"background-color: {color}; width: 100%; height: 100%" + ) as main: + if "Space" in layout_type: + rv.CardTitle(children=["Space"]) + if space_drawer == "default": + # draw with the default implementation + components_matplotlib.SpaceMatplotlib( + model, agent_portrayal, dependencies=dependencies + ) + elif space_drawer == "altair": + components_altair.SpaceAltair( + model, agent_portrayal, dependencies=dependencies + ) + elif space_drawer: + # if specified, draw agent space with an alternate renderer + space_drawer(model, agent_portrayal, dependencies=dependencies) + elif "Measure" in layout_type: + rv.CardTitle(children=["Measure"]) + measure = measures[layout_type["Measure"]] + if callable(measure): + # Is a custom object + measure(model) + else: + components_matplotlib.PlotMatplotlib( + model, measure, dependencies=dependencies + ) + return main + + +@solara.component +def SolaraViz( + model: "Model" | solara.Reactive["Model"], + components: list[solara.component] | None = None, + *args, + play_interval=150, + model_params=None, + seed=0, + name: str | None = None, +): + if components is None: + components = [] + + # Convert model to reactive + if not isinstance(model, solara.Reactive): + model = solara.use_reactive(model) + + def connect_to_model(): + # Patch the step function to force updates + original_step = model.value.step + + def step(): + original_step() + force_update() + + model.value.step = step + # Add a trigger to model itself + model.value.force_update = force_update + force_update() + + solara.use_effect(connect_to_model, [model.value]) + + with solara.AppBar(): + solara.AppBarTitle(name if name else model.value.__class__.__name__) + + with solara.Sidebar(): + with solara.Card("Controls", margin=1, elevation=2): + if model_params is not None: + ModelCreator( + model, + model_params, + seed=seed, + ) + ModelController(model, play_interval) + with solara.Card("Information", margin=1, elevation=2): + ShowSteps(model.value) + + solara.Column( + [ + *(component(model.value) for component in components), + ] + ) + + +JupyterViz = SolaraViz + + +@solara.component +def ModelController(model: solara.Reactive["Model"], play_interval): + """ + Create controls for model execution (step, play, pause, reset). + + Args: + model: The model being visualized + play_interval: Interval between steps during play + current_step: Reactive value for the current step + reset_counter: Counter to trigger model reset + """ + playing = solara.use_reactive(False) + thread = solara.use_reactive(None) + # We track the previous step to detect if user resets the model via + # clicking the reset button or changing the parameters. If previous_step > + # current_step, it means a model reset happens while the simulation is + # still playing. + previous_step = solara.use_reactive(0) + original_model = solara.use_reactive(None) + + def save_initial_model(): + """Save the initial model for comparison.""" + original_model.set(copy.deepcopy(model.value)) + + solara.use_effect(save_initial_model, [model.value]) + + def on_value_play(change): + """Handle play/pause state changes.""" + if previous_step.value > model.value.steps and model.value.steps == 0: + # We add extra checks for model.value.steps == 0, just to be sure. + # We automatically stop the playing if a model is reset. + playing.value = False + elif model.value.running: + do_step() + else: + playing.value = False + + def do_step(): + """Advance the model by one step.""" + previous_step.value = model.value.steps + model.value.step() + + def do_play(): + """Run the model continuously.""" + model.value.running = True + while model.value.running: + do_step() + + def threaded_do_play(): + """Start a new thread for continuous model execution.""" + if thread is not None and thread.is_alive(): + return + thread.value = threading.Thread(target=do_play) + thread.start() + + def do_pause(): + """Pause the model execution.""" + if (thread is None) or (not thread.is_alive()): + return + model.value.running = False + thread.join() + + def do_reset(): + """Reset the model""" + model.value = copy.deepcopy(original_model.value) + previous_step.value = 0 + force_update() + + def do_set_playing(value): + """Set the playing state.""" + if model.value.steps == 0: + # This means the model has been recreated, and the step resets to + # 0. We want to avoid triggering the playing.value = False in the + # on_value_play function. + previous_step.value = model.value.steps + playing.set(value) + + with solara.Row(): + solara.Button(label="Step", color="primary", on_click=do_step) + # This style is necessary so that the play widget has almost the same + # height as typical Solara buttons. + solara.Style( + """ + .widget-play { + height: 35px; + } + .widget-play button { + color: white; + background-color: #1976D2; // Solara blue color + } + """ + ) + widgets.Play( + value=0, + interval=play_interval, + repeat=True, + show_repeat=False, + on_value=on_value_play, + playing=playing.value, + on_playing=do_set_playing, + ) + solara.Button(label="Reset", color="primary", on_click=do_reset) + # threaded_do_play is not used for now because it + # doesn't work in Google colab. We use + # ipywidgets.Play until it is fixed. The threading + # version is definite a much better implementation, + # if it works. + # solara.Button(label="▶", color="primary", on_click=viz.threaded_do_play) + # solara.Button(label="⏸︎", color="primary", on_click=viz.do_pause) + # solara.Button(label="Reset", color="primary", on_click=do_reset) + + +def split_model_params(model_params): + """ + Split model parameters into user-adjustable and fixed parameters. + + Args: + model_params: Dictionary of all model parameters + + Returns: + tuple: (user_adjustable_params, fixed_params) + """ + model_params_input = {} + model_params_fixed = {} + for k, v in model_params.items(): + if check_param_is_fixed(v): + model_params_fixed[k] = v + else: + model_params_input[k] = v + return model_params_input, model_params_fixed + + +def check_param_is_fixed(param): + """ + Check if a parameter is fixed (not user-adjustable). + + Args: + param: Parameter to check + + Returns: + bool: True if parameter is fixed, False otherwise + """ + if isinstance(param, Slider): + return False + if not isinstance(param, dict): + return True + if "type" not in param: + return True + + +@solara.component +def ModelCreator(model, model_params, seed=1): + user_params, fixed_params = split_model_params(model_params) + + reactive_seed = solara.use_reactive(seed) + + model_parameters, set_model_parameters = solara.use_state( + { + **fixed_params, + **{k: v.get("value") for k, v in user_params.items()}, + } + ) + + def do_reseed(): + """Update the random seed for the model.""" + reactive_seed.value = model.value.random.random() + + def on_change(name, value): + set_model_parameters({**model_parameters, name: value}) + + def create_model(): + model.value = model.value.__class__.__new__( + model.value.__class__, **model_parameters, seed=reactive_seed.value + ) + model.value.__init__(**model_parameters) + + solara.use_effect(create_model, [model_parameters, reactive_seed.value]) + + solara.InputText( + label="Seed", + value=reactive_seed, + continuous_update=True, + ) + + solara.Button(label="Reseed", color="primary", on_click=do_reseed) + + UserInputs(user_params, on_change=on_change) + + +@solara.component +def UserInputs(user_params, on_change=None): + """ + Initialize user inputs for configurable model parameters. + Currently supports :class:`solara.SliderInt`, :class:`solara.SliderFloat`, + :class:`solara.Select`, and :class:`solara.Checkbox`. + + Args: + user_params: Dictionary with options for the input, including label, + min and max values, and other fields specific to the input type. + on_change: Function to be called with (name, value) when the value of an input changes. + """ + + for name, options in user_params.items(): + + def change_handler(value, name=name): + on_change(name, value) + + if isinstance(options, Slider): + slider_class = ( + solara.SliderFloat if options.is_float_slider else solara.SliderInt + ) + slider_class( + options.label, + value=options.value, + on_value=change_handler, + min=options.min, + max=options.max, + step=options.step, + ) + continue + + # label for the input is "label" from options or name + label = options.get("label", name) + input_type = options.get("type") + if input_type == "SliderInt": + solara.SliderInt( + label, + value=options.get("value"), + on_value=change_handler, + min=options.get("min"), + max=options.get("max"), + step=options.get("step"), + ) + elif input_type == "SliderFloat": + solara.SliderFloat( + label, + value=options.get("value"), + on_value=change_handler, + min=options.get("min"), + max=options.get("max"), + step=options.get("step"), + ) + elif input_type == "Select": + solara.Select( + label, + value=options.get("value"), + on_value=change_handler, + values=options.get("values"), + ) + elif input_type == "Checkbox": + solara.Checkbox( + label=label, + on_value=change_handler, + value=options.get("value"), + ) + else: + raise ValueError(f"{input_type} is not a supported input type") + + +def make_text(renderer): + """ + Create a function that renders text using Markdown. + + Args: + renderer: Function that takes a model and returns a string + + Returns: + function: A function that renders the text as Markdown + """ + + def function(model): + solara.Markdown(renderer(model)) + + return function + + +def make_initial_grid_layout(layout_types): + """ + Create an initial grid layout for visualization components. + + Args: + layout_types: List of layout types (Space or Measure) + + Returns: + list: Initial grid layout configuration + """ + return [ + { + "i": i, + "w": 6, + "h": 10, + "moved": False, + "x": 6 * (i % 2), + "y": 16 * (i - i % 2), + } + for i in range(len(layout_types)) + ] + + +@solara.component +def ShowSteps(model): + update_counter.get() + return solara.Text(f"Step: {model.steps}") diff --git a/mesa/visualization/utils.py b/mesa/visualization/utils.py new file mode 100644 index 00000000000..c49b35e3664 --- /dev/null +++ b/mesa/visualization/utils.py @@ -0,0 +1,7 @@ +import solara + +update_counter = solara.reactive(0) + + +def force_update(): + update_counter.value += 1 diff --git a/tests/test_solara_viz.py b/tests/test_solara_viz.py index 660277b3d7e..3ff8164065e 100644 --- a/tests/test_solara_viz.py +++ b/tests/test_solara_viz.py @@ -5,6 +5,7 @@ import solara import mesa +from mesa.visualization.components.matplotlib import make_space_matplotlib from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs @@ -95,50 +96,22 @@ def test_call_space_drawer(mocker): "Shape": "circle", "color": "gray", } - current_step = 0 - seed = 0 - dependencies = [current_step, seed] # initialize with space drawer unspecified (use default) # component must be rendered for code to run - solara.render( - SolaraViz( - model_class=mesa.Model, - model_params={}, - agent_portrayal=agent_portrayal, - ) - ) + solara.render(SolaraViz(model, components=[make_space_matplotlib(agent_portrayal)])) # should call default method with class instance and agent portrayal - mock_space_matplotlib.assert_called_with( - model, agent_portrayal, dependencies=dependencies - ) + mock_space_matplotlib.assert_called_with(model, agent_portrayal) - # specify no space should be drawn; any false value should work - for falsy_value in [None, False, 0]: - mock_space_matplotlib.reset_mock() - solara.render( - SolaraViz( - model_class=mesa.Model, - model_params={}, - agent_portrayal=agent_portrayal, - space_drawer=falsy_value, - ) - ) - # should call default method with class instance and agent portrayal - assert mock_space_matplotlib.call_count == 0 + # specify no space should be drawn + mock_space_matplotlib.reset_mock() + solara.render(SolaraViz(model, components=[])) + # should call default method with class instance and agent portrayal + assert mock_space_matplotlib.call_count == 0 # specify a custom space method altspace_drawer = Mock() - solara.render( - SolaraViz( - model_class=mesa.Model, - model_params={}, - agent_portrayal=agent_portrayal, - space_drawer=altspace_drawer, - ) - ) - altspace_drawer.assert_called_with( - model, agent_portrayal, dependencies=dependencies - ) + solara.render(SolaraViz(model, components=[altspace_drawer])) + altspace_drawer.assert_called_with(model) # check voronoi space drawer voronoi_model = mesa.Model() @@ -146,9 +119,7 @@ def test_call_space_drawer(mocker): centroids_coordinates=[(0, 1), (0, 0), (1, 0)], ) solara.render( - SolaraViz( - model_class=voronoi_model, model_params={}, agent_portrayal=agent_portrayal - ) + SolaraViz(voronoi_model, components=[make_space_matplotlib(agent_portrayal)]) )