Skip to content

Commit

Permalink
Bugfix for Solara deepcopy bug (#2460)
Browse files Browse the repository at this point in the history
closes #2427

adds textinput field as an option.
  • Loading branch information
quaquel authored Nov 6, 2024
1 parent 0c6dc35 commit ef383c4
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 45 deletions.
9 changes: 7 additions & 2 deletions mesa/examples/basic/boltzmann_wealth_model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def agent_portrayal(agent):
"max": 100,
"step": 1,
},
"seed": {
"type": "InputText",
"value": 42,
"label": "Random Seed",
},
"width": 10,
"height": 10,
}
Expand All @@ -30,7 +35,7 @@ def post_process(ax):


# Create initial model instance
model1 = BoltzmannWealthModel(50, 10, 10)
model = BoltzmannWealthModel(50, 10, 10)

# Create visualization elements. The visualization elements are solara components
# that receive the model instance as a "prop" and display it in a certain way.
Expand All @@ -49,7 +54,7 @@ def post_process(ax):
# solara run app.py
# It will automatically update and display any changes made to this file
page = SolaraViz(
model1,
model,
components=[SpaceGraph, GiniPlot],
model_params=model_params,
name="Boltzmann Wealth Model",
Expand Down
95 changes: 54 additions & 41 deletions mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from __future__ import annotations

import asyncio
import copy
import inspect
from collections.abc import Callable
from typing import TYPE_CHECKING, Literal
Expand All @@ -48,7 +47,6 @@ def SolaraViz(
| Literal["default"] = "default",
play_interval: int = 100,
model_params=None,
seed: float = 0,
name: str | None = None,
):
"""Solara visualization component.
Expand All @@ -69,8 +67,6 @@ def SolaraViz(
This controls the speed of the model's automatic stepping. Defaults to 100 ms.
model_params (dict, optional): Parameters for (re-)instantiating a model.
Can include user-adjustable parameters and fixed parameters. Defaults to None.
seed (int, optional): Seed for the random number generator. This ensures reproducibility
of the model's behavior. Defaults to 0.
name (str | None, optional): Name of the visualization. Defaults to the models class name.
Returns:
Expand All @@ -88,7 +84,9 @@ def SolaraViz(
value results in faster stepping, while a higher value results in slower stepping.
"""
if components == "default":
components = [components_altair.make_space_altair()]
components = [components_altair.make_altair_space()]
if model_params is None:
model_params = {}

# Convert model to reactive
if not isinstance(model, solara.Reactive):
Expand All @@ -109,20 +107,23 @@ def step():

solara.use_effect(connect_to_model, [model.value])

# set up reactive model_parameters shared by ModelCreator and ModelController
reactive_model_parameters = solara.use_reactive({})

with solara.AppBar():
solara.AppBarTitle(name if name else model.value.__class__.__name__)

with solara.Sidebar(), solara.Column():
with solara.Card("Controls"):
ModelController(model, play_interval)

if model_params is not None:
with solara.Card("Model Parameters"):
ModelCreator(
model,
model_params,
seed=seed,
)
ModelController(
model,
model_parameters=reactive_model_parameters,
play_interval=play_interval,
)
with solara.Card("Model Parameters"):
ModelCreator(
model, model_params, model_parameters=reactive_model_parameters
)
with solara.Card("Information"):
ShowSteps(model.value)

Expand Down Expand Up @@ -173,24 +174,24 @@ def ComponentsView(


@solara.component
def ModelController(model: solara.Reactive[Model], play_interval=100):
def ModelController(
model: solara.Reactive[Model],
*,
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int = 100,
):
"""Create controls for model execution (step, play, pause, reset).
Args:
model (solara.Reactive[Model]): Reactive model instance
play_interval (int, optional): Interval for playing the model steps in milliseconds.
model: Reactive model instance
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.
"""
playing = solara.use_reactive(False)
running = solara.use_reactive(True)
original_model = solara.use_reactive(None)

def save_initial_model():
"""Save the initial model for comparison."""
original_model.set(copy.deepcopy(model.value))
playing.value = False
force_update()

solara.use_effect(save_initial_model, [model.value])
if model_parameters is None:
model_parameters = solara.use_reactive({})

async def step():
while playing.value and running.value:
Expand All @@ -210,7 +211,7 @@ def do_reset():
"""Reset the model to its initial state."""
playing.value = False
running.value = True
model.value = copy.deepcopy(original_model.value)
model.value = model.value = model.value.__class__(**model_parameters.value)

def do_play_pause():
"""Toggle play/pause."""
Expand Down Expand Up @@ -269,17 +270,22 @@ def check_param_is_fixed(param):


@solara.component
def ModelCreator(model, model_params, seed=1):
def ModelCreator(
model: solara.Reactive[Model],
user_params: dict,
*,
model_parameters: dict | solara.Reactive[dict] = None,
):
"""Solara component for creating and managing a model instance with user-defined parameters.
This component allows users to create a model instance with specified parameters and seed.
It provides an interface for adjusting model parameters and reseeding the model's random
number generator.
Args:
model (solara.Reactive[Model]): A reactive model instance. This is the main model to be created and managed.
model_params (dict): Dictionary of model parameters. This includes both user-adjustable parameters and fixed parameters.
seed (int, optional): Initial seed for the random number generator. Defaults to 1.
model: A reactive model instance. This is the main model to be created and managed.
user_params: Parameters for (re-)instantiating a model. Can include user-adjustable parameters and fixed parameters. Defaults to None.
model_parameters: reactive parameters for reinitializing the model
Returns:
solara.component: A Solara component that renders the model creation and management interface.
Expand All @@ -300,24 +306,25 @@ def ModelCreator(model, model_params, seed=1):
- The component provides an interface for adjusting user-defined parameters and reseeding the model.
"""
if model_parameters is None:
model_parameters = solara.use_reactive({})

solara.use_effect(
lambda: _check_model_params(model.value.__class__.__init__, fixed_params),
[model.value],
)
user_params, fixed_params = split_model_params(user_params)

user_params, fixed_params = split_model_params(model_params)

model_parameters, set_model_parameters = solara.use_state(
{
**fixed_params,
**{k: v.get("value") for k, v in user_params.items()},
}
)
# set model_parameters to the default values for all parameters
model_parameters.value = {
**fixed_params,
**{k: v.get("value") for k, v in user_params.items()},
}

def on_change(name, value):
new_model_parameters = {**model_parameters, name: value}
new_model_parameters = {**model_parameters.value, name: value}
model.value = model.value.__class__(**new_model_parameters)
set_model_parameters(new_model_parameters)
model_parameters.value = new_model_parameters

UserInputs(user_params, on_change=on_change)

Expand Down Expand Up @@ -409,6 +416,12 @@ def change_handler(value, name=name):
on_value=change_handler,
value=options.get("value"),
)
elif input_type == "InputText":
solara.InputText(
label=label,
on_value=change_handler,
value=options.get("value"),
)
else:
raise ValueError(f"{input_type} is not a supported input type")

Expand Down
11 changes: 9 additions & 2 deletions tests/test_solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def test_call_space_drawer(mocker): # noqa: D103
mesa.visualization.components.altair_components, "SpaceAltair"
)

model = mesa.Model()
class MockModel(mesa.Model):
def __init__(self, seed=None):
super().__init__(seed=seed)

model = MockModel()
mocker.patch.object(mesa.Model, "__init__", return_value=None)

agent_portrayal = {
Expand All @@ -112,7 +116,10 @@ def test_call_space_drawer(mocker): # noqa: D103
# initialize with space drawer unspecified (use default)
# component must be rendered for code to run
solara.render(
SolaraViz(model, components=[make_mpl_space_component(agent_portrayal)])
SolaraViz(
model,
components=[make_mpl_space_component(agent_portrayal)],
)
)
# should call default method with class instance and agent portrayal
mock_space_matplotlib.assert_called_with(
Expand Down

0 comments on commit ef383c4

Please sign in to comment.