Skip to content

Commit

Permalink
Visualisation: Allow specifying Agent shapes in agent_portrayal (#2214)
Browse files Browse the repository at this point in the history
This PR allows specifying an `"shape"` in the `agent_portrayal` dictionary used by matplotlib component of the Solara visualisation. In short, it allows you represent an Agent in any [matplotlib marker](https://matplotlib.org/stable/api/markers_api.html), by adding a "shape" key-value pair to the `agent_portrayal` dictionary.

This is especially useful when you're using the default shape drawer for grid or continuous space.

For example:
```Python
def agent_portrayal(cell):
    return {
        "color": "blue"
        "size": 5,
        "shape": "h"  # marker is a hexagon!
    }
```
  • Loading branch information
rmhopkins4 authored Aug 21, 2024
1 parent 3cf1b76 commit 3ca9098
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 25 deletions.
3 changes: 2 additions & 1 deletion docs/tutorials/visualization_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@
"source": [
"#### Changing the agents\n",
"\n",
"In the visualization above, all we could see is the agents moving around -- but not how much money they had, or anything else of interest. Let's change it so that agents who are broke (wealth 0) are drawn in red, smaller. (TODO: currently, we can't predict the drawing order of the circles, so a broke agent may be overshadowed by a wealthy agent. We should fix this by doing a hollow circle instead)\n",
"In the visualization above, all we could see is the agents moving around -- but not how much money they had, or anything else of interest. Let's change it so that agents who are broke (wealth 0) are drawn in red, smaller. (TODO: Currently, we can't predict the drawing order of the circles, so a broke agent may be overshadowed by a wealthy agent. We should fix this by doing a hollow circle instead)\n",
"In addition to size and color, an agent's shape can also be customized when using the default drawer. The allowed values for shapes can be found [here](https://matplotlib.org/stable/api/markers_api.html).\n",
"\n",
"To do this, we go back to our `agent_portrayal` code and add some code to change the portrayal based on the agent properties and launch the server again."
]
Expand Down
83 changes: 60 additions & 23 deletions mesa/visualization/components/matplotlib.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict

import networkx as nx
import solara
from matplotlib.figure import Figure
Expand All @@ -23,12 +25,44 @@ def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = Non
solara.FigureMatplotlib(space_fig, format="png", dependencies=dependencies)


# matplotlib scatter does not allow for multiple shapes in one call
def _split_and_scatter(portray_data, space_ax):
grouped_data = defaultdict(lambda: {"x": [], "y": [], "s": [], "c": []})

# Extract data from the dictionary
x = portray_data["x"]
y = portray_data["y"]
s = portray_data["s"]
c = portray_data["c"]
m = portray_data["m"]

if not (len(x) == len(y) == len(s) == len(c) == len(m)):
raise ValueError(
"Length mismatch in portrayal data lists: "
f"x: {len(x)}, y: {len(y)}, size: {len(s)}, "
f"color: {len(c)}, marker: {len(m)}"
)

# Group the data by marker
for i in range(len(x)):
marker = m[i]
grouped_data[marker]["x"].append(x[i])
grouped_data[marker]["y"].append(y[i])
grouped_data[marker]["s"].append(s[i])
grouped_data[marker]["c"].append(c[i])

# Plot each group with the same marker
for marker, data in grouped_data.items():
space_ax.scatter(data["x"], data["y"], s=data["s"], c=data["c"], marker=marker)


def _draw_grid(space, space_ax, agent_portrayal):
def portray(g):
x = []
y = []
s = [] # size
c = [] # color
m = [] # shape
for i in range(g.width):
for j in range(g.height):
content = g._grid[i][j]
Expand All @@ -41,23 +75,23 @@ def portray(g):
data = agent_portrayal(agent)
x.append(i)
y.append(j)
if "size" in data:
s.append(data["size"])
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
# This is the default value for the marker size, which auto-scales
# according to the grid area.
out["s"] = (180 / max(g.width, g.height)) ** 2
if len(s) > 0:
out["s"] = s
if len(c) > 0:
out["c"] = c

# This is the default value for the marker size, which auto-scales
# according to the grid area.
default_size = (180 / max(g.width, g.height)) ** 2
# establishing a default prevents misalignment if some agents are not given size, color, etc.
size = data.get("size", default_size)
s.append(size)
color = data.get("color", "tab:blue")
c.append(color)
mark = data.get("shape", "o")
m.append(mark)
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
return out

space_ax.set_xlim(-1, space.width)
space_ax.set_ylim(-1, space.height)
space_ax.scatter(**portray(space))
_split_and_scatter(portray(space), space_ax)


def _draw_network_grid(space, space_ax, agent_portrayal):
Expand All @@ -77,20 +111,23 @@ def portray(space):
y = []
s = [] # size
c = [] # color
m = [] # shape
for agent in space._agent_to_index:
data = agent_portrayal(agent)
_x, _y = agent.pos
x.append(_x)
y.append(_y)
if "size" in data:
s.append(data["size"])
if "color" in data:
c.append(data["color"])
out = {"x": x, "y": y}
if len(s) > 0:
out["s"] = s
if len(c) > 0:
out["c"] = c

# This is matplotlib's default marker size
default_size = 20
# establishing a default prevents misalignment if some agents are not given size, color, etc.
size = data.get("size", default_size)
s.append(size)
color = data.get("color", "tab:blue")
c.append(color)
mark = data.get("shape", "o")
m.append(mark)
out = {"x": x, "y": y, "s": s, "c": c, "m": m}
return out

# Determine border style based on space.torus
Expand All @@ -110,7 +147,7 @@ def portray(space):
space_ax.set_ylim(space.y_min - y_padding, space.y_max + y_padding)

# Portray and scatter the agents in the space
space_ax.scatter(**portray(space))
_split_and_scatter(portray(space), space_ax)


@solara.component
Expand Down
3 changes: 2 additions & 1 deletion mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def SolaraViz(
model_params: Parameters for initializing the model
measures: List of callables or data attributes to plot
name: Name for display
agent_portrayal: Options for rendering agents (dictionary)
agent_portrayal: Options for rendering agents (dictionary);
Default drawer supports custom `"size"`, `"color"`, and `"shape"`.
space_drawer: Method to render the agent space for
the model; default implementation is the `SpaceMatplotlib` component;
simulations with no space to visualize should
Expand Down

0 comments on commit 3ca9098

Please sign in to comment.