Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests: Resolve warnings by removing scheduler and updating arguments #2329

Merged
merged 6 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 32 additions & 32 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
from mesa.model import Model


class TestAgent(Agent):
class AgentTest(Agent):
"""Agent class for testing."""

def get_unique_identifier(self):
"""Return unique identifier for this agent."""
return self.unique_id


class TestAgentDo(Agent):
class AgentDoTest(Agent):
"""Agent class for testing."""

def __init__(
Expand All @@ -37,7 +37,7 @@ def get_unique_identifier(self): # noqa: D102
return self.unique_id

def do_add(self): # noqa: D102
agent = TestAgentDo(self.model)
agent = AgentDoTest(self.model)
self.agent_set.add(agent)

def do_remove(self): # noqa: D102
Expand All @@ -47,7 +47,7 @@ def do_remove(self): # noqa: D102
def test_agent_removal():
"""Test agent removal."""
model = Model()
agent = TestAgent(model)
agent = AgentTest(model)
# Check if the agent is added
assert agent in model.agents

Expand All @@ -60,7 +60,7 @@ def test_agentset():
"""Test agentset class."""
# create agentset
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]

agentset = AgentSet(agents, model)

Expand All @@ -81,13 +81,13 @@ def test_function(agent):
assert len(agentset.select(at_most=1)) == 1 # Select 1 agent

assert len(agentset.select(test_function)) == 5
assert len(agentset.select(test_function, n=2)) == 2
assert len(agentset.select(test_function, at_most=2)) == 2
assert len(agentset.select(test_function, inplace=True)) == 5
assert agentset.select(inplace=True) == agentset
assert all(a1 == a2 for a1, a2 in zip(agentset.select(), agentset))
assert all(a1 == a2 for a1, a2 in zip(agentset.select(n=5), agentset[:5]))
assert all(a1 == a2 for a1, a2 in zip(agentset.select(at_most=5), agentset[:5]))

assert len(agentset.shuffle(inplace=False).select(n=5)) == 5
assert len(agentset.shuffle(inplace=False).select(at_most=5)) == 5

def test_function(agent):
return agent.unique_id
Expand Down Expand Up @@ -132,15 +132,15 @@ def test_agentset_initialization():
empty_agentset = AgentSet([], model)
assert len(empty_agentset) == 0

agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
assert len(agentset) == 10


def test_agentset_serialization():
"""Test pickleability of agentset."""
model = Model()
agents = [TestAgent(model) for _ in range(5)]
agents = [AgentTest(model) for _ in range(5)]
agentset = AgentSet(agents, model)

serialized = pickle.dumps(agentset)
Expand All @@ -155,17 +155,17 @@ def test_agentset_serialization():
def test_agent_membership():
"""Test agent membership in AgentSet."""
model = Model()
agents = [TestAgent(model) for _ in range(5)]
agents = [AgentTest(model) for _ in range(5)]
agentset = AgentSet(agents, model)

assert agents[0] in agentset
assert TestAgent(model) not in agentset
assert AgentTest(model) not in agentset


def test_agent_add_remove_discard():
"""Test adding, removing and discarding agents from AgentSet."""
model = Model()
agent = TestAgent(model)
agent = AgentTest(model)
agentset = AgentSet([], model)

agentset.add(agent)
Expand All @@ -185,7 +185,7 @@ def test_agent_add_remove_discard():
def test_agentset_get_item():
"""Test integer based access to AgentSet."""
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)

assert agentset[0] == agents[0]
Expand All @@ -199,7 +199,7 @@ def test_agentset_get_item():
def test_agentset_do_str():
"""Test AgentSet.do with str."""
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)

with pytest.raises(AttributeError):
Expand All @@ -212,7 +212,7 @@ def test_agentset_do_str():
# setup
n = 10
model = Model()
agents = [TestAgentDo(model) for _ in range(n)]
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, model)
for agent in agents:
agent.agent_set = agentset
Expand All @@ -222,7 +222,7 @@ def test_agentset_do_str():

# setup
model = Model()
agents = [TestAgentDo(model) for _ in range(10)]
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
for agent in agents:
agent.agent_set = agentset
Expand All @@ -234,7 +234,7 @@ def test_agentset_do_str():
def test_agentset_do_callable():
"""Test AgentSet.do with callable."""
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)

