Skip to content

Commit

Permalink
Add better docstrings and improve layout of solara viz
Browse files Browse the repository at this point in the history
  • Loading branch information
Corvince authored Sep 19, 2024
1 parent 696f123 commit b9088dd
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 75 deletions.
3 changes: 1 addition & 2 deletions mesa/visualization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

from .components.altair import make_space_altair
from .components.matplotlib import make_plot_measure, make_space_matplotlib
from .solara_viz import JupyterViz, SolaraViz, make_text
from .solara_viz import JupyterViz, SolaraViz
from .UserParam import Slider

__all__ = [
"JupyterViz",
"SolaraViz",
"make_text",
"Slider",
"make_space_altair",
"make_space_matplotlib",
Expand Down
200 changes: 133 additions & 67 deletions mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
- SolaraViz: Main component for creating visualizations, supporting grid displays and plots
- ModelController: Handles model execution controls (step, play, pause, reset)
- UserInputs: Generates UI elements for adjusting model parameters
- Card: Renders individual visualization elements (space, measures)
The module uses Solara for rendering in Jupyter notebooks or as standalone web applications.
It supports various types of visualizations including matplotlib plots, agent grids, and
Expand All @@ -22,10 +21,14 @@
See the Visualization Tutorial and example models for more details.
"""

from __future__ import annotations

import copy
import time
from collections.abc import Callable
from typing import TYPE_CHECKING, Literal

import reacton.core
import solara
from solara.alias import rv

Expand Down Expand Up @@ -89,31 +92,57 @@ def Card(

@solara.component
def SolaraViz(
model: "Model" | solara.Reactive["Model"],
components: list[solara.component] | Literal["default"] = "default",
play_interval=100,
model: Model | solara.Reactive[Model],
components: list[reacton.core.Component]
| list[Callable[[Model], reacton.core.Component]]
| Literal["default"] = "default",
play_interval: int = 100,
model_params=None,
seed=0,
seed: float = 0,
name: str | None = None,
):
"""Solara visualization component.
This component provides a visualization interface for a given model using Solara.
It supports various visualization components and allows for interactive model
stepping and parameter adjustments.
Args:
model: a Model instance
components: list of solara components
play_interval: int
model_params: parameters for instantiating a model
seed: the seed for the rng
name: str
model (Model | solara.Reactive[Model]): A Model instance or a reactive Model.
This is the main model to be visualized. If a non-reactive model is provided,
it will be converted to a reactive model.
components (list[solara.component] | Literal["default"], optional): List of solara
components or functions that return a solara component.
These components are used to render different parts of the model visualization.
Defaults to "default", which uses the default Altair space visualization.
play_interval (int, optional): Interval for playing the model steps in milliseconds.
This controls the speed of the model's automatic stepping. Defaults to 100 ms.
model_params (dict, optional): Parameters for (re-)instantiating a model.
Can include user-adjustable parameters and fixed parameters. Defaults to None.
seed (int, optional): Seed for the random number generator. This ensures reproducibility
of the model's behavior. Defaults to 0.
name (str | None, optional): Name of the visualization. Defaults to the models class name.
Returns:
solara.component: A Solara component that renders the visualization interface for the model.
Example:
>>> model = MyModel()
>>> page = SolaraViz(model)
>>> page
Notes:
- The `model` argument can be either a direct model instance or a reactive model. If a direct
model instance is provided, it will be converted to a reactive model using `solara.use_reactive`.
- The `play_interval` argument controls the speed of the model's automatic stepping. A lower
value results in faster stepping, while a higher value results in slower stepping.
"""
update_counter.get()
if components == "default":
components = [components_altair.make_space_altair()]

# Convert model to reactive
if not isinstance(model, solara.Reactive):
model = solara.use_reactive(model)
model = solara.use_reactive(model) # noqa: SH102, RUF100

def connect_to_model():
# Patch the step function to force updates
Expand All @@ -133,39 +162,68 @@ def step():
with solara.AppBar():
solara.AppBarTitle(name if name else model.value.__class__.__name__)

with solara.Sidebar():
with solara.Card("Controls", margin=1, elevation=2):
if model_params is not None:
with solara.Sidebar(), solara.Column():
with solara.Card("Controls"):
ModelController(model, play_interval)

if model_params is not None:
with solara.Card("Model Parameters"):
ModelCreator(
model,
model_params,
seed=seed,
)
ModelController(model, play_interval)
with solara.Card("Information", margin=1, elevation=2):
with solara.Card("Information"):
ShowSteps(model.value)

solara.Column(
[
*(component(model.value) for component in components),
]
)
ComponentsView(components, model.value)


def _wrap_component(
component: reacton.core.Component | Callable[[Model], reacton.core.Component],
) -> reacton.core.Component:
"""Wrap a component in an auto-updated Solara component if needed."""
if isinstance(component, reacton.core.Component):
return component

@solara.component
def WrappedComponent(model):
update_counter.get()
return component(model)

return WrappedComponent


@solara.component
def ComponentsView(
components: list[reacton.core.Component]
| list[Callable[[Model], reacton.core.Component]],
model: Model,
):
"""Display a list of components.
Args:
components: List of components to display
model: Model instance to pass to each component
"""
wrapped_components = [_wrap_component(component) for component in components]

with solara.Column():
for component in wrapped_components:
component(model)


JupyterViz = SolaraViz


@solara.component
def ModelController(model: solara.Reactive["Model"], play_interval=100):
def ModelController(model: solara.Reactive[Model], play_interval=100):
"""Create controls for model execution (step, play, pause, reset).
Args:
model: The reactive model being visualized
play_interval: Interval between steps during play
model (solara.Reactive[Model]): Reactive model instance
play_interval (int, optional): Interval for playing the model steps in milliseconds.
"""
if not isinstance(model, solara.Reactive):
model = solara.use_reactive(model)

playing = solara.use_reactive(False)
original_model = solara.use_reactive(None)

Expand All @@ -188,24 +246,25 @@ def do_step():
"""Advance the model by one step."""
model.value.step()

def do_play():
"""Run the model continuously."""
playing.value = True

def do_pause():
"""Pause the model execution."""
playing.value = False

def do_reset():
"""Reset the model to its initial state."""
playing.value = False
model.value = copy.deepcopy(original_model.value)

def do_play_pause():
"""Toggle play/pause."""
playing.value = not playing.value

with solara.Row(justify="space-between"):
solara.Button(label="Reset", color="primary", on_click=do_reset)
solara.Button(label="Step", color="primary", on_click=do_step)
solara.Button(label="▶", color="primary", on_click=do_play)
solara.Button(label="⏸︎", color="primary", on_click=do_pause)
solara.Button(
label="▶" if not playing.value else "❚❚",
color="primary",
on_click=do_play_pause,
)
solara.Button(
label="Step", color="primary", on_click=do_step, disabled=playing.value
)


def split_model_params(model_params):
Expand Down Expand Up @@ -246,13 +305,34 @@ def check_param_is_fixed(param):

@solara.component
def ModelCreator(model, model_params, seed=1):
"""Helper class to create a new Model instance.
"""Solara component for creating and managing a model instance with user-defined parameters.
This component allows users to create a model instance with specified parameters and seed.
It provides an interface for adjusting model parameters and reseeding the model's random
number generator.
Args:
model: model instance
model_params: model parameters
seed: the seed to use for the random number generator
model (solara.Reactive[Model]): A reactive model instance. This is the main model to be created and managed.
model_params (dict): Dictionary of model parameters. This includes both user-adjustable parameters and fixed parameters.
seed (int, optional): Initial seed for the random number generator. Defaults to 1.
Returns:
solara.component: A Solara component that renders the model creation and management interface.
Example:
>>> model = solara.reactive(MyModel())
>>> model_params = {
>>> "param1": {"type": "slider", "value": 10, "min": 0, "max": 100},
>>> "param2": {"type": "slider", "value": 5, "min": 1, "max": 10},
>>> }
>>> creator = ModelCreator(model, model_params)
>>> creator
Notes:
- The `model_params` argument should be a dictionary where keys are parameter names and values either fixed values
or are dictionaries containing parameter details such as type, value, min, and max.
- The `seed` argument ensures reproducibility by setting the initial seed for the model's random number generator.
- The component provides an interface for adjusting user-defined parameters and reseeding the model.
"""
user_params, fixed_params = split_model_params(model_params)
Expand All @@ -279,13 +359,14 @@ def create_model():

solara.use_effect(create_model, [model_parameters, reactive_seed.value])

solara.InputText(
label="Seed",
value=reactive_seed,
continuous_update=True,
)
with solara.Row(justify="space-between"):
solara.InputText(
label="Seed",
value=reactive_seed,
continuous_update=True,
)

solara.Button(label="Reseed", color="primary", on_click=do_reseed)
solara.Button(label="Reseed", color="primary", on_click=do_reseed)

UserInputs(user_params, on_change=on_change)

Expand Down Expand Up @@ -358,22 +439,6 @@ def change_handler(value, name=name):
raise ValueError(f"{input_type} is not a supported input type")


def make_text(renderer):
"""Create a function that renders text using Markdown.
Args:
renderer: Function that takes a model and returns a string
Returns:
function: A function that renders the text as Markdown
"""

def function(model):
solara.Markdown(renderer(model))

return function


def make_initial_grid_layout(layout_types):
"""Create an initial grid layout for visualization components.
Expand All @@ -397,6 +462,7 @@ def make_initial_grid_layout(layout_types):


@solara.component
def ShowSteps(model): # noqa: D103
def ShowSteps(model):
"""Display the current step of the model."""
update_counter.get()
return solara.Text(f"Step: {model.steps}")
21 changes: 15 additions & 6 deletions tests/test_solara_viz.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Test Solara visualizations."""

import unittest
from unittest.mock import Mock

import ipyvuetify as vw
import solara

import mesa
import mesa.visualization.components.altair
import mesa.visualization.components.matplotlib
from mesa.visualization.components.matplotlib import make_space_matplotlib
from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs

Expand Down Expand Up @@ -86,10 +87,12 @@ def Test(user_params):


def test_call_space_drawer(mocker): # noqa: D103
mock_space_matplotlib = mocker.patch(
"mesa.visualization.components.matplotlib.SpaceMatplotlib"
mock_space_matplotlib = mocker.spy(
mesa.visualization.components.matplotlib, "SpaceMatplotlib"
)

mock_space_altair = mocker.spy(mesa.visualization.components.altair, "SpaceAltair")

model = mesa.Model()
mocker.patch.object(mesa.Model, "__init__", return_value=None)

Expand All @@ -105,13 +108,19 @@ def test_call_space_drawer(mocker): # noqa: D103

# specify no space should be drawn
mock_space_matplotlib.reset_mock()
solara.render(SolaraViz(model, components=[]))
solara.render(SolaraViz(model))
# should call default method with class instance and agent portrayal
assert mock_space_matplotlib.call_count == 0
assert mock_space_altair.call_count > 0

# specify a custom space method
altspace_drawer = Mock()
solara.render(SolaraViz(model, components=[altspace_drawer]))
class AltSpace:
@staticmethod
def drawer(model):
return

altspace_drawer = mocker.spy(AltSpace, "drawer")
solara.render(SolaraViz(model, components=[AltSpace.drawer]))
altspace_drawer.assert_called_with(model)

# check voronoi space drawer
Expand Down

0 comments on commit b9088dd

Please sign in to comment.