From e9572a607f8343986d44c3a1c97f8dc833a3a074 Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Sat, 21 Sep 2024 10:16:58 +0200 Subject: [PATCH] small update to CellCollect.select --- mesa/experimental/cell_space/cell.py | 4 +-- .../cell_space/cell_collection.py | 36 +++++++++++-------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/mesa/experimental/cell_space/cell.py b/mesa/experimental/cell_space/cell.py index 381d0f6cccb..08a37102ebb 100644 --- a/mesa/experimental/cell_space/cell.py +++ b/mesa/experimental/cell_space/cell.py @@ -48,7 +48,7 @@ class Cell: def __init__( self, coordinate: Coordinate, - capacity: float | None = None, + capacity: int | None = None, random: Random | None = None, ) -> None: """Initialise the cell. @@ -65,7 +65,7 @@ def __init__( self.agents: list[ Agent ] = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, ) - self.capacity = capacity + self.capacity: int = capacity self.properties: dict[Coordinate, object] = {} self.random = random diff --git a/mesa/experimental/cell_space/cell_collection.py b/mesa/experimental/cell_space/cell_collection.py index 14832d511be..2fb157a04a3 100644 --- a/mesa/experimental/cell_space/cell_collection.py +++ b/mesa/experimental/cell_space/cell_collection.py @@ -26,9 +26,9 @@ class CellCollection(Generic[T]): """ def __init__( - self, - cells: Mapping[T, list[CellAgent]] | Iterable[T], - random: Random | None = None, + self, + cells: Mapping[T, list[CellAgent]] | Iterable[T], + random: Random | None = None, ) -> None: """Initialize a CellCollection. @@ -83,25 +83,33 @@ def select_random_agent(self) -> CellAgent: """ return self.random.choice(list(self.agents)) - def select(self, filter_func: Callable[[T], bool] | None = None, n=0): + + def select(self, filter_func: Callable[[T], bool] | None = None, at_most: int | float = float("inf"), ): """Select cells based on filter function. Args: filter_func: filter function - n: number of cells to select + at_most: The maximum amount of cells to select. Defaults to infinity. + - If an integer, at most the first number of matching cells is selected. + - If a float between 0 and 1, at most that fraction of original number of cells Returns: CellCollection """ - # FIXME: n is not considered - if filter_func is None and n == 0: + if filter_func is None and at_most == float("inf"): return self - return CellCollection( - { - cell: agents - for cell, agents in self._cells.items() - if filter_func is None or filter_func(cell) - } - ) + if at_most <= 1.0 and isinstance(at_most, float): + at_most = int(len(self) * at_most) # Note that it rounds down (floor) + + def cell_generator(filter_func, at_most): + count = 0 + for cell in self: + if count >= at_most: + break + if (not filter_func or filter_func(cell)): + yield cell + count += 1 + + return CellCollection(cell_generator(filter_func, at_most))