# Test callable with non-existent function
Expand All @@ -248,7 +248,7 @@ def test_agentset_do_callable():
# setup for lambda function tests
n = 10
model = Model()
agents = [TestAgentDo(model) for _ in range(n)]
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, model)
for agent in agents:
agent.agent_set = agentset
Expand All @@ -259,7 +259,7 @@ def test_agentset_do_callable():

# setup again for lambda function tests
model = Model()
agents = [TestAgentDo(model) for _ in range(10)]
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
for agent in agents:
agent.agent_set = agentset
Expand All @@ -277,7 +277,7 @@ def remove_function(agent):

# setup again for actual function tests
model = Model()
agents = [TestAgentDo(model) for _ in range(n)]
agents = [AgentDoTest(model) for _ in range(n)]
agentset = AgentSet(agents, model)
for agent in agents:
agent.agent_set = agentset
Expand All @@ -288,7 +288,7 @@ def remove_function(agent):

# setup again for actual function tests
model = Model()
agents = [TestAgentDo(model) for _ in range(10)]
agents = [AgentDoTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)
for agent in agents:
agent.agent_set = agentset
Expand All @@ -301,7 +301,7 @@ def remove_function(agent):
def test_agentset_get():
"""Test AgentSet.get."""
model = Model()
_ = [TestAgent(model) for i in range(10)]
[AgentTest(model) for _ in range(10)]

agentset = model.agents

Expand Down Expand Up @@ -347,7 +347,7 @@ def test_agentset_get():
def test_agentset_agg():
"""Test agentset.agg."""
model = Model()
agents = [TestAgent(model) for i in range(10)]
agents = [AgentTest(model) for i in range(10)]

# Assign some values to attributes
for i, agent in enumerate(agents):
Expand Down Expand Up @@ -409,7 +409,7 @@ def __init__(self, model, age=None):
def test_agentset_map_str():
"""Test AgentSet.map with strings."""
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)

with pytest.raises(AttributeError):
Expand All @@ -422,7 +422,7 @@ def test_agentset_map_str():
def test_agentset_map_callable():
"""Test AgentSet.map with callable."""
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)

# Test callable with non-existent function
Expand Down Expand Up @@ -476,7 +476,7 @@ def test_method(self):
def test_agentset_get_attribute():
"""Test AgentSet.get for attributes."""
model = Model()
agents = [TestAgent(model) for _ in range(10)]
agents = [AgentTest(model) for _ in range(10)]
agentset = AgentSet(agents, model)

unique_ids = agentset.get("unique_id")
Expand All @@ -488,7 +488,7 @@ def test_agentset_get_attribute():
model = Model()
agents = []
for i in range(10):
agent = TestAgent(model)
agent = AgentTest(model)
agent.i = i**2
agents.append(agent)
agentset = AgentSet(agents, model)
Expand Down Expand Up @@ -516,17 +516,17 @@ def test_agentset_select_by_type():
"""Test AgentSet.select for agent type."""
model = Model()
# Create a mix of agents of two different types
test_agents = [TestAgent(model) for _ in range(4)]
test_agents = [AgentTest(model) for _ in range(4)]
other_agents = [OtherAgentType(model) for _ in range(6)]

# Combine the two types of agents
mixed_agents = test_agents + other_agents
agentset = AgentSet(mixed_agents, model)

# Test selection by type
selected_test_agents = agentset.select(agent_type=TestAgent)
selected_test_agents = agentset.select(agent_type=AgentTest)
assert len(selected_test_agents) == len(test_agents)
assert all(isinstance(agent, TestAgent) for agent in selected_test_agents)
assert all(isinstance(agent, AgentTest) for agent in selected_test_agents)
assert len(selected_test_agents) == 4

selected_other_agents = agentset.select(agent_type=OtherAgentType)
Expand All @@ -542,7 +542,7 @@ def test_agentset_select_by_type():
def test_agentset_shuffle():
"""Test AgentSet.shuffle."""
model = Model()
test_agents = [TestAgent(model) for _ in range(12)]
test_agents = [AgentTest(model) for _ in range(12)]

agentset = AgentSet(test_agents, model=model)
agentset = agentset.shuffle()
Expand Down
8 changes: 2 additions & 6 deletions tests/test_batch_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from mesa.batchrunner import _make_model_kwargs
from mesa.datacollection import DataCollector
from mesa.model import Model
from mesa.time import BaseScheduler


