Skip to content

Commit

Permalink
small update to CellCollect.select
Browse files Browse the repository at this point in the history
  • Loading branch information
quaquel committed Sep 21, 2024
1 parent e6874ad commit e9572a6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 16 deletions.
4 changes: 2 additions & 2 deletions mesa/experimental/cell_space/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Cell:
def __init__(
self,
coordinate: Coordinate,
capacity: float | None = None,
capacity: int | None = None,
random: Random | None = None,
) -> None:
"""Initialise the cell.
Expand All @@ -65,7 +65,7 @@ def __init__(
self.agents: list[
Agent
] = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, )
self.capacity = capacity
self.capacity: int = capacity
self.properties: dict[Coordinate, object] = {}
self.random = random

Expand Down
36 changes: 22 additions & 14 deletions mesa/experimental/cell_space/cell_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class CellCollection(Generic[T]):
"""

def __init__(
self,
cells: Mapping[T, list[CellAgent]] | Iterable[T],
random: Random | None = None,
self,
cells: Mapping[T, list[CellAgent]] | Iterable[T],
random: Random | None = None,
) -> None:
"""Initialize a CellCollection.
Expand Down Expand Up @@ -83,25 +83,33 @@ def select_random_agent(self) -> CellAgent:
"""
return self.random.choice(list(self.agents))

def select(self, filter_func: Callable[[T], bool] | None = None, n=0):

def select(self, filter_func: Callable[[T], bool] | None = None, at_most: int | float = float("inf"), ):
"""Select cells based on filter function.
Args:
filter_func: filter function
n: number of cells to select
at_most: The maximum amount of cells to select. Defaults to infinity.
- If an integer, at most the first number of matching cells is selected.
- If a float between 0 and 1, at most that fraction of original number of cells
Returns:
CellCollection
"""
# FIXME: n is not considered
if filter_func is None and n == 0:
if filter_func is None and at_most == float("inf"):
return self

return CellCollection(
{
cell: agents
for cell, agents in self._cells.items()
if filter_func is None or filter_func(cell)
}
)
if at_most <= 1.0 and isinstance(at_most, float):
at_most = int(len(self) * at_most) # Note that it rounds down (floor)

def cell_generator(filter_func, at_most):
count = 0
for cell in self:
if count >= at_most:
break
if (not filter_func or filter_func(cell)):
yield cell
count += 1

return CellCollection(cell_generator(filter_func, at_most))

0 comments on commit e9572a6

Please sign in to comment.