Skip to content

Commit

Permalink
tests: Resolve warnings by removing scheduler and updating arguments (#…
Browse files Browse the repository at this point in the history
…2329)

This PR updates several test files to resolve warnings and deprecations introduced in Mesa 3.0. The changes align the tests with the new AgentSet functionality and the removal of schedulers.

* tests: Remove schedule from batch_run
* tests: Remove scheduler from datacollector tests
* tests: Replace n with at_most in AgentSet select
* tests: Replace scheduler in lifespan tests
* tests: Rename Agent classes to not start with Test
  • Loading branch information
EwoutH authored Sep 26, 2024
1 parent 28e52d5 commit 17cd62a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 61 deletions.
64 changes: 32 additions & 32 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
from mesa.model import Model


class TestAgent(Agent):
class AgentTest(Agent):
"""Agent class for testing."""

def get_unique_identifier(self):
"""Return unique identifier for this agent."""
return self.unique_id


class TestAgentDo(Agent):
class AgentDoTest(Agent):
"""Agent class for testing."""

def __init__(
Expand All @@ -37,7 +37,7 @@ def get_unique_identifier(self): # noqa: D102
return self.unique_id

def do_add(self): # noqa: D102
agent = TestAgentDo(self.model)
agent = AgentDoTest(self.model)
self.agent_set.add(agent)

def do_remove(self): # noqa: D102
Expand All @@ -47,7 +47,7 @@ def do_remove(self): # noqa: D102
def test_agent_removal():
"""Test agent removal."""
model = Model()
agent = TestAgent(model)
agent = AgentTest(model)
# Check if the agent is added
assert agent in model.agents

Expand All @@ -60,7 +60,7 @@ def test_agentset():
"""Test agentset class."""
# create agentset
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]

agentset = AgentSet(agents, model)

Expand All @@ -81,13 +81,13 @@ def test_function(agent):
assert len(agentset.select(at_most=1)) == 1 # Select 1 agent

assert len(agentset.select(test_function)) == 5
assert len(agentset.select(test_function, n=2)) == 2
assert len(agentset.select(test_function, at_most=2)) == 2
assert len(agentset.select(test_function, inplace=True)) == 5
assert agentset.select(inplace=True) == agentset
assert all(a1 == a2 for a1, a2 in zip(agentset.select(), agentset))
assert all(a1 == a2 for a1, a2 in zip(agentset.select(n=5), agentset[:5]))
assert all(a1 == a2 for a1, a2 in zip(agentset.select(at_most=5), agentset[:5]))

assert len(agentset.shuffle(inplace=False).select(n=5)) == 5
assert len(agentset.shuffle(inplace=False).select(at_most=5)) == 5

def test_function(agent):
return agent.unique_id
Expand Down Expand Up @@ -132,15 +132,15 @@ def test_agentset_initialization():
empty_agentset = AgentSet([], model)
assert len(empty_agentset) == 0

agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
assert len(agentset) == 10


def test_agentset_serialization():
"""Test pickleability of agentset."""
model = Model()
agents = [TestAgent(model) for _ in range(5)]
agents = [AgentTest(model) for _ in range(5)]
agentset = AgentSet(agents, model)

serialized = pickle.dumps(agentset)
Expand All @@ -155,17 +155,17 @@ def test_agentset_serialization():
def test_agent_membership():
"""Test agent membership in AgentSet."""
model = Model()
agents = [TestAgent(model) for _ in range(5)]
agents = [AgentTest(model) for _ in range(5)]
agentset = AgentSet(agents, model)

assert agents[0] in agentset
assert TestAgent(model) not in agentset
assert AgentTest(model) not in agentset


def test_agent_add_remove_discard():
"""Test adding, removing and discarding agents from AgentSet."""
model = Model()
agent = TestAgent(model)
agent = AgentTest(model)
agentset = AgentSet([], model)

agentset.add(agent)
Expand All @@ -185,7 +185,7 @@ def test_agent_add_remove_discard():
def test_agentset_get_item():
"""Test integer based access to AgentSet."""
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)

assert agentset[0] == agents[0]
Expand All @@ -199,7 +199,7 @@ def test_agentset_get_item():
def test_agentset_do_str():
"""Test AgentSet.do with str."""
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)

with pytest.raises(AttributeError):
Expand All @@ -212,7 +212,7 @@ def test_agentset_do_str():
# setup
n = 10
model = Model()
agents = [TestAgentDo(model) for _ in range(n)]
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, model)
for agent in agents:
agent.agent_set = agentset
Expand All @@ -222,7 +222,7 @@ def test_agentset_do_str():

# setup
model = Model()
agents = [TestAgentDo(model) for _ in range(10)]
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
for agent in agents:
agent.agent_set = agentset
Expand All @@ -234,7 +234,7 @@ def test_agentset_do_str():
def test_agentset_do_callable():
"""Test AgentSet.do with callable."""
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)

# Test callable with non-existent function
Expand All @@ -248,7 +248,7 @@ def test_agentset_do_callable():
# setup for lambda function tests
n = 10
model = Model()
agents = [TestAgentDo(model) for _ in range(n)]
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, model)
for agent in agents:
agent.agent_set = agentset
Expand All @@ -259,7 +259,7 @@ def test_agentset_do_callable():

# setup again for lambda function tests
model = Model()
agents = [TestAgentDo(model) for _ in range(10)]
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
for agent in agents:
agent.agent_set = agentset
Expand All @@ -277,7 +277,7 @@ def remove_function(agent):

# setup again for actual function tests
model = Model()
agents = [TestAgentDo(model) for _ in range(n)]
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, model)
for agent in agents:
agent.agent_set = agentset
Expand All @@ -288,7 +288,7 @@ def remove_function(agent):

# setup again for actual function tests
model = Model()
agents = [TestAgentDo(model) for _ in range(10)]
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
for agent in agents:
agent.agent_set = agentset
Expand All @@ -301,7 +301,7 @@ def remove_function(agent):
def test_agentset_get():
"""Test AgentSet.get."""
model = Model()
_ = [TestAgent(model) for i in range(10)]
[AgentTest(model) for _ in range(10)]

agentset = model.agents

Expand Down Expand Up @@ -347,7 +347,7 @@ def test_agentset_get():
def test_agentset_agg():
"""Test agentset.agg."""
model = Model()
agents = [TestAgent(model) for i in range(10)]
agents = [AgentTest(model) for i in range(10)]

# Assign some values to attributes
for i, agent in enumerate(agents):
Expand Down Expand Up @@ -409,7 +409,7 @@ def __init__(self, model, age=None):
def test_agentset_map_str():
"""Test AgentSet.map with strings."""
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)

with pytest.raises(AttributeError):
Expand All @@ -422,7 +422,7 @@ def test_agentset_map_str():
def test_agentset_map_callable():
"""Test AgentSet.map with callable."""
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)

# Test callable with non-existent function
Expand Down Expand Up @@ -476,7 +476,7 @@ def test_method(self):
def test_agentset_get_attribute():
"""Test AgentSet.get for attributes."""
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)

unique_ids = agentset.get("unique_id")
Expand All @@ -488,7 +488,7 @@ def test_agentset_get_attribute():
model = Model()
agents = []
for i in range(10):
agent = TestAgent(model)
agent = AgentTest(model)
agent.i = i**2
agents.append(agent)
agentset = AgentSet(agents, model)
Expand Down Expand Up @@ -516,17 +516,17 @@ def test_agentset_select_by_type():
"""Test AgentSet.select for agent type."""
model = Model()
# Create a mix of agents of two different types
test_agents = [TestAgent(model) for _ in range(4)]
test_agents = [AgentTest(model) for _ in range(4)]
other_agents = [OtherAgentType(model) for _ in range(6)]

# Combine the two types of agents
mixed_agents = test_agents + other_agents
agentset = AgentSet(mixed_agents, model)

# Test selection by type
selected_test_agents = agentset.select(agent_type=TestAgent)
selected_test_agents = agentset.select(agent_type=AgentTest)
assert len(selected_test_agents) == len(test_agents)
assert all(isinstance(agent, TestAgent) for agent in selected_test_agents)
assert all(isinstance(agent, AgentTest) for agent in selected_test_agents)
assert len(selected_test_agents) == 4

selected_other_agents = agentset.select(agent_type=OtherAgentType)
Expand All @@ -542,7 +542,7 @@ def test_agentset_select_by_type():
def test_agentset_shuffle():
"""Test AgentSet.shuffle."""
model = Model()
test_agents = [TestAgent(model) for _ in range(12)]
test_agents = [AgentTest(model) for _ in range(12)]

agentset = AgentSet(test_agents, model=model)
agentset = agentset.shuffle()
Expand Down
8 changes: 2 additions & 6 deletions tests/test_batch_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from mesa.batchrunner import _make_model_kwargs
from mesa.datacollection import DataCollector
from mesa.model import Model
from mesa.time import BaseScheduler


def test_make_model_kwargs(): # noqa: D103
Expand Down Expand Up @@ -52,7 +51,6 @@ def __init__(
variable_model_param=None,
variable_agent_param=None,
fixed_model_param=None,
schedule=None,
enable_agent_reporters=True,
n_agents=3,
**kwargs,
Expand All @@ -63,13 +61,11 @@ def __init__(
variable_model_param: variable model parameters
variable_agent_param: variable agent parameters
fixed_model_param: fixed model parameters
schedule: schedule instance
enable_agent_reporters: whether to enable agent reporters
n_agents: number of agents
kwargs: keyword arguments
"""
super().__init__()
self.schedule = BaseScheduler(self) if schedule is None else schedule
self.variable_model_param = variable_model_param
self.variable_agent_param = variable_agent_param
self.fixed_model_param = fixed_model_param
Expand All @@ -92,13 +88,13 @@ def init_agents(self):
else:
agent_val = self.variable_agent_param
for _ in range(self.n_agents):
self.schedule.add(MockAgent(self, agent_val))
MockAgent(self, agent_val)

def get_local_model_param(self): # noqa: D102
return 42

def step(self): # noqa: D102
self.schedule.step()
self.agents.do("step")
self.datacollector.collect(self)


Expand Down
20 changes: 6 additions & 14 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from mesa import Agent, Model
from mesa.datacollection import DataCollector
from mesa.time import BaseScheduler


class MockAgent(Agent):
Expand Down Expand Up @@ -60,21 +59,17 @@ def agent_function_with_params(agent, multiplier, offset): # noqa: D103
class MockModel(Model):
"""Minimalistic model for testing purposes."""

schedule = BaseScheduler(None)

def __init__(self): # noqa: D107
super().__init__()
self.schedule = BaseScheduler(self)
self.model_val = 100

self.n = 10
for i in range(1, self.n + 1):
self.schedule.add(MockAgent(self, val=i))
MockAgent(self, val=i)
self.datacollector = DataCollector(
model_reporters={
"total_agents": lambda m: m.schedule.get_agent_count(),
"total_agents": lambda m: len(m.agents),
"model_value": "model_val",
"model_calc": self.schedule.get_agent_count,
"model_calc_comp": [self.test_model_calc_comp, [3, 4]],
"model_calc_fail": [self.test_model_calc_comp, [12, 0]],
},
Expand All @@ -95,7 +90,7 @@ def test_model_calc_comp(self, input1, input2): # noqa: D102
return None

def step(self): # noqa: D102
self.schedule.step()
self.agents.do("step")
self.datacollector.collect(self)


Expand Down Expand Up @@ -135,11 +130,11 @@ def setUp(self):
self.model.datacollector.collect(self.model)
for i in range(7):
if i == 4:
self.model.schedule.remove(self.model.schedule._agents[3])
self.model.agents[3].remove()
self.model.step()

# Write to table:
for agent in self.model.schedule.agents:
for agent in self.model.agents:
agent.write_final_values()

def step_assertion(self, model_var): # noqa: D102
Expand All @@ -154,18 +149,15 @@ def test_model_vars(self):
data_collector = self.model.datacollector
assert "total_agents" in data_collector.model_vars
assert "model_value" in data_collector.model_vars
assert "model_calc" in data_collector.model_vars
assert "model_calc_comp" in data_collector.model_vars
assert "model_calc_fail" in data_collector.model_vars
length = 8
assert len(data_collector.model_vars["total_agents"]) == length
assert len(data_collector.model_vars["model_value"]) == length
assert len(data_collector.model_vars["model_calc"]) == length
assert len(data_collector.model_vars["model_calc_comp"]) == length
self.step_assertion(data_collector.model_vars["total_agents"])
for element in data_collector.model_vars["model_value"]:
assert element == 100
self.step_assertion(data_collector.model_vars["model_calc"])
for element in data_collector.model_vars["model_calc_comp"]:
assert element == 75
for element in data_collector.model_vars["model_calc_fail"]:
Expand Down Expand Up @@ -227,7 +219,7 @@ def test_exports(self):
model_vars = data_collector.get_model_vars_dataframe()
agent_vars = data_collector.get_agent_vars_dataframe()
table_df = data_collector.get_table_dataframe("Final_Values")
assert model_vars.shape == (8, 5)
assert model_vars.shape == (8, 4)
assert agent_vars.shape == (77, 4)
assert table_df.shape == (9, 2)

Expand Down
Loading

0 comments on commit 17cd62a

Please sign in to comment.