def test_make_model_kwargs(): # noqa: D103
Expand Down Expand Up @@ -52,7 +51,6 @@ def __init__(
variable_model_param=None,
variable_agent_param=None,
fixed_model_param=None,
schedule=None,
enable_agent_reporters=True,
n_agents=3,
**kwargs,
Expand All @@ -63,13 +61,11 @@ def __init__(
variable_model_param: variable model parameters
variable_agent_param: variable agent parameters
fixed_model_param: fixed model parameters
schedule: schedule instance
enable_agent_reporters: whether to enable agent reporters
n_agents: number of agents
kwargs: keyword arguments
"""
super().__init__()
self.schedule = BaseScheduler(self) if schedule is None else schedule
self.variable_model_param = variable_model_param
self.variable_agent_param = variable_agent_param
self.fixed_model_param = fixed_model_param
Expand All @@ -92,13 +88,13 @@ def init_agents(self):
else:
agent_val = self.variable_agent_param
for _ in range(self.n_agents):
self.schedule.add(MockAgent(self, agent_val))
MockAgent(self, agent_val)

def get_local_model_param(self): # noqa: D102
return 42

def step(self): # noqa: D102
self.schedule.step()
self.agents.do("step")
self.datacollector.collect(self)


Expand Down
20 changes: 6 additions & 14 deletions tests/test_datacollector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from mesa import Agent, Model
from mesa.datacollection import DataCollector
from mesa.time import BaseScheduler


class MockAgent(Agent):
Expand Down Expand Up @@ -60,21 +59,17 @@ def agent_function_with_params(agent, multiplier, offset): # noqa: D103
class MockModel(Model):
"""Minimalistic model for testing purposes."""

schedule = BaseScheduler(None)

def __init__(self): # noqa: D107
super().__init__()
self.schedule = BaseScheduler(self)
self.model_val = 100

self.n = 10
for i in range(1, self.n + 1):
self.schedule.add(MockAgent(self, val=i))
MockAgent(self, val=i)
self.datacollector = DataCollector(
model_reporters={
"total_agents": lambda m: m.schedule.get_agent_count(),
"total_agents": lambda m: len(m.agents),
"model_value": "model_val",
"model_calc": self.schedule.get_agent_count,
"model_calc_comp": [self.test_model_calc_comp, [3, 4]],
"model_calc_fail": [self.test_model_calc_comp, [12, 0]],
},
Expand All @@ -95,7 +90,7 @@ def test_model_calc_comp(self, input1, input2): # noqa: D102
return None

def step(self): # noqa: D102
self.schedule.step()
self.agents.do("step")
self.datacollector.collect(self)


Expand Down Expand Up @@ -135,11 +130,11 @@ def setUp(self):
self.model.datacollector.collect(self.model)
for i in range(7):
if i == 4:
self.model.schedule.remove(self.model.schedule._agents[3])
self.model.agents[3].remove()
self.model.step()

# Write to table:
for agent in self.model.schedule.agents:
for agent in self.model.agents:
agent.write_final_values()

def step_assertion(self, model_var): # noqa: D102
Expand All @@ -154,18 +149,15 @@ def test_model_vars(self):
data_collector = self.model.datacollector
assert "total_agents" in data_collector.model_vars
assert "model_value" in data_collector.model_vars
assert "model_calc" in data_collector.model_vars
assert "model_calc_comp" in data_collector.model_vars
assert "model_calc_fail" in data_collector.model_vars
length = 8
assert len(data_collector.model_vars["total_agents"]) == length
assert len(data_collector.model_vars["model_value"]) == length
assert len(data_collector.model_vars["model_calc"]) == length
assert len(data_collector.model_vars["model_calc_comp"]) == length
self.step_assertion(data_collector.model_vars["total_agents"])
for element in data_collector.model_vars["model_value"]:
assert element == 100
self.step_assertion(data_collector.model_vars["model_calc"])
for element in data_collector.model_vars["model_calc_comp"]:
assert element == 75
for element in data_collector.model_vars["model_calc_fail"]:
Expand Down Expand Up @@ -227,7 +219,7 @@ def test_exports(self):
model_vars = data_collector.get_model_vars_dataframe()
agent_vars = data_collector.get_agent_vars_dataframe()
table_df = data_collector.get_table_dataframe("Final_Values")
assert model_vars.shape == (8, 5)
assert model_vars.shape == (8, 4)
assert agent_vars.shape == (77, 4)
assert table_df.shape == (9, 2)

Expand Down
Loading
Loading