diff --git a/benchmarks/BoltzmannWealth/boltzmann_wealth.py b/benchmarks/BoltzmannWealth/boltzmann_wealth.py index 041445f1b8f..93d4da14aec 100644 --- a/benchmarks/BoltzmannWealth/boltzmann_wealth.py +++ b/benchmarks/BoltzmannWealth/boltzmann_wealth.py @@ -60,7 +60,7 @@ def __init__(self, seed=None, n=100, width=10, height=10): def step(self): """Run the model for a single step.""" - self.agents.shuffle().do("step") + self.agents.shuffle_do("step") # collect data self.datacollector.collect(self) diff --git a/benchmarks/WolfSheep/wolf_sheep.py b/benchmarks/WolfSheep/wolf_sheep.py index 0999fc77842..16c01e000c1 100644 --- a/benchmarks/WolfSheep/wolf_sheep.py +++ b/benchmarks/WolfSheep/wolf_sheep.py @@ -227,8 +227,8 @@ def __init__( def step(self): """Run one step of the model.""" - self.agents_by_type[Sheep].shuffle(inplace=True).do("step") - self.agents_by_type[Wolf].shuffle(inplace=True).do("step") + self.agents_by_type[Sheep].shuffle_do("step") + self.agents_by_type[Wolf].shuffle_do("step") if __name__ == "__main__": diff --git a/mesa/agent.py b/mesa/agent.py index 2d098c3549f..ec98efaad08 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -293,6 +293,23 @@ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: return self + def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet: + """Shuffle the agents in the AgentSet and then invoke a method or function on each agent. + + It's a fast, optimized version of calling shuffle() followed by do(). + """ + agents = list(self._agents.keys()) + self.random.shuffle(agents) + + if isinstance(method, str): + for agent in agents: + getattr(agent, method)(*args, **kwargs) + else: + for agent in agents: + method(agent, *args, **kwargs) + + return self + def map(self, method: str | Callable, *args, **kwargs) -> list[Any]: """Invoke a method or function on each agent in the AgentSet and return the results. diff --git a/mesa/experimental/devs/examples/epstein_civil_violence.py b/mesa/experimental/devs/examples/epstein_civil_violence.py index ce6b835e826..6f32e061356 100644 --- a/mesa/experimental/devs/examples/epstein_civil_violence.py +++ b/mesa/experimental/devs/examples/epstein_civil_violence.py @@ -293,7 +293,7 @@ def __init__( def step(self): """Run one step of the model.""" - self.active_agents.shuffle(inplace=True).do("step") + self.active_agents.shuffle_do("step") if __name__ == "__main__": diff --git a/mesa/experimental/devs/examples/wolf_sheep.py b/mesa/experimental/devs/examples/wolf_sheep.py index b76a35bf057..8d7d16d671a 100644 --- a/mesa/experimental/devs/examples/wolf_sheep.py +++ b/mesa/experimental/devs/examples/wolf_sheep.py @@ -230,8 +230,8 @@ def __init__( def step(self): """Perform one step of the model.""" - self.agents_by_type[Sheep].shuffle(inplace=True).do("step") - self.agents_by_type[Wolf].shuffle(inplace=True).do("step") + self.agents_by_type[Sheep].shuffle_do("step") + self.agents_by_type[Wolf].shuffle_do("step") if __name__ == "__main__": diff --git a/tests/test_agent.py b/tests/test_agent.py index f43a80c84ec..a0837de71a1 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -437,6 +437,42 @@ def test_agentset_map_callable(): assert all(i == entry for i, entry in zip(results, range(1, 11))) +def test_agentset_shuffle_do(): + """Test AgentSet.shuffle_do method.""" + model = Model() + + class TestAgentShuffleDo(Agent): + def __init__(self, model): + super().__init__(model) + self.called = False + + def test_method(self): + self.called = True + + agents = [TestAgentShuffleDo(model) for _ in range(100)] + agentset = AgentSet(agents, model) + + # Test shuffle_do with a string method name + agentset.shuffle_do("test_method") + assert all(agent.called for agent in agents) + + # Reset the called flag + for agent in agents: + agent.called = False + + # Test shuffle_do with a callable + agentset.shuffle_do(lambda agent: setattr(agent, "called", True)) + assert all(agent.called for agent in agents) + + # Verify that the order is indeed shuffled + original_order = list(agentset) + shuffled_order = [] + agentset.shuffle_do(lambda agent: shuffled_order.append(agent)) + assert ( + original_order != shuffled_order + ), "The order should be different after shuffle_do" + + def test_agentset_get_attribute(): """Test AgentSet.get for attributes.""" model = Model()