-
Notifications
You must be signed in to change notification settings - Fork 880
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add new solara viz API --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ewout ter Hoeven <[email protected]>
- Loading branch information
1 parent
d01d15d
commit 48065fd
Showing
7 changed files
with
872 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
class UserParam: | ||
_ERROR_MESSAGE = "Missing or malformed inputs for '{}' Option '{}'" | ||
|
||
def maybe_raise_error(self, param_type, valid): | ||
if valid: | ||
return | ||
msg = self._ERROR_MESSAGE.format(param_type, self.label) | ||
raise ValueError(msg) | ||
|
||
|
||
class Slider(UserParam): | ||
""" | ||
A number-based slider input with settable increment. | ||
Example: | ||
slider_option = Slider("My Slider", value=123, min=10, max=200, step=0.1) | ||
Args: | ||
label: The displayed label in the UI | ||
value: The initial value of the slider | ||
min: The minimum possible value of the slider | ||
max: The maximum possible value of the slider | ||
step: The step between min and max for a range of possible values | ||
dtype: either int or float | ||
""" | ||
|
||
def __init__( | ||
self, | ||
label="", | ||
value=None, | ||
min=None, | ||
max=None, | ||
step=1, | ||
dtype=None, | ||
): | ||
self.label = label | ||
self.value = value | ||
self.min = min | ||
self.max = max | ||
self.step = step | ||
|
||
# Validate option type to make sure values are supplied properly | ||
valid = not (self.value is None or self.min is None or self.max is None) | ||
self.maybe_raise_error("slider", valid) | ||
|
||
if dtype is None: | ||
self.is_float_slider = self._check_values_are_float(value, min, max, step) | ||
else: | ||
self.is_float_slider = dtype is float | ||
|
||
def _check_values_are_float(self, value, min, max, step): | ||
return any(isinstance(n, float) for n in (value, min, max, step)) | ||
|
||
def get(self, attr): | ||
return getattr(self, attr) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from .components.altair import make_space_altair | ||
from .components.matplotlib import make_plot_measure, make_space_matplotlib | ||
from .solara_viz import JupyterViz, SolaraViz, make_text | ||
from .UserParam import Slider | ||
|
||
__all__ = [ | ||
"JupyterViz", | ||
"SolaraViz", | ||
"make_text", | ||
"Slider", | ||
"make_space_altair", | ||
"make_space_matplotlib", | ||
"make_plot_measure", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import contextlib | ||
|
||
import solara | ||
|
||
with contextlib.suppress(ImportError): | ||
import altair as alt | ||
|
||
from mesa.visualization.utils import update_counter | ||
|
||
|
||
def make_space_altair(agent_portrayal=None): | ||
if agent_portrayal is None: | ||
|
||
def agent_portrayal(a): | ||
return {"id": a.unique_id} | ||
|
||
def MakeSpaceAltair(model): | ||
return SpaceAltair(model, agent_portrayal) | ||
|
||
return MakeSpaceAltair | ||
|
||
|
||
@solara.component | ||
def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): | ||
update_counter.get() | ||
space = getattr(model, "grid", None) | ||
if space is None: | ||
# Sometimes the space is defined as model.space instead of model.grid | ||
space = model.space | ||
chart = _draw_grid(space, agent_portrayal) | ||
solara.FigureAltair(chart) | ||
|
||
|
||
def _draw_grid(space, agent_portrayal): | ||
def portray(g): | ||
all_agent_data = [] | ||
for content, (x, y) in g.coord_iter(): | ||
if not content: | ||
continue | ||
if not hasattr(content, "__iter__"): | ||
# Is a single grid | ||
content = [content] # noqa: PLW2901 | ||
for agent in content: | ||
# use all data from agent portrayal, and add x,y coordinates | ||
agent_data = agent_portrayal(agent) | ||
agent_data["x"] = x | ||
agent_data["y"] = y | ||
all_agent_data.append(agent_data) | ||
return all_agent_data | ||
|
||
all_agent_data = portray(space) | ||
invalid_tooltips = ["color", "size", "x", "y"] | ||
|
||
encoding_dict = { | ||
# no x-axis label | ||
"x": alt.X("x", axis=None, type="ordinal"), | ||
# no y-axis label | ||
"y": alt.Y("y", axis=None, type="ordinal"), | ||
"tooltip": [ | ||
alt.Tooltip(key, type=alt.utils.infer_vegalite_type([value])) | ||
for key, value in all_agent_data[0].items() | ||
if key not in invalid_tooltips | ||
], | ||
} | ||
has_color = "color" in all_agent_data[0] | ||
if has_color: | ||
encoding_dict["color"] = alt.Color("color", type="nominal") | ||
has_size = "size" in all_agent_data[0] | ||
if has_size: | ||
encoding_dict["size"] = alt.Size("size", type="quantitative") | ||
|
||
chart = ( | ||
alt.Chart( | ||
alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict) | ||
) | ||
.mark_point(filled=True) | ||
.properties(width=280, height=280) | ||
# .configure_view(strokeOpacity=0) # hide grid/chart lines | ||
) | ||
# This is the default value for the marker size, which auto-scales | ||
# according to the grid area. | ||
if not has_size: | ||
length = min(space.width, space.height) | ||
chart = chart.mark_point(size=30000 / length**2, filled=True) | ||
|
||
return chart |
Oops, something went wrong.