Skip to content

Commit

Permalink
Solaraviz api (#2263)
Browse files Browse the repository at this point in the history
* add new solara viz API

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ewout ter Hoeven <[email protected]>
  • Loading branch information
3 people authored Sep 4, 2024
1 parent d01d15d commit 48065fd
Show file tree
Hide file tree
Showing 7 changed files with 872 additions and 40 deletions.
56 changes: 56 additions & 0 deletions mesa/visualization/UserParam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
class UserParam:
_ERROR_MESSAGE = "Missing or malformed inputs for '{}' Option '{}'"

def maybe_raise_error(self, param_type, valid):
if valid:
return
msg = self._ERROR_MESSAGE.format(param_type, self.label)
raise ValueError(msg)


class Slider(UserParam):
"""
A number-based slider input with settable increment.
Example:
slider_option = Slider("My Slider", value=123, min=10, max=200, step=0.1)
Args:
label: The displayed label in the UI
value: The initial value of the slider
min: The minimum possible value of the slider
max: The maximum possible value of the slider
step: The step between min and max for a range of possible values
dtype: either int or float
"""

def __init__(
self,
label="",
value=None,
min=None,
max=None,
step=1,
dtype=None,
):
self.label = label
self.value = value
self.min = min
self.max = max
self.step = step

# Validate option type to make sure values are supplied properly
valid = not (self.value is None or self.min is None or self.max is None)
self.maybe_raise_error("slider", valid)

if dtype is None:
self.is_float_slider = self._check_values_are_float(value, min, max, step)
else:
self.is_float_slider = dtype is float

def _check_values_are_float(self, value, min, max, step):
return any(isinstance(n, float) for n in (value, min, max, step))

def get(self, attr):
return getattr(self, attr)
14 changes: 14 additions & 0 deletions mesa/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .components.altair import make_space_altair
from .components.matplotlib import make_plot_measure, make_space_matplotlib
from .solara_viz import JupyterViz, SolaraViz, make_text
from .UserParam import Slider

__all__ = [
"JupyterViz",
"SolaraViz",
"make_text",
"Slider",
"make_space_altair",
"make_space_matplotlib",
"make_plot_measure",
]
86 changes: 86 additions & 0 deletions mesa/visualization/components/altair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import contextlib

import solara

with contextlib.suppress(ImportError):
import altair as alt

from mesa.visualization.utils import update_counter


def make_space_altair(agent_portrayal=None):
if agent_portrayal is None:

def agent_portrayal(a):
return {"id": a.unique_id}

def MakeSpaceAltair(model):
return SpaceAltair(model, agent_portrayal)

return MakeSpaceAltair


@solara.component
def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None):
update_counter.get()
space = getattr(model, "grid", None)
if space is None:
# Sometimes the space is defined as model.space instead of model.grid
space = model.space
chart = _draw_grid(space, agent_portrayal)
solara.FigureAltair(chart)


def _draw_grid(space, agent_portrayal):
def portray(g):
all_agent_data = []
for content, (x, y) in g.coord_iter():
if not content:
continue
if not hasattr(content, "__iter__"):
# Is a single grid
content = [content] # noqa: PLW2901
for agent in content:
# use all data from agent portrayal, and add x,y coordinates
agent_data = agent_portrayal(agent)
agent_data["x"] = x
agent_data["y"] = y
all_agent_data.append(agent_data)
return all_agent_data

all_agent_data = portray(space)
invalid_tooltips = ["color", "size", "x", "y"]

encoding_dict = {
# no x-axis label
"x": alt.X("x", axis=None, type="ordinal"),
# no y-axis label
"y": alt.Y("y", axis=None, type="ordinal"),
"tooltip": [
alt.Tooltip(key, type=alt.utils.infer_vegalite_type([value]))
for key, value in all_agent_data[0].items()
if key not in invalid_tooltips
],
}
has_color = "color" in all_agent_data[0]
if has_color:
encoding_dict["color"] = alt.Color("color", type="nominal")
has_size = "size" in all_agent_data[0]
if has_size:
encoding_dict["size"] = alt.Size("size", type="quantitative")

chart = (
alt.Chart(
alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict)
)
.mark_point(filled=True)
.properties(width=280, height=280)
# .configure_view(strokeOpacity=0) # hide grid/chart lines
)
# This is the default value for the marker size, which auto-scales
# according to the grid area.
if not has_size:
length = min(space.width, space.height)
chart = chart.mark_point(size=30000 / length**2, filled=True)

return chart
Loading

0 comments on commit 48065fd

Please sign in to comment.