Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make connections in experiment cell spaces named #2096

Closed
wants to merge 14 commits into from
2 changes: 2 additions & 0 deletions mesa/experimental/cell_space/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from mesa.experimental.cell_space.cell import Cell
from mesa.experimental.cell_space.cell_agent import CellAgent
from mesa.experimental.cell_space.cell_collection import CellCollection
from mesa.experimental.cell_space.connection import Connection
from mesa.experimental.cell_space.discrete_space import DiscreteSpace
from mesa.experimental.cell_space.grid import (
Grid,
Expand All @@ -14,6 +15,7 @@
"CellCollection",
"Cell",
"CellAgent",
"Connection",
"DiscreteSpace",
"Grid",
"HexGrid",
Expand Down
7 changes: 4 additions & 3 deletions mesa/experimental/cell_space/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING

from mesa.experimental.cell_space.cell_collection import CellCollection
from mesa.experimental.cell_space.connection import Connection

if TYPE_CHECKING:
from mesa.experimental.cell_space.cell_agent import CellAgent
Expand Down Expand Up @@ -56,20 +57,20 @@ def __init__(
"""
super().__init__()
self.coordinate = coordinate
self._connections: list[Cell] = [] # TODO: change to CellCollection?
self._connections: Connection = Connection() # TODO: change to CellCollection?
self.agents = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, )
self.capacity = capacity
self.properties: dict[str, object] = {}
self.random = random

def connect(self, other: Cell) -> None:
def connect(self, other: Cell, name: str | None = None) -> None:
"""Connects this cell to another cell.

Args:
other (Cell): other cell to connect to

"""
self._connections.append(other)
self._connections.append(other, name)

def disconnect(self, other: Cell) -> None:
"""Disconnects this cell from another cell.
Expand Down
124 changes: 124 additions & 0 deletions mesa/experimental/cell_space/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from mesa.experimental.cell_space.cell import Cell

Check warning on line 6 in mesa/experimental/cell_space/connection.py

View check run for this annotation

Codecov / codecov/patch

mesa/experimental/cell_space/connection.py#L6

Added line #L6 was not covered by tests

import mesa.experimental.cell_space.cell as _cell


class Connection:
"""An immutable collection of connections"""

def __init__(self) -> None:
self._naming: dict[str, int] = {}
self._reverse_naming: dict[int, str] = {}
self._connections: list[Cell] = []

def __len__(self) -> int:
return len(self._connections)

def __getitem__(self, key: str | int) -> Cell:
"""Get the specified cell in the list of connections based on id or key

Arg:
key (str or int): the specific name or id of the connection

"""
if isinstance(key, str):
try:
conn_id = self._naming[key]
return self._connections[conn_id]
except KeyError as e:
raise KeyError("The connection name is not found.") from e

if isinstance(key, int):
try:
return self._connections[key]
except IndexError as e:
raise IndexError("The connection id is out of range.") from e

raise TypeError(
f"The connection key must be either str or int, but {type(key)} is found."
)

def __contains__(self, key: Cell | str | int) -> Cell:
"""Get the specified cell in the list of connections based on id or key

Arg:
key (str or int): the specific name or id of the connection

"""
if isinstance(key, str):
return key in self._naming

if isinstance(key, int):
return key in range(len(self._connections))

if isinstance(key, _cell.Cell):
return key in self._connections

raise TypeError(
f"The connection key must be either Cell or str or int, but {type(key)} is found."
)

def append(self, cell: Cell, name: str | None = None) -> None:
"""Add the new connection to the list of connections with an optional name

Arg:
cell (Cell): the cell to add to the list of connections
name (str, optional): the name of the connection

"""
if name is None:
self._connections.append(cell)
return

if name in self._naming:
raise ValueError("The connection key has already existed!")

conn_idx = len(self._connections)
self._naming[name] = conn_idx
self._reverse_naming[conn_idx] = name
self._connections.append(cell)

def remove(self, cell: Cell | str | int) -> None:
"""Remove a connection from the list of connections

Arg:
cell (Cell or str or int): the cell to add to the list of connections, it can be either a Cell object or a connection name or a connection index

"""
if isinstance(cell, _cell.Cell):
conn_idx = self._connections.index(cell)
self._connections.pop(conn_idx)
if conn_idx in self._reverse_naming:
conn_name = self._reverse_naming[conn_idx]
del self._naming[conn_name]
del self._reverse_naming[conn_idx]
return

if isinstance(cell, str):
if cell not in self._naming:
raise ValueError("The connection same does not exist!")

conn_idx = self._naming[cell]
self._connections.pop(conn_idx)
del self._naming[cell]
del self._reverse_naming[conn_idx]
return

if isinstance(cell, int):
if cell not in range(len(self._connections)):
raise ValueError("Connection index out of range!")
self._connections.pop(cell)
if cell in self._reverse_naming:
conn_name = self._reverse_naming[cell]
del self._naming[conn_name]
del self._reverse_naming[cell]
return

raise TypeError(
f"The argument must be either Cell or str or int, but {type(cell)} is found."
)
39 changes: 33 additions & 6 deletions mesa/experimental/cell_space/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,22 @@ def _connect_single_cell_nd(self, cell: T, offsets: list[tuple[int, ...]]) -> No
if all(0 <= nc < d for nc, d in zip(n_coord, self.dimensions)):
cell.connect(self._cells[n_coord])

def _connect_single_cell_2d(self, cell: T, offsets: list[tuple[int, int]]) -> None:
def _connect_single_cell_2d(
self,
cell: T,
offsets: list[tuple[int, int]],
conn_names: list[str] | None = None,
) -> None:
i, j = cell.coordinate
height, width = self.dimensions

for di, dj in offsets:
for idx, (di, dj) in enumerate(offsets):
name = conn_names[idx] if conn_names is not None else None
ni, nj = (i + di, j + dj)
if self.torus:
ni, nj = ni % height, nj % width
if 0 <= ni < height and 0 <= nj < width:
cell.connect(self._cells[ni, nj])
cell.connect(self._cells[ni, nj], name)


class OrthogonalMooreGrid(Grid[T]):
Expand All @@ -121,11 +127,16 @@ def _connect_cells_2d(self) -> None:
( 0, -1), ( 0, 1),
( 1, -1), ( 1, 0), ( 1, 1),
]
conn_names = [
"top left", "top", "top right",
"left", "right",
"bottom left", "bottom", "bottom right",
]
# fmt: on
height, width = self.dimensions

for cell in self.all_cells:
self._connect_single_cell_2d(cell, offsets)
self._connect_single_cell_2d(cell, offsets, conn_names)

def _connect_cells_nd(self) -> None:
offsets = list(product([-1, 0, 1], repeat=len(self.dimensions)))
Expand Down Expand Up @@ -153,11 +164,16 @@ def _connect_cells_2d(self) -> None:
(0, -1), (0, 1),
( 1, 0),
]
conn_names = [
"top",
"left", "right",
"bottom",
]
# fmt: on
height, width = self.dimensions

for cell in self.all_cells:
self._connect_single_cell_2d(cell, offsets)
self._connect_single_cell_2d(cell, offsets, conn_names)

def _connect_cells_nd(self) -> None:
offsets = []
Expand All @@ -183,17 +199,28 @@ def _connect_cells_2d(self) -> None:
( 0, -1), ( 0, 1),
( 1, -1), ( 1, 0),
]
even_names = [
"top left", "top",
"left", "right",
"bottom left", "bottom",
]
odd_offsets = [
(-1, 0), (-1, 1),
( 0, -1), ( 0, 1),
( 1, 0), ( 1, 1),
]
odd_names = [
"top", "top right",
"left", "right",
"bottom", "bottom right",
]
# fmt: on

for cell in self.all_cells:
i = cell.coordinate[0]
offsets = even_offsets if i % 2 == 0 else odd_offsets
self._connect_single_cell_2d(cell, offsets=offsets)
names = even_names if i % 2 == 0 else odd_names
self._connect_single_cell_2d(cell, offsets=offsets, conn_names=names)

def _connect_cells_nd(self) -> None:
raise NotImplementedError("HexGrids are only defined for 2 dimensions")
Expand Down
10 changes: 10 additions & 0 deletions tests/test_cell_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@ def test_orthogonal_grid_moore():
(1, 9), (1, 0), (1, 1)}
# fmt: on

# Traverse diagonally using names
current_cell = (9, 9)
assert "top left" in grid._cells[current_cell]._connections
while current_cell[0] > 0 and current_cell[1] > 0:
next_cell = grid._cells[current_cell]._connections["top left"].coordinate
# fmt: off
assert next_cell[0] == current_cell[0] - 1 and next_cell[1] == current_cell[1] - 1
# fmt: on
current_cell = next_cell


def test_orthogonal_grid_moore_3d():
width = 10
Expand Down
93 changes: 93 additions & 0 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest

from mesa.experimental.cell_space import (
Cell,
)


def test_get_connection():
cell = Cell(coordinate=(0, 0))
right_cell = Cell(coordinate=(1, 0))
left_cell = Cell(coordinate=(-1, 0))

cell.connect(right_cell, "right")

assert cell._connections["right"].coordinate == right_cell.coordinate
assert cell._connections[0].coordinate == right_cell.coordinate

with pytest.raises(KeyError):
left = cell._connections["left"]

with pytest.raises(TypeError):
obj = {
"name": "right",
}
random_cell = cell._connections[obj]


def test_in_connection():
cell = Cell(coordinate=(0, 0))
right_cell = Cell(coordinate=(1, 0))
left_cell = Cell(coordinate=(-1, 0))

cell.connect(right_cell, "right")

assert right_cell in cell._connections
assert left_cell not in cell._connections

assert 0 in cell._connections
assert 2 not in cell._connections

with pytest.raises(TypeError):
obj = {
"name": "right",
}
is_in = obj in cell._connections


def test_add_connection():
cell = Cell(coordinate=(0, 0))
right_cell = Cell(coordinate=(1, 0))
left_cell = Cell(coordinate=(-1, 0))

cell.connect(right_cell, "right")

# Raises exception when adding duplicated name
with pytest.raises(ValueError):
cell.connect(left_cell, "right")


def test_remove_connection():
cell = Cell(coordinate=(0, 0))
right_cell = Cell(coordinate=(1, 0))
left_cell = Cell(coordinate=(-1, 0))
bot_cell = Cell(coordinate=(0, -1))

cell.connect(right_cell, "right")
cell.connect(left_cell, "left")
cell.connect(bot_cell, "bottom")

assert 2 in cell._connections
cell._connections.remove(2)
assert 2 not in cell._connections

assert "left" in cell._connections
cell._connections.remove(left_cell)
assert "left" not in cell._connections

assert "right" in cell._connections
cell._connections.remove("right")
assert "right" not in cell._connections

# Raises exception when removing non-exsistent name
with pytest.raises(ValueError):
cell._connections.remove("top")

with pytest.raises(ValueError):
cell._connections.remove(123)

with pytest.raises(TypeError):
obj = {
"name": "right",
}
cell._connections.remove(obj)