From 3fce5926cc846b79a35356b37dc9b6b513aff41c Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Tue, 17 Sep 2024 16:55:55 +0200 Subject: [PATCH] Enforce google docstrings (#2294) * further updates * Update benchmarks/WolfSheep/__init__.py * Update __init__.py * remove methods from class docstrings and make attributes bulleted lists * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update model.py * enforce google docstring * first pass for fixing all ruff issues with enforced google benchmarks * fix space.py for google docstring standard * ongoing fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes to model.py and time.py * further fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes for cellspaces * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * Update conf.py add extensin for parsing google docstring * minor fixes * last outstanding fixes in mesa source code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * start of fixing docstrings in tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remaining test fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_batch_run.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes for typos * made noqa's more explicit * Update fetch_unlabeled_prs.py * Update wolf_sheep.py * fixed all noqa directives * Update conf.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- benchmarks/BoltzmannWealth/__init__.py | 1 + .../BoltzmannWealth/boltzmann_wealth.py | 40 ++- benchmarks/Flocking/__init__.py | 1 + benchmarks/Flocking/flocking.py | 38 +-- benchmarks/Schelling/__init__.py | 1 + benchmarks/Schelling/schelling.py | 38 +-- benchmarks/WolfSheep/__init__.py | 1 + benchmarks/WolfSheep/wolf_sheep.py | 57 ++-- benchmarks/compare_timings.py | 15 +- benchmarks/configurations.py | 2 + benchmarks/global_benchmark.py | 19 ++ docs/conf.py | 7 +- docs/tutorials/MoneyModel.py | 19 +- maintenance/fetch_unlabeled_prs.py | 3 +- mesa/__init__.py | 3 +- mesa/agent.py | 101 +++--- mesa/batchrunner.py | 38 +-- .../hooks/post_gen_project.py | 2 + .../{{cookiecutter.snake}}/__init__.py | 1 + mesa/datacollection.py | 8 +- mesa/experimental/UserParam.py | 23 +- mesa/experimental/__init__.py | 2 + mesa/experimental/cell_space/__init__.py | 2 + mesa/experimental/cell_space/cell.py | 9 +- mesa/experimental/cell_space/cell_agent.py | 15 +- .../cell_space/cell_collection.py | 40 ++- .../experimental/cell_space/discrete_space.py | 19 +- mesa/experimental/cell_space/grid.py | 17 +- mesa/experimental/cell_space/network.py | 10 +- mesa/experimental/cell_space/voronoi.py | 48 ++- mesa/experimental/components/altair.py | 10 + mesa/experimental/components/matplotlib.py | 18 ++ mesa/experimental/devs/__init__.py | 2 + mesa/experimental/devs/eventlist.py | 51 ++- .../devs/examples/epstein_civil_violence.py | 92 ++++-- mesa/experimental/devs/examples/wolf_sheep.py | 68 ++-- mesa/experimental/devs/simulator.py | 70 ++++- mesa/experimental/solara_viz.py | 29 +- mesa/main.py | 10 +- mesa/model.py | 64 ++-- mesa/space.py | 252 ++++++++------- mesa/time.py | 111 +++---- mesa/visualization/UserParam.py | 25 +- mesa/visualization/__init__.py | 2 + mesa/visualization/components/altair.py | 6 +- mesa/visualization/components/matplotlib.py | 10 +- mesa/visualization/solara_viz.py | 52 ++-- mesa/visualization/utils.py | 4 +- pyproject.toml | 3 + tests/__init__.py | 1 + tests/read_requirements.py | 1 + tests/test_agent.py | 46 ++- tests/test_batch_run.py | 46 ++- tests/test_cell_space.py | 15 + tests/test_datacollector.py | 68 ++-- tests/test_devs.py | 6 + tests/test_examples.py | 12 +- tests/test_grid.py | 126 +++----- tests/test_import_namespace.py | 9 +- tests/test_lifespan.py | 27 +- tests/test_model.py | 11 + tests/test_scaffold.py | 9 +- tests/test_solara_viz.py | 16 +- tests/test_space.py | 290 +++++++----------- tests/test_time.py | 121 +++----- 65 files changed, 1249 insertions(+), 1014 deletions(-) diff --git a/benchmarks/BoltzmannWealth/__init__.py b/benchmarks/BoltzmannWealth/__init__.py index e69de29bb2d..b70e37fa100 100644 --- a/benchmarks/BoltzmannWealth/__init__.py +++ b/benchmarks/BoltzmannWealth/__init__.py @@ -0,0 +1 @@ +"""init file for BoltzmannWealth module.""" diff --git a/benchmarks/BoltzmannWealth/boltzmann_wealth.py b/benchmarks/BoltzmannWealth/boltzmann_wealth.py index 56443c6f476..041445f1b8f 100644 --- a/benchmarks/BoltzmannWealth/boltzmann_wealth.py +++ b/benchmarks/BoltzmannWealth/boltzmann_wealth.py @@ -1,8 +1,21 @@ -# https://github.com/projectmesa/mesa-examples/blob/main/examples/boltzmann_wealth_model_experimental/model.py +"""boltmann wealth model for performance benchmarking. + +https://github.com/projectmesa/mesa-examples/blob/main/examples/boltzmann_wealth_model_experimental/model.py +""" + import mesa def compute_gini(model): + """Calculate gini for wealth in model. + + Args: + model: a Model instance + + Returns: + float: gini score + + """ agent_wealths = [agent.wealth for agent in model.agents] x = sorted(agent_wealths) n = model.num_agents @@ -19,7 +32,15 @@ class BoltzmannWealth(mesa.Model): """ def __init__(self, seed=None, n=100, width=10, height=10): - super().__init__() + """Initializes the model. + + Args: + seed: the seed for random number generator + n: the number of agents + width: the width of the grid + height: the height of the grid + """ + super().__init__(seed) self.num_agents = n self.grid = mesa.space.MultiGrid(width, height, True) self.schedule = mesa.time.RandomActivation(self) @@ -38,11 +59,18 @@ def __init__(self, seed=None, n=100, width=10, height=10): self.datacollector.collect(self) def step(self): + """Run the model for a single step.""" self.agents.shuffle().do("step") # collect data self.datacollector.collect(self) def run_model(self, n): + """Run the model for n steps. + + Args: + n: the number of steps for which to run the model + + """ for _i in range(n): self.step() @@ -51,10 +79,16 @@ class MoneyAgent(mesa.Agent): """An agent with fixed initial wealth.""" def __init__(self, model): + """Instantiate an agent. + + Args: + model: a Model instance + """ super().__init__(model) self.wealth = 1 def move(self): + """Move the agent to a random neighboring cell.""" possible_steps = self.model.grid.get_neighborhood( self.pos, moore=True, include_center=False ) @@ -62,6 +96,7 @@ def move(self): self.model.grid.move_agent(self, new_position) def give_money(self): + """Give money to a random cell mate.""" cellmates = self.model.grid.get_cell_list_contents([self.pos]) cellmates.pop( cellmates.index(self) @@ -72,6 +107,7 @@ def give_money(self): self.wealth -= 1 def step(self): + """Run the agent for 1 step.""" self.move() if self.wealth > 0: self.give_money() diff --git a/benchmarks/Flocking/__init__.py b/benchmarks/Flocking/__init__.py index e69de29bb2d..684c3743037 100644 --- a/benchmarks/Flocking/__init__.py +++ b/benchmarks/Flocking/__init__.py @@ -0,0 +1 @@ +"""initi for flocking benchmark model.""" diff --git a/benchmarks/Flocking/flocking.py b/benchmarks/Flocking/flocking.py index 284d831850c..9678f34bb62 100644 --- a/benchmarks/Flocking/flocking.py +++ b/benchmarks/Flocking/flocking.py @@ -1,7 +1,5 @@ -""" -Flockers -============================================================= -A Mesa implementation of Craig Reynolds's Boids flocker model. +"""A Mesa implementation of Craig Reynolds's Boids flocker model. + Uses numpy arrays to represent vectors. """ @@ -11,8 +9,7 @@ class Boid(mesa.Agent): - """ - A Boid-style flocker agent. + """A Boid-style flocker agent. The agent follows three behaviors to flock: - Cohesion: steering towards neighboring agents. @@ -36,10 +33,10 @@ def __init__( separate=0.015, match=0.05, ): - """ - Create a new Boid flocker agent. + """Create a new Boid flocker agent. Args: + model: a Model instance speed: Distance to move per step. direction: numpy vector for the Boid's direction of movement. vision: Radius to look around for nearby Boids. @@ -59,10 +56,7 @@ def __init__( self.match_factor = match def step(self): - """ - Get the Boid's neighbors, compute the new vector, and move accordingly. - """ - + """Get the Boid's neighbors, compute the new vector, and move accordingly.""" neighbors = self.model.space.get_neighbors(self.pos, self.vision, False) n = 0 match_vector, separation_vector, cohere = np.zeros((3, 2)) @@ -84,9 +78,7 @@ def step(self): class BoidFlockers(mesa.Model): - """ - Flocker model class. Handles agent creation, placement and scheduling. - """ + """Flocker model class. Handles agent creation, placement and scheduling.""" def __init__( self, @@ -102,18 +94,21 @@ def __init__( match=0.05, simulator=None, ): - """ - Create a new Flockers model. + """Create a new Flockers model. Args: + seed: seed for random number generator population: Number of Boids - width, height: Size of the space. + width: the width of the space + height: the height of the space speed: How fast should the Boids move. vision: How far around should each Boid look for its neighbors - separation: What's the minimum distance each Boid will attempt to - keep from any other - cohere, separate, match: factors for the relative importance of + separation: What's the minimum distance each Boid will attempt to keep from any other + cohere: the relative importance of matching neighbors' positions' + separate: the relative importance of avoiding close neighbors + match: factors for the relative importance of the three drives. + simulator: a Simulator Instance """ super().__init__(seed=seed) self.population = population @@ -146,6 +141,7 @@ def __init__( self.schedule.add(boid) def step(self): + """Run the model for one step.""" self.schedule.step() diff --git a/benchmarks/Schelling/__init__.py b/benchmarks/Schelling/__init__.py index e69de29bb2d..de8d0f1a187 100644 --- a/benchmarks/Schelling/__init__.py +++ b/benchmarks/Schelling/__init__.py @@ -0,0 +1 @@ +"""Schelling separation for performance benchmarking.""" diff --git a/benchmarks/Schelling/schelling.py b/benchmarks/Schelling/schelling.py index b582e7af585..47bf521e057 100644 --- a/benchmarks/Schelling/schelling.py +++ b/benchmarks/Schelling/schelling.py @@ -1,19 +1,21 @@ +"""Schelling separation for performance benchmarking.""" + from mesa import Model from mesa.experimental.cell_space import CellAgent, OrthogonalMooreGrid from mesa.time import RandomActivation class SchellingAgent(CellAgent): - """ - Schelling segregation agent - """ + """Schelling segregation agent.""" def __init__(self, model, agent_type, radius, homophily): - """ - Create a new Schelling agent. + """Create a new Schelling agent. + Args: - x, y: Agent initial location. - agent_type: Indicator for the agent's type (minority=1, majority=0) + model: model instance + agent_type: type of agent (minority=1, majority=0) + radius: size of neighborhood of agent + homophily: fraction of neighbors of the same type that triggers movement """ super().__init__(model) self.type = agent_type @@ -21,6 +23,7 @@ def __init__(self, model, agent_type, radius, homophily): self.homophily = homophily def step(self): + """Run one step of the agent.""" similar = 0 neighborhood = self.cell.neighborhood(radius=self.radius) for neighbor in neighborhood.agents: @@ -35,9 +38,7 @@ def step(self): class Schelling(Model): - """ - Model class for the Schelling segregation model. - """ + """Model class for the Schelling segregation model.""" def __init__( self, @@ -50,16 +51,17 @@ def __init__( seed=None, simulator=None, ): - """ - Create a new Schelling model. + """Create a new Schelling model. Args: - height, width: Size of the space. - density: Initial Chance for a cell to populated - minority_pc: Chances for an agent to be in minority class + height: height of the grid + width: width of the grid homophily: Minimum number of agents of same class needed to be happy radius: Search radius for checking similarity - seed: Seed for Reproducibility + density: Initial Chance for a cell to populated + minority_pc: Chances for an agent to be in minority class + seed: the seed for the random number generator + simulator: a simulator instance """ super().__init__(seed=seed) self.minority_pc = minority_pc @@ -85,9 +87,7 @@ def __init__( self.schedule.add(agent) def step(self): - """ - Run one step of the model. - """ + """Run one step of the model.""" self.happy = 0 # Reset counter of happy agents self.schedule.step() diff --git a/benchmarks/WolfSheep/__init__.py b/benchmarks/WolfSheep/__init__.py index e69de29bb2d..89c18853af6 100644 --- a/benchmarks/WolfSheep/__init__.py +++ b/benchmarks/WolfSheep/__init__.py @@ -0,0 +1 @@ +"""Wolf-Sheep Predation Model for performance benchmarking.""" diff --git a/benchmarks/WolfSheep/wolf_sheep.py b/benchmarks/WolfSheep/wolf_sheep.py index ceb6efef8b8..0999fc77842 100644 --- a/benchmarks/WolfSheep/wolf_sheep.py +++ b/benchmarks/WolfSheep/wolf_sheep.py @@ -1,6 +1,4 @@ -""" -Wolf-Sheep Predation Model -================================ +"""Wolf-Sheep Predation Model for performance benchmarking. Replication of the model found in NetLogo: Wilensky, U. (1997). NetLogo Wolf Sheep Predation model. @@ -17,16 +15,28 @@ class Animal(CellAgent): + """The base animal class.""" + def __init__(self, model, energy, p_reproduce, energy_from_food): + """Initializes an animal. + + Args: + model: a model instance + energy: starting amount of energy + p_reproduce: probability of sexless reproduction + energy_from_food: energy obtained from 1 unit of food + """ super().__init__(model) self.energy = energy self.p_reproduce = p_reproduce self.energy_from_food = energy_from_food def random_move(self): + """Move to a random neighboring cell.""" self.move_to(self.cell.neighborhood().select_random_cell()) def spawn_offspring(self): + """Create offspring.""" self.energy /= 2 offspring = self.__class__( self.model, @@ -36,13 +46,15 @@ def spawn_offspring(self): ) offspring.move_to(self.cell) - def feed(self): ... + def feed(self): ... # noqa: D102 def die(self): + """Die.""" self.cell.remove_agent(self) self.remove() def step(self): + """One step of the agent.""" self.random_move() self.energy -= 1 @@ -55,13 +67,10 @@ def step(self): class Sheep(Animal): - """ - A sheep that walks around, reproduces (asexually) and gets eaten. - - The init is the same as the RandomWalker. - """ + """A sheep that walks around, reproduces (asexually) and gets eaten.""" def feed(self): + """If possible eat the food in the current location.""" # If there is grass available, eat it grass_patch = next( obj for obj in self.cell.agents if isinstance(obj, GrassPatch) @@ -72,11 +81,10 @@ def feed(self): class Wolf(Animal): - """ - A wolf that walks around, reproduces (asexually) and eats sheep. - """ + """A wolf that walks around, reproduces (asexually) and eats sheep.""" def feed(self): + """If possible eat the food in the current location.""" sheep = [obj for obj in self.cell.agents if isinstance(obj, Sheep)] if len(sheep) > 0: sheep_to_eat = self.random.choice(sheep) @@ -87,12 +95,10 @@ def feed(self): class GrassPatch(CellAgent): - """ - A patch of grass that grows at a fixed rate and it is eaten by sheep - """ + """A patch of grass that grows at a fixed rate and it is eaten by sheep.""" @property - def fully_grown(self): + def fully_grown(self): # noqa: D102 return self._fully_grown @fully_grown.setter @@ -107,17 +113,15 @@ def fully_grown(self, value: bool) -> None: ) def __init__(self, model, fully_grown, countdown, grass_regrowth_time): - """ - TODO:: fully grown can just be an int --> so one less param (i.e. countdown) - - Creates a new patch of grass + """Creates a new patch of grass. Args: + model: a model instance fully_grown: (boolean) Whether the patch of grass is fully grown or not countdown: Time for the patch of grass to be fully grown again grass_regrowth_time : time to fully regrow grass - countdown : Time for the patch of grass to be fully regrown if fully grown is False """ + # TODO:: fully grown can just be an int --> so one less param (i.e. countdown) super().__init__(model) self._fully_grown = fully_grown self.grass_regrowth_time = grass_regrowth_time @@ -129,8 +133,7 @@ def __init__(self, model, fully_grown, countdown, grass_regrowth_time): class WolfSheep(Model): - """ - Wolf-Sheep Predation Model + """Wolf-Sheep Predation Model. A model for simulating wolf and sheep (predator-prey) ecosystem modelling. """ @@ -149,18 +152,19 @@ def __init__( sheep_gain_from_food=5, seed=None, ): - """ - Create a new Wolf-Sheep model with the given parameters. + """Create a new Wolf-Sheep model with the given parameters. Args: simulator: ABMSimulator instance + width: width of the grid + height: height of the grid initial_sheep: Number of sheep to start with initial_wolves: Number of wolves to start with sheep_reproduce: Probability of each sheep reproducing each step wolf_reproduce: Probability of each wolf reproducing each step - wolf_gain_from_food: Energy a wolf gains from eating a sheep grass_regrowth_time: How long it takes for a grass patch to regrow once it is eaten + wolf_gain_from_food: Energy a wolf gains from eating a sheep sheep_gain_from_food: Energy sheep gain from grass, if enabled. seed : the random seed """ @@ -222,6 +226,7 @@ def __init__( patch.move_to(cell) def step(self): + """Run one step of the model.""" self.agents_by_type[Sheep].shuffle(inplace=True).do("step") self.agents_by_type[Wolf].shuffle(inplace=True).do("step") diff --git a/benchmarks/compare_timings.py b/benchmarks/compare_timings.py index d6f1c11cf89..89b177045cf 100644 --- a/benchmarks/compare_timings.py +++ b/benchmarks/compare_timings.py @@ -1,3 +1,5 @@ +"""compare timings across 2 benchmarks.""" + import pickle import numpy as np @@ -13,8 +15,17 @@ timings_2 = pickle.load(handle) # noqa: S301 -# Function to calculate the percentage change and perform bootstrap to estimate the confidence interval def bootstrap_percentage_change_confidence_interval(data1, data2, n=1000): + """Calculate the percentage change and perform bootstrap to estimate the confidence interval. + + Args: + data1: benchmark dataset 1 + data2: benchmark dataset 2 + n: bootstrap sample size + + Returns: + float, mean, and lower and upper bound of confidence interval. + """ change_samples = [] for _ in range(n): sampled_indices = np.random.choice( @@ -32,8 +43,8 @@ def bootstrap_percentage_change_confidence_interval(data1, data2, n=1000): results_df = pd.DataFrame() -# Function to determine the emoji based on change and confidence interval def performance_emoji(lower, upper): + """Function to determine the emoji based on change and confidence interval.""" if upper < -3: return "🟢" # Emoji for faster performance elif lower > 3: diff --git a/benchmarks/configurations.py b/benchmarks/configurations.py index 6d80ef3a897..0f2be5410b2 100644 --- a/benchmarks/configurations.py +++ b/benchmarks/configurations.py @@ -1,3 +1,5 @@ +"""configurations for benchmarks.""" + from BoltzmannWealth.boltzmann_wealth import BoltzmannWealth from Flocking.flocking import BoidFlockers from Schelling.schelling import Schelling diff --git a/benchmarks/global_benchmark.py b/benchmarks/global_benchmark.py index 19187c219d5..41c2643f88c 100644 --- a/benchmarks/global_benchmark.py +++ b/benchmarks/global_benchmark.py @@ -1,3 +1,5 @@ +"""runner for global performance benchmarks.""" + import gc import os import pickle @@ -16,6 +18,16 @@ # Generic function to initialize and run a model def run_model(model_class, seed, parameters): + """Run model for given seed and parameter values. + + Args: + model_class: a model class + seed: the seed + parameters: parameters for the run + + Returns: + startup time and run time + """ no_simulator = ["BoltzmannWealth"] start_init = timeit.default_timer() if model_class.__name__ in no_simulator: @@ -39,6 +51,13 @@ def run_model(model_class, seed, parameters): # Function to run experiments and save the fastest replication for each seed def run_experiments(model_class, config): + """Run performance benchmarks. + + Args: + model_class: the model class to use for the benchmark + config: the benchmark configuration + + """ gc.enable() sys.path.insert(0, os.path.abspath(".")) diff --git a/docs/conf.py b/docs/conf.py index ad23acab42f..603c64c7c9a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,3 +1,4 @@ +# noqa: D100 #!/usr/bin/env python3 # # Mesa documentation build configuration file, created by @@ -12,11 +13,10 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys import os +import sys from datetime import date - # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. @@ -42,6 +42,7 @@ "sphinx.ext.mathjax", "sphinx.ext.ifconfig", "sphinx.ext.viewcode", + "sphinx.ext.napoleon", # for google style docstrings "myst_nb", # For Markdown and Jupyter notebooks ] @@ -281,4 +282,4 @@ # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'python': ('https://docs.python.org/3', None)} +intersphinx_mapping = {"python": ("https://docs.python.org/3", None)} diff --git a/docs/tutorials/MoneyModel.py b/docs/tutorials/MoneyModel.py index 93e19dec11f..d723a52b654 100644 --- a/docs/tutorials/MoneyModel.py +++ b/docs/tutorials/MoneyModel.py @@ -1,5 +1,4 @@ -#!/usr/bin/env python3 - +"""a simple version of the boltman wealth model""" import mesa @@ -16,10 +15,16 @@ class MoneyAgent(mesa.Agent): """An agent with fixed initial wealth.""" def __init__(self, model): + """initialize a MoneyAgent instance. + + Args: + model: A model instance + """ super().__init__(model) self.wealth = 1 def move(self): + """move to a random neighboring cell.""" possible_steps = self.model.grid.get_neighborhood( self.pos, moore=True, include_center=False ) @@ -27,6 +32,7 @@ def move(self): self.model.grid.move_agent(self, new_position) def give_money(self): + """give money to another agent in the same gridcell.""" cellmates = self.model.grid.get_cell_list_contents([self.pos]) if len(cellmates) > 1: other = self.random.choice(cellmates) @@ -34,6 +40,7 @@ def give_money(self): self.wealth -= 1 def step(self): + """do one step of the agent.""" self.move() if self.wealth > 0: self.give_money() @@ -43,6 +50,13 @@ class MoneyModel(mesa.Model): """A model with some number of agents.""" def __init__(self, N, width, height): + """Initialize a MoneyModel instance. + + Args: + N: The number of agents. + width: width of the grid. + height: Height of the grid. + """ super().__init__() self.num_agents = N self.grid = mesa.space.MultiGrid(width, height, True) @@ -62,5 +76,6 @@ def __init__(self, N, width, height): ) def step(self): + """do one step of the model""" self.datacollector.collect(self) self.schedule.step() diff --git a/maintenance/fetch_unlabeled_prs.py b/maintenance/fetch_unlabeled_prs.py index 0b01a291a01..1801a145a63 100644 --- a/maintenance/fetch_unlabeled_prs.py +++ b/maintenance/fetch_unlabeled_prs.py @@ -1,3 +1,4 @@ +# noqa: D100 import os from datetime import datetime @@ -66,7 +67,7 @@ def get_closed_pull_requests_since_latest_release( return pull_requests -def main() -> None: +def main() -> None: # noqa: D103 # Based on https://github.com/projectmesa/mesa/pull/1917#issuecomment-1871352058 latest_release_date = get_latest_release_date() pull_requests = get_closed_pull_requests_since_latest_release(latest_release_date) diff --git a/mesa/__init__.py b/mesa/__init__.py index dc4791b4e08..7e3b014b817 100644 --- a/mesa/__init__.py +++ b/mesa/__init__.py @@ -1,5 +1,4 @@ -""" -Mesa Agent-Based Modeling Framework +"""Mesa Agent-Based Modeling Framework. Core Objects: Model, and Agent. """ diff --git a/mesa/agent.py b/mesa/agent.py index 75d79f90563..6d63ac79849 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -1,5 +1,4 @@ -""" -The agent class for Mesa framework. +"""The agent class for Mesa framework. Core Objects: Agent """ @@ -30,8 +29,7 @@ class Agent: - """ - Base class for a model agent in Mesa. + """Base class for a model agent in Mesa. Attributes: model (Model): A reference to the model instance. @@ -49,11 +47,12 @@ class Agent: _ids = defaultdict(functools.partial(itertools.count, 1)) def __init__(self, *args, **kwargs) -> None: - """ - Create a new agent. + """Create a new agent. Args: model (Model): The model instance in which the agent exists. + args: currently ignored, to be fixed in 3.1 + kwargs: currently ignored, to be fixed in 3.1 """ # TODO: Cleanup in future Mesa version (3.1+) match args: @@ -89,28 +88,25 @@ def remove(self) -> None: def step(self) -> None: """A single step of the agent.""" - def advance(self) -> None: + def advance(self) -> None: # noqa: D102 pass @property def random(self) -> Random: + """Return a seeded rng.""" return self.model.random class AgentSet(MutableSet, Sequence): - """ - A collection class that represents an ordered set of agents within an agent-based model (ABM). This class - extends both MutableSet and Sequence, providing set-like functionality with order preservation and + """A collection class that represents an ordered set of agents within an agent-based model (ABM). + + This class extends both MutableSet and Sequence, providing set-like functionality with order preservation and sequence operations. Attributes: model (Model): The ABM model instance to which this AgentSet belongs. - Methods: - __len__, __iter__, __contains__, select, shuffle, sort, _update, do, get, __getitem__, - add, discard, remove, __getstate__, __setstate__, random - - Note: + Notes: The AgentSet maintains weak references to agents, allowing for efficient management of agent lifecycles without preventing garbage collection. It is associated with a specific model instance, enabling interactions with the model's environment and other agents.The implementation uses a WeakKeyDictionary to store agents, @@ -118,14 +114,12 @@ class AgentSet(MutableSet, Sequence): """ def __init__(self, agents: Iterable[Agent], model: Model): - """ - Initializes the AgentSet with a collection of agents and a reference to the model. + """Initializes the AgentSet with a collection of agents and a reference to the model. Args: agents (Iterable[Agent]): An iterable of Agent objects to be included in the set. model (Model): The ABM model instance to which this AgentSet belongs. """ - self.model = model self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents}) @@ -149,8 +143,7 @@ def select( agent_type: type[Agent] | None = None, n: int | None = None, ) -> AgentSet: - """ - Select a subset of agents from the AgentSet based on a filter function and/or quantity limit. + """Select a subset of agents from the AgentSet based on a filter function and/or quantity limit. Args: filter_func (Callable[[Agent], bool], optional): A function that takes an Agent and returns True if the @@ -160,6 +153,7 @@ def select( - If a float between 0 and 1, at most that fraction of original the agents are selected. inplace (bool, optional): If True, modifies the current AgentSet; otherwise, returns a new AgentSet. Defaults to False. agent_type (type[Agent], optional): The class type of the agents to select. Defaults to None, meaning no type filtering is applied. + n (int): deprecated, use at_most instead Returns: AgentSet: A new AgentSet containing the selected agents, unless inplace is True, in which case the current AgentSet is updated. @@ -200,8 +194,7 @@ def agent_generator(filter_func, agent_type, at_most): return AgentSet(agents, self.model) if not inplace else self._update(agents) def shuffle(self, inplace: bool = False) -> AgentSet: - """ - Randomly shuffle the order of agents in the AgentSet. + """Randomly shuffle the order of agents in the AgentSet. Args: inplace (bool, optional): If True, shuffles the agents in the current AgentSet; otherwise, returns a new shuffled AgentSet. Defaults to False. @@ -230,8 +223,7 @@ def sort( ascending: bool = False, inplace: bool = False, ) -> AgentSet: - """ - Sort the agents in the AgentSet based on a specified attribute or custom function. + """Sort the agents in the AgentSet based on a specified attribute or custom function. Args: key (Callable[[Agent], Any] | str): A function or attribute name based on which the agents are sorted. @@ -254,15 +246,14 @@ def sort( def _update(self, agents: Iterable[Agent]): """Update the AgentSet with a new set of agents. + This is a private method primarily used internally by other methods like select, shuffle, and sort. """ - self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents}) return self def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: - """ - Invoke a method or function on each agent in the AgentSet. + """Invoke a method or function on each agent in the AgentSet. Args: method (str, callable): the callable to do on each agent @@ -303,8 +294,7 @@ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: return self def map(self, method: str | Callable, *args, **kwargs) -> list[Any]: - """ - Invoke a method or function on each agent in the AgentSet and return the results. + """Invoke a method or function on each agent in the AgentSet and return the results. Args: method (str, callable): the callable to apply on each agent @@ -335,8 +325,7 @@ def map(self, method: str | Callable, *args, **kwargs) -> list[Any]: return res def agg(self, attribute: str, func: Callable) -> Any: - """ - Aggregate an attribute of all agents in the AgentSet using a specified function. + """Aggregate an attribute of all agents in the AgentSet using a specified function. Args: attribute (str): The name of the attribute to aggregate. @@ -370,8 +359,7 @@ def get( handle_missing="error", default_value=None, ): - """ - Retrieve the specified attribute(s) from each agent in the AgentSet. + """Retrieve the specified attribute(s) from each agent in the AgentSet. Args: attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent. @@ -418,8 +406,7 @@ def get( ) def set(self, attr_name: str, value: Any) -> AgentSet: - """ - Set a specified attribute to a given value for all agents in the AgentSet. + """Set a specified attribute to a given value for all agents in the AgentSet. Args: attr_name (str): The name of the attribute to set. @@ -433,8 +420,7 @@ def set(self, attr_name: str, value: Any) -> AgentSet: return self def __getitem__(self, item: int | slice) -> Agent: - """ - Retrieve an agent or a slice of agents from the AgentSet. + """Retrieve an agent or a slice of agents from the AgentSet. Args: item (int | slice): The index or slice for selecting agents. @@ -445,8 +431,7 @@ def __getitem__(self, item: int | slice) -> Agent: return list(self._agents.keys())[item] def add(self, agent: Agent): - """ - Add an agent to the AgentSet. + """Add an agent to the AgentSet. Args: agent (Agent): The agent to add to the set. @@ -457,8 +442,7 @@ def add(self, agent: Agent): self._agents[agent] = None def discard(self, agent: Agent): - """ - Remove an agent from the AgentSet if it exists. + """Remove an agent from the AgentSet if it exists. This method does not raise an error if the agent is not present. @@ -472,8 +456,7 @@ def discard(self, agent: Agent): del self._agents[agent] def remove(self, agent: Agent): - """ - Remove an agent from the AgentSet. + """Remove an agent from the AgentSet. This method raises an error if the agent is not present. @@ -486,8 +469,7 @@ def remove(self, agent: Agent): del self._agents[agent] def __getstate__(self): - """ - Retrieve the state of the AgentSet for serialization. + """Retrieve the state of the AgentSet for serialization. Returns: dict: A dictionary representing the state of the AgentSet. @@ -495,8 +477,7 @@ def __getstate__(self): return {"agents": list(self._agents.keys()), "model": self.model} def __setstate__(self, state): - """ - Set the state of the AgentSet during deserialization. + """Set the state of the AgentSet during deserialization. Args: state (dict): A dictionary representing the state to restore. @@ -506,8 +487,7 @@ def __setstate__(self, state): @property def random(self) -> Random: - """ - Provide access to the model's random number generator. + """Provide access to the model's random number generator. Returns: Random: The random number generator associated with the model. @@ -515,8 +495,7 @@ def random(self) -> Random: return self.model.random def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy: - """ - Group agents by the specified attribute or return from the callable + """Group agents by the specified attribute or return from the callable. Args: by (Callable, str): used to determine what to group agents by @@ -526,6 +505,7 @@ def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy: * if ``by`` is a str, it should refer to an attribute on the agent and the value of this attribute will be used for grouping result_type (str, optional): The datatype for the resulting groups {"agentset", "list"} + Returns: GroupBy @@ -557,8 +537,7 @@ def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy: class GroupBy: - """Helper class for AgentSet.groupby - + """Helper class for AgentSet.groupby. Attributes: groups (dict): A dictionary with the group_name as key and group as values @@ -566,6 +545,12 @@ class GroupBy: """ def __init__(self, groups: dict[Any, list | AgentSet]): + """Initialize a GroupBy instance. + + Args: + groups (dict): A dictionary with the group_name as key and group as values + + """ self.groups: dict[Any, list | AgentSet] = groups def map(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]: @@ -578,6 +563,8 @@ def map(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]: * if ``method`` is a str, it should refer to a method on the group Additional arguments and keyword arguments will be passed on to the callable. + args: arguments to pass to the callable + kwargs: keyword arguments to pass to the callable Returns: dict with group_name as key and the return of the method as value @@ -595,7 +582,7 @@ def map(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]: return {k: method(v, *args, **kwargs) for k, v in self.groups.items()} def do(self, method: Callable | str, *args, **kwargs) -> GroupBy: - """Apply the specified callable to each group + """Apply the specified callable to each group. Args: method (Callable, str): The callable to apply to each group, @@ -604,6 +591,8 @@ def do(self, method: Callable | str, *args, **kwargs) -> GroupBy: * if ``method`` is a str, it should refer to a method on the group Additional arguments and keyword arguments will be passed on to the callable. + args: arguments to pass to the callable + kwargs: keyword arguments to pass to the callable Returns: the original GroupBy instance @@ -622,8 +611,8 @@ def do(self, method: Callable | str, *args, **kwargs) -> GroupBy: return self - def __iter__(self): + def __iter__(self): # noqa: D105 return iter(self.groups.items()) - def __len__(self): + def __len__(self): # noqa: D105 return len(self.groups) diff --git a/mesa/batchrunner.py b/mesa/batchrunner.py index 4e7fe5cae3e..d50476c6f52 100644 --- a/mesa/batchrunner.py +++ b/mesa/batchrunner.py @@ -1,3 +1,5 @@ +"""batchrunner for running a factorial experiment design over a model.""" + import itertools import multiprocessing from collections.abc import Iterable, Mapping @@ -24,29 +26,19 @@ def batch_run( ) -> list[dict[str, Any]]: """Batch run a mesa model with a set of parameter values. - Parameters - ---------- - model_cls : Type[Model] - The model class to batch-run - parameters : Mapping[str, Union[Any, Iterable[Any]]], - Dictionary with model parameters over which to run the model. You can either pass single values or iterables. - number_processes : int, optional - Number of processes used, by default 1. Set this to None if you want to use all CPUs. - iterations : int, optional - Number of iterations for each parameter combination, by default 1 - data_collection_period : int, optional - Number of steps after which data gets collected, by default -1 (end of episode) - max_steps : int, optional - Maximum number of model steps after which the model halts, by default 1000 - display_progress : bool, optional - Display batch run process, by default True + Args: + model_cls (Type[Model]): The model class to batch-run + parameters (Mapping[str, Union[Any, Iterable[Any]]]): Dictionary with model parameters over which to run the model. You can either pass single values or iterables. + number_processes (int, optional): Number of processes used, by default 1. Set this to None if you want to use all CPUs. + iterations (int, optional): Number of iterations for each parameter combination, by default 1 + data_collection_period (int, optional): Number of steps after which data gets collected, by default -1 (end of episode) + max_steps (int, optional): Maximum number of model steps after which the model halts, by default 1000 + display_progress (bool, optional): Display batch run process, by default True - Returns - ------- - List[Dict[str, Any]] - [description] - """ + Returns: + List[Dict[str, Any]] + """ runs_list = [] run_id = 0 for iteration in range(iterations): @@ -88,7 +80,7 @@ def _make_model_kwargs( parameters : Mapping[str, Union[Any, Iterable[Any]]] Single or multiple values for each model parameter name - Returns + Returns: ------- List[Dict[str, Any]] A list of all kwargs combinations. @@ -128,7 +120,7 @@ def _model_run_func( data_collection_period : int Number of steps after which data gets collected - Returns + Returns: ------- List[Dict[str, Any]] Return model_data, agent_data from the reporters diff --git a/mesa/cookiecutter-mesa/hooks/post_gen_project.py b/mesa/cookiecutter-mesa/hooks/post_gen_project.py index 24c615d517d..1521594704b 100644 --- a/mesa/cookiecutter-mesa/hooks/post_gen_project.py +++ b/mesa/cookiecutter-mesa/hooks/post_gen_project.py @@ -1,3 +1,5 @@ +"""helper module.""" + import glob import os diff --git a/mesa/cookiecutter-mesa/{{cookiecutter.snake}}/{{cookiecutter.snake}}/__init__.py b/mesa/cookiecutter-mesa/{{cookiecutter.snake}}/{{cookiecutter.snake}}/__init__.py index e69de29bb2d..dda29ba61af 100644 --- a/mesa/cookiecutter-mesa/{{cookiecutter.snake}}/{{cookiecutter.snake}}/__init__.py +++ b/mesa/cookiecutter-mesa/{{cookiecutter.snake}}/{{cookiecutter.snake}}/__init__.py @@ -0,0 +1 @@ +"""helper modules.""" diff --git a/mesa/datacollection.py b/mesa/datacollection.py index 5a82e71c20d..bf50be2a723 100644 --- a/mesa/datacollection.py +++ b/mesa/datacollection.py @@ -1,6 +1,4 @@ -""" -Mesa Data Collection Module -=========================== +"""Mesa Data Collection Module. DataCollector is meant to provide a simple, standard way to collect data generated by a Mesa model. It collects three types of data: model-level data, @@ -59,8 +57,8 @@ def __init__( agent_reporters=None, tables=None, ): - """ - Instantiate a DataCollector with lists of model and agent reporters. + """Instantiate a DataCollector with lists of model and agent reporters. + Both model_reporters and agent_reporters accept a dictionary mapping a variable name to either an attribute name, a function, a method of a class/instance, or a function with parameters placed in a list. diff --git a/mesa/experimental/UserParam.py b/mesa/experimental/UserParam.py index 5b342471ddb..9cf5585e802 100644 --- a/mesa/experimental/UserParam.py +++ b/mesa/experimental/UserParam.py @@ -1,7 +1,10 @@ -class UserParam: +"""helper classes.""" + + +class UserParam: # noqa: D101 _ERROR_MESSAGE = "Missing or malformed inputs for '{}' Option '{}'" - def maybe_raise_error(self, param_type, valid): + def maybe_raise_error(self, param_type, valid): # noqa: D102 if valid: return msg = self._ERROR_MESSAGE.format(param_type, self.label) @@ -9,11 +12,9 @@ def maybe_raise_error(self, param_type, valid): class Slider(UserParam): - """ - A number-based slider input with settable increment. + """A number-based slider input with settable increment. Example: - slider_option = Slider("My Slider", value=123, min=10, max=200, step=0.1) Args: @@ -34,6 +35,16 @@ def __init__( step=1, dtype=None, ): + """Slider class. + + Args: + label: The displayed label in the UI + value: The initial value of the slider + min: The minimum possible value of the slider + max: The maximum possible value of the slider + step: The step between min and max for a range of possible values + dtype: either int or float + """ self.label = label self.value = value self.min = min @@ -52,5 +63,5 @@ def __init__( def _check_values_are_float(self, value, min, max, step): return any(isinstance(n, float) for n in (value, min, max, step)) - def get(self, attr): + def get(self, attr): # noqa: D102 return getattr(self, attr) diff --git a/mesa/experimental/__init__.py b/mesa/experimental/__init__.py index 753b4bd7985..a48c498f6f9 100644 --- a/mesa/experimental/__init__.py +++ b/mesa/experimental/__init__.py @@ -1,3 +1,5 @@ +"""Experimental init.""" + from mesa.experimental import cell_space from .solara_viz import JupyterViz, Slider, SolaraViz, make_text diff --git a/mesa/experimental/cell_space/__init__.py b/mesa/experimental/cell_space/__init__.py index 8db71025616..33bdbe9b76e 100644 --- a/mesa/experimental/cell_space/__init__.py +++ b/mesa/experimental/cell_space/__init__.py @@ -1,3 +1,5 @@ +"""Cell spaces.""" + from mesa.experimental.cell_space.cell import Cell from mesa.experimental.cell_space.cell_agent import CellAgent from mesa.experimental.cell_space.cell_collection import CellCollection diff --git a/mesa/experimental/cell_space/cell.py b/mesa/experimental/cell_space/cell.py index 55264f68daa..4e9a8f156cd 100644 --- a/mesa/experimental/cell_space/cell.py +++ b/mesa/experimental/cell_space/cell.py @@ -1,3 +1,5 @@ +"""The Cell in a cell space.""" + from __future__ import annotations from functools import cache @@ -46,10 +48,10 @@ def __init__( capacity: float | None = None, random: Random | None = None, ) -> None: - """ " + """Initialise the cell. Args: - coordinate: + coordinate: coordinates of the cell capacity (int) : the capacity of the cell. If None, the capacity is infinite random (Random) : the random number generator to use @@ -116,12 +118,13 @@ def is_full(self) -> bool: """Returns a bool of the contents of a cell.""" return len(self.agents) == self.capacity - def __repr__(self): + def __repr__(self): # noqa return f"Cell({self.coordinate}, {self.agents})" # FIXME: Revisit caching strategy on methods @cache # noqa: B019 def neighborhood(self, radius=1, include_center=False): + """Returns a list of all neighboring cells.""" return CellCollection( self._neighborhood(radius=radius, include_center=include_center), random=self.random, diff --git a/mesa/experimental/cell_space/cell_agent.py b/mesa/experimental/cell_space/cell_agent.py index 5f1cca5cbcc..76a7d16b866 100644 --- a/mesa/experimental/cell_space/cell_agent.py +++ b/mesa/experimental/cell_space/cell_agent.py @@ -1,3 +1,5 @@ +"""An agent with movement methods for cell spaces.""" + from __future__ import annotations from typing import TYPE_CHECKING @@ -9,8 +11,7 @@ class CellAgent(Agent): - """Cell Agent is an extension of the Agent class and adds behavior for moving in discrete spaces - + """Cell Agent is an extension of the Agent class and adds behavior for moving in discrete spaces. Attributes: unique_id (int): A unique identifier for this agent. @@ -20,17 +21,21 @@ class CellAgent(Agent): """ def __init__(self, model: Model) -> None: - """ - Create a new agent. + """Create a new agent. Args: - unique_id (int): A unique identifier for this agent. model (Model): The model instance in which the agent exists. """ super().__init__(model) self.cell: Cell | None = None def move_to(self, cell) -> None: + """Move agent to cell. + + Args: + cell: cell to which agent is to move + + """ if self.cell is not None: self.cell.remove_agent(self) self.cell = cell diff --git a/mesa/experimental/cell_space/cell_collection.py b/mesa/experimental/cell_space/cell_collection.py index 114301db100..14832d511be 100644 --- a/mesa/experimental/cell_space/cell_collection.py +++ b/mesa/experimental/cell_space/cell_collection.py @@ -1,3 +1,5 @@ +"""CellCollection class.""" + from __future__ import annotations import itertools @@ -14,7 +16,7 @@ class CellCollection(Generic[T]): - """An immutable collection of cells + """An immutable collection of cells. Attributes: cells (List[Cell]): The list of cells this collection represents @@ -28,6 +30,12 @@ def __init__( cells: Mapping[T, list[CellAgent]] | Iterable[T], random: Random | None = None, ) -> None: + """Initialize a CellCollection. + + Args: + cells: cells to add to the collection + random: a seeded random number generator. + """ if isinstance(cells, dict): self._cells = cells else: @@ -40,34 +48,52 @@ def __init__( random = Random() # FIXME self.random = random - def __iter__(self): + def __iter__(self): # noqa return iter(self._cells) - def __getitem__(self, key: T) -> Iterable[CellAgent]: + def __getitem__(self, key: T) -> Iterable[CellAgent]: # noqa return self._cells[key] # @cached_property - def __len__(self) -> int: + def __len__(self) -> int: # noqa return len(self._cells) - def __repr__(self): + def __repr__(self): # noqa return f"CellCollection({self._cells})" @cached_property - def cells(self) -> list[T]: + def cells(self) -> list[T]: # noqa return list(self._cells.keys()) @property - def agents(self) -> Iterable[CellAgent]: + def agents(self) -> Iterable[CellAgent]: # noqa return itertools.chain.from_iterable(self._cells.values()) def select_random_cell(self) -> T: + """Select a random cell.""" return self.random.choice(self.cells) def select_random_agent(self) -> CellAgent: + """Select a random agent. + + Returns: + CellAgent instance + + + """ return self.random.choice(list(self.agents)) def select(self, filter_func: Callable[[T], bool] | None = None, n=0): + """Select cells based on filter function. + + Args: + filter_func: filter function + n: number of cells to select + + Returns: + CellCollection + + """ # FIXME: n is not considered if filter_func is None and n == 0: return self diff --git a/mesa/experimental/cell_space/discrete_space.py b/mesa/experimental/cell_space/discrete_space.py index 6c92c320cb1..f9a35de5160 100644 --- a/mesa/experimental/cell_space/discrete_space.py +++ b/mesa/experimental/cell_space/discrete_space.py @@ -1,3 +1,5 @@ +"""DiscreteSpace base class.""" + from __future__ import annotations from functools import cached_property @@ -28,6 +30,13 @@ def __init__( cell_klass: type[T] = Cell, random: Random | None = None, ): + """Instantiate a DiscreteSpace. + + Args: + capacity: capacity of cells + cell_klass: base class for all cells + random: random number generator + """ super().__init__() self.capacity = capacity self._cells: dict[tuple[int, ...], T] = {} @@ -40,25 +49,27 @@ def __init__( self._empties_initialized = False @property - def cutoff_empties(self): + def cutoff_empties(self): # noqa return 7.953 * len(self._cells) ** 0.384 def _connect_single_cell(self, cell: T): ... @cached_property def all_cells(self): + """Return all cells in space.""" return CellCollection({cell: cell.agents for cell in self._cells.values()}) - def __iter__(self): + def __iter__(self): # noqa return iter(self._cells.values()) - def __getitem__(self, key): + def __getitem__(self, key): # noqa return self._cells[key] @property def empties(self) -> CellCollection: + """Return all empty in spaces.""" return self.all_cells.select(lambda cell: cell.is_empty) def select_random_empty_cell(self) -> T: - """select random empty cell""" + """Select random empty cell.""" return self.random.choice(list(self.empties)) diff --git a/mesa/experimental/cell_space/grid.py b/mesa/experimental/cell_space/grid.py index f08657d2107..ae33fbe49f6 100644 --- a/mesa/experimental/cell_space/grid.py +++ b/mesa/experimental/cell_space/grid.py @@ -1,3 +1,5 @@ +"""Various Grid Spaces.""" + from __future__ import annotations from collections.abc import Sequence @@ -11,7 +13,7 @@ class Grid(DiscreteSpace, Generic[T]): - """Base class for all grid classes + """Base class for all grid classes. Attributes: dimensions (Sequence[int]): the dimensions of the grid @@ -30,6 +32,15 @@ def __init__( random: Random | None = None, cell_klass: type[T] = Cell, ) -> None: + """Initialise the grid class. + + Args: + dimensions: the dimensions of the space + torus: whether the space wraps + capacity: capacity of the grid cell + random: a random number generator + cell_klass: the base class to use for the cells + """ super().__init__(capacity=capacity, random=random, cell_klass=cell_klass) self.torus = torus self.dimensions = dimensions @@ -63,7 +74,7 @@ def _validate_parameters(self): if self.capacity is not None and not isinstance(self.capacity, float | int): raise ValueError("Capacity must be a number or None.") - def select_random_empty_cell(self) -> T: + def select_random_empty_cell(self) -> T: # noqa # FIXME:: currently just a simple boolean to control behavior # FIXME:: basically if grid is close to 99% full, creating empty list can be faster # FIXME:: note however that the old results don't apply because in this implementation @@ -176,6 +187,8 @@ def _connect_cells_nd(self) -> None: class HexGrid(Grid[T]): + """A Grid with hexagonal tilling of the space.""" + def _connect_cells_2d(self) -> None: # fmt: off even_offsets = [ diff --git a/mesa/experimental/cell_space/network.py b/mesa/experimental/cell_space/network.py index 3983287e4ef..3920511c22e 100644 --- a/mesa/experimental/cell_space/network.py +++ b/mesa/experimental/cell_space/network.py @@ -1,3 +1,5 @@ +"""A Network grid.""" + from random import Random from typing import Any @@ -6,7 +8,7 @@ class Network(DiscreteSpace): - """A networked discrete space""" + """A networked discrete space.""" def __init__( self, @@ -15,13 +17,13 @@ def __init__( random: Random | None = None, cell_klass: type[Cell] = Cell, ) -> None: - """A Networked grid + """A Networked grid. Args: G: a NetworkX Graph instance. capacity (int) : the capacity of the cell - random (Random): - CellKlass (type[Cell]): The base Cell class to use in the Network + random (Random): a random number generator + cell_klass (type[Cell]): The base Cell class to use in the Network """ super().__init__(capacity=capacity, random=random, cell_klass=cell_klass) diff --git a/mesa/experimental/cell_space/voronoi.py b/mesa/experimental/cell_space/voronoi.py index 4395ca4ead8..0d712f3ea16 100644 --- a/mesa/experimental/cell_space/voronoi.py +++ b/mesa/experimental/cell_space/voronoi.py @@ -1,3 +1,5 @@ +"""Support for Voronoi meshed grids.""" + from collections.abc import Sequence from itertools import combinations from random import Random @@ -9,16 +11,17 @@ class Delaunay: - """ - Class to compute a Delaunay triangulation in 2D + """Class to compute a Delaunay triangulation in 2D. + ref: http://github.com/jmespadero/pyDelaunay2D """ def __init__(self, center: tuple = (0, 0), radius: int = 9999) -> None: - """ - Init and create a new frame to contain the triangulation - center: Optional position for the center of the frame. Default (0,0) - radius: Optional distance from corners to the center. + """Init and create a new frame to contain the triangulation. + + Args: + center: Optional position for the center of the frame. Default (0,0) + radius: Optional distance from corners to the center. """ center = np.asarray(center) # Create coordinates for the corners of the frame @@ -44,9 +47,7 @@ def __init__(self, center: tuple = (0, 0), radius: int = 9999) -> None: self.circles[t] = self._circumcenter(t) def _circumcenter(self, triangle: list) -> tuple: - """ - Compute circumcenter and circumradius of a triangle in 2D. - """ + """Compute circumcenter and circumradius of a triangle in 2D.""" points = np.asarray([self.coords[v] for v in triangle]) points2 = np.dot(points, points.T) a = np.bmat([[2 * points2, [[1], [1], [1]]], [[[1, 1, 1, 0]]]]) @@ -60,16 +61,12 @@ def _circumcenter(self, triangle: list) -> tuple: return (center, radius) def _in_circle(self, triangle: list, point: list) -> bool: - """ - Check if point p is inside of precomputed circumcircle of triangle. - """ + """Check if point p is inside of precomputed circumcircle of triangle.""" center, radius = self.circles[triangle] return np.sum(np.square(center - point)) <= radius def add_point(self, point: Sequence) -> None: - """ - Add a point to the current DT, and refine it using Bowyer-Watson. - """ + """Add a point to the current DT, and refine it using Bowyer-Watson.""" point_index = len(self.coords) self.coords.append(np.asarray(point)) @@ -121,9 +118,7 @@ def add_point(self, point: Sequence) -> None: self.triangles[triangle][2] = new_triangles[(i - 1) % n] # previous def export_triangles(self) -> list: - """ - Export the current list of Delaunay triangles - """ + """Export the current list of Delaunay triangles.""" triangles_list = [ (a - 4, b - 4, c - 4) for (a, b, c) in self.triangles @@ -132,9 +127,7 @@ def export_triangles(self) -> list: return triangles_list def export_voronoi_regions(self): - """ - Export coordinates and regions of Voronoi diagram as indexed data. - """ + """Export coordinates and regions of Voronoi diagram as indexed data.""" use_vertex = {i: [] for i in range(len(self.coords))} vor_coors = [] index = {} @@ -163,11 +156,13 @@ def export_voronoi_regions(self): return vor_coors, regions -def round_float(x: float) -> int: +def round_float(x: float) -> int: # noqa return int(x * 500) class VoronoiGrid(DiscreteSpace): + """Voronoi meshed GridSpace.""" + triangulation: Delaunay voronoi_coordinates: list regions: list @@ -181,8 +176,7 @@ def __init__( capacity_function: callable = round_float, cell_coloring_property: str | None = None, ) -> None: - """ - A Voronoi Tessellation Grid. + """A Voronoi Tessellation Grid. Given a set of points, this class creates a grid where a cell is centered in each point, its neighbors are given by Voronoi Tessellation cells neighbors @@ -192,7 +186,7 @@ def __init__( centroids_coordinates: coordinates of centroids to build the tessellation space capacity (int) : capacity of the cells in the discrete space random (Random): random number generator - CellKlass (type[Cell]): type of cell class + cell_klass (type[Cell]): type of cell class capacity_function (Callable): function to compute (int) capacity according to (float) area cell_coloring_property (str): voronoi visualization polygon fill property """ @@ -215,9 +209,7 @@ def __init__( self._build_cell_polygons() def _connect_cells(self) -> None: - """ - Connect cells to neighbors based on given centroids and using Delaunay Triangulation - """ + """Connect cells to neighbors based on given centroids and using Delaunay Triangulation.""" self.triangulation = Delaunay() for centroid in self.centroids_coordinates: self.triangulation.add_point(centroid) diff --git a/mesa/experimental/components/altair.py b/mesa/experimental/components/altair.py index f9d1a81c172..aaf2f2a1a4c 100644 --- a/mesa/experimental/components/altair.py +++ b/mesa/experimental/components/altair.py @@ -1,3 +1,5 @@ +"""Altair components.""" + import contextlib import solara @@ -8,6 +10,14 @@ @solara.component def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): + """A component that renders a Space using Altair. + + Args: + model: a model instance + agent_portrayal: agent portray specification + dependencies: optional list of dependencies (currently not used) + + """ space = getattr(model, "grid", None) if space is None: # Sometimes the space is defined as model.space instead of model.grid diff --git a/mesa/experimental/components/matplotlib.py b/mesa/experimental/components/matplotlib.py index c35fe97bc5e..bb3b9854193 100644 --- a/mesa/experimental/components/matplotlib.py +++ b/mesa/experimental/components/matplotlib.py @@ -1,3 +1,5 @@ +"""Support for using matplotlib to draw spaces.""" + from collections import defaultdict import networkx as nx @@ -11,6 +13,14 @@ @solara.component def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = None): + """A component for rendering a space using Matplotlib. + + Args: + model: a model instance + agent_portrayal: a specification of how to portray an agent. + dependencies: list of dependencies. + + """ space_fig = Figure() space_ax = space_fig.subplots() space = getattr(model, "grid", None) @@ -205,6 +215,14 @@ def portray(g): @solara.component def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): + """A solara component for creating a matplotlib figure. + + Args: + model: Model instance + measure: measure to plot + dependencies: list of additional dependencies + + """ fig = Figure() ax = fig.subplots() df = model.datacollector.get_model_vars_dataframe() diff --git a/mesa/experimental/devs/__init__.py b/mesa/experimental/devs/__init__.py index b6dca39e29c..fb4563b07ae 100644 --- a/mesa/experimental/devs/__init__.py +++ b/mesa/experimental/devs/__init__.py @@ -1,3 +1,5 @@ +"""Support for event scheduling.""" + from .eventlist import Priority, SimulationEvent from .simulator import ABMSimulator, DEVSimulator diff --git a/mesa/experimental/devs/eventlist.py b/mesa/experimental/devs/eventlist.py index abff889a0e6..3ecd182a349 100644 --- a/mesa/experimental/devs/eventlist.py +++ b/mesa/experimental/devs/eventlist.py @@ -1,3 +1,5 @@ +"""Eventlist which is at the core of event scheduling.""" + from __future__ import annotations import itertools @@ -10,15 +12,17 @@ class Priority(IntEnum): + """Enumeration of priority levels.""" + LOW = 10 DEFAULT = 5 HIGH = 1 class SimulationEvent: - """A simulation event + """A simulation event. - the callable is wrapped using weakref, so there is no need to explicitly cancel event if e.g., an agent + The callable is wrapped using weakref, so there is no need to explicitly cancel event if e.g., an agent is removed from the simulation. Attributes: @@ -34,7 +38,7 @@ class SimulationEvent: _ids = itertools.count() @property - def CANCELED(self) -> bool: + def CANCELED(self) -> bool: # noqa: D102 return self._canceled def __init__( @@ -45,6 +49,15 @@ def __init__( function_args: list[Any] | None = None, function_kwargs: dict[str, Any] | None = None, ) -> None: + """Initialize a simulation event. + + Args: + time: the instant of time of the simulation event + function: the callable to invoke + priority: the priority of the event + function_args: arguments for callable + function_kwargs: keyword arguments for the callable + """ super().__init__() if not callable(function): raise Exception() @@ -64,20 +77,20 @@ def __init__( self.function_kwargs = function_kwargs if function_kwargs else {} def execute(self): - """execute this event""" + """Execute this event.""" if not self._canceled: fn = self.fn() if fn is not None: fn(*self.function_args, **self.function_kwargs) def cancel(self) -> None: - """cancel this event""" + """Cancel this event.""" self._canceled = True self.fn = None self.function_args = [] self.function_kwargs = {} - def __lt__(self, other): + def __lt__(self, other): # noqa # Define a total ordering for events to be used by the heapq return (self.time, self.priority, self.unique_id) < ( other.time, @@ -87,30 +100,31 @@ def __lt__(self, other): class EventList: - """An event list + """An event list. This is a heap queue sorted list of events. Events are always removed from the left, so heapq is a performant and appropriate data structure. Events are sorted based on their time stamp, their priority, and their unique_id as a tie-breaker, guaranteeing a complete ordering. + """ def __init__(self): + """Initialize an event list.""" self._events: list[SimulationEvent] = [] heapify(self._events) def add_event(self, event: SimulationEvent): - """Add the event to the event list + """Add the event to the event list. Args: event (SimulationEvent): The event to be added """ - heappush(self._events, event) def peak_ahead(self, n: int = 1) -> list[SimulationEvent]: - """Look at the first n non-canceled event in the event list + """Look at the first n non-canceled event in the event list. Args: n (int): The number of events to look ahead @@ -139,7 +153,7 @@ def peak_ahead(self, n: int = 1) -> list[SimulationEvent]: return peek def pop_event(self) -> SimulationEvent: - """pop the first element from the event list""" + """Pop the first element from the event list.""" while self._events: event = heappop(self._events) if not event.CANCELED: @@ -147,16 +161,17 @@ def pop_event(self) -> SimulationEvent: raise IndexError("Event list is empty") def is_empty(self) -> bool: + """Return whether the event list is empty.""" return len(self) == 0 - def __contains__(self, event: SimulationEvent) -> bool: + def __contains__(self, event: SimulationEvent) -> bool: # noqa return event in self._events - def __len__(self) -> int: + def __len__(self) -> int: # noqa return len(self._events) def __repr__(self) -> str: - """Return a string representation of the event list""" + """Return a string representation of the event list.""" events_str = ", ".join( [ f"Event(time={e.time}, priority={e.priority}, id={e.unique_id})" @@ -167,7 +182,12 @@ def __repr__(self) -> str: return f"EventList([{events_str}])" def remove(self, event: SimulationEvent) -> None: - """remove an event from the event list""" + """Remove an event from the event list. + + Args: + event (SimulationEvent): The event to be removed + + """ # we cannot simply remove items from _eventlist because this breaks # heap structure invariant. So, we use a form of lazy deletion. # SimEvents have a CANCELED flag that we set to True, while popping and peak_ahead @@ -175,4 +195,5 @@ def remove(self, event: SimulationEvent) -> None: event.cancel() def clear(self): + """Clear the event list.""" self._events.clear() diff --git a/mesa/experimental/devs/examples/epstein_civil_violence.py b/mesa/experimental/devs/examples/epstein_civil_violence.py index c976bb73898..ce6b835e826 100644 --- a/mesa/experimental/devs/examples/epstein_civil_violence.py +++ b/mesa/experimental/devs/examples/epstein_civil_violence.py @@ -1,3 +1,5 @@ +"""Epstein civil violence example using ABMSimulator.""" + import enum import math @@ -7,21 +9,32 @@ class EpsteinAgent(Agent): + """Epstein Agent.""" + def __init__(self, model, vision, movement): + """Initialize the agent. + + Args: + model: a model instance + vision: size of neighborhood + movement: boolean whether agent can move or not + """ super().__init__(model) self.vision = vision self.movement = movement class AgentState(enum.IntEnum): + """Agent states.""" + QUIESCENT = enum.auto() ARRESTED = enum.auto() ACTIVE = enum.auto() class Citizen(EpsteinAgent): - """ - A member of the general population, may or may not be in active rebellion. + """A member of the general population, may or may not be in active rebellion. + Summary of rule: If grievance - risk > threshold, rebel. Attributes: @@ -55,10 +68,13 @@ def __init__( threshold, arrest_prob_constant, ): - """ - Create a new Citizen. + """Create a new Citizen. + Args: model : model instance + vision: number of cells in each direction (N, S, E and W) that + agent can inspect. Exogenous. + movement: whether agent can move or not hardship: Agent's 'perceived hardship (i.e., physical or economic privation).' Exogenous, drawn from U(0,1). regime_legitimacy: Agent's perception of regime legitimacy, equal @@ -66,8 +82,8 @@ def __init__( risk_aversion: Exogenous, drawn from U(0,1). threshold: if (grievance - (risk_aversion * arrest_probability)) > threshold, go/remain Active - vision: number of cells in each direction (N, S, E and W) that - agent can inspect. Exogenous. + arrest_prob_constant : agent's assessment of arrest probability + """ super().__init__(model, vision, movement) self.hardship = hardship @@ -80,9 +96,7 @@ def __init__( self.arrest_prob_constant = arrest_prob_constant def step(self): - """ - Decide whether to activate, then move if applicable. - """ + """Decide whether to activate, then move if applicable.""" self.update_neighbors() self.update_estimated_arrest_probability() net_risk = self.risk_aversion * self.arrest_probability @@ -95,9 +109,7 @@ def step(self): self.model.grid.move_agent(self, new_pos) def update_neighbors(self): - """ - Look around and see who my neighbors are - """ + """Look around and see who my neighbors are.""" self.neighborhood = self.model.grid.get_neighborhood( self.pos, moore=True, radius=self.vision ) @@ -107,10 +119,7 @@ def update_neighbors(self): ] def update_estimated_arrest_probability(self): - """ - Based on the ratio of cops to actives in my neighborhood, estimate the - p(Arrest | I go active). - """ + """Based on the ratio of cops to actives in my neighborhood, estimate the p(Arrest | I go active).""" cops_in_vision = len([c for c in self.neighbors if isinstance(c, Cop)]) actives_in_vision = 1.0 # citizen counts herself for c in self.neighbors: @@ -121,18 +130,25 @@ def update_estimated_arrest_probability(self): ) def sent_to_jail(self, value): + """Sent agent to jail. + + Args: + value: duration of jail sentence + + """ self.model.active_agents.remove(self) self.condition = AgentState.ARRESTED self.model.simulator.schedule_event_relative(self.release_from_jail, value) def release_from_jail(self): + """Release agent from jail.""" self.model.active_agents.add(self) self.condition = AgentState.QUIESCENT class Cop(EpsteinAgent): - """ - A cop for life. No defection. + """A cop for life. No defection. + Summary of rule: Inspect local vision and arrest a random active agent. Attributes: @@ -143,14 +159,19 @@ class Cop(EpsteinAgent): """ def __init__(self, model, vision, movement, max_jail_term): + """Initialize a Cop agent. + + Args: + model: a model instance + vision: size of neighborhood + movement: whether agent can move or not + max_jail_term: maximum jail sentence + """ super().__init__(model, vision, movement) self.max_jail_term = max_jail_term def step(self): - """ - Inspect local vision and arrest a random active agent. Move if - applicable. - """ + """Inspect local vision and arrest a random active agent. Move if applicable.""" self.update_neighbors() active_neighbors = [] for agent in self.neighbors: @@ -164,9 +185,7 @@ def step(self): self.model.grid.move_agent(self, new_pos) def update_neighbors(self): - """ - Look around and see who my neighbors are. - """ + """Look around and see who my neighbors are.""" self.neighborhood = self.model.grid.get_neighborhood( self.pos, moore=True, radius=self.vision ) @@ -177,9 +196,8 @@ def update_neighbors(self): class EpsteinCivilViolence(Model): - """ - Model 1 from "Modeling civil violence: An agent-based computational - approach," by Joshua Epstein. + """Model 1 from "Modeling civil violence: An agent-based computational approach," by Joshua Epstein. + http://www.pnas.org/content/99/suppl_3/7243.full Attributes: height: grid height @@ -218,6 +236,23 @@ def __init__( max_iters=1000, seed=None, ): + """Initialize the Eppstein civil violence model. + + Args: + width: the width of the grid + height: the height of the grid + citizen_density: density of citizens + cop_density: density of cops + citizen_vision: size of citizen vision + cop_vision: size of cop vision + legitimacy: perceived legitimacy + max_jail_term: maximum jail term + active_threshold: threshold for citizen to become active + arrest_prob_constant: arrest probability + movement: allow agent movement or not + max_iters: number of iterations + seed: seed for random number generator + """ super().__init__(seed) if cop_density + citizen_density > 1: raise ValueError("Cop density + citizen density must be less than 1") @@ -257,6 +292,7 @@ def __init__( self.active_agents = self.agents def step(self): + """Run one step of the model.""" self.active_agents.shuffle(inplace=True).do("step") diff --git a/mesa/experimental/devs/examples/wolf_sheep.py b/mesa/experimental/devs/examples/wolf_sheep.py index cd81cb2ff04..b76a35bf057 100644 --- a/mesa/experimental/devs/examples/wolf_sheep.py +++ b/mesa/experimental/devs/examples/wolf_sheep.py @@ -1,20 +1,22 @@ -""" -Wolf-Sheep Predation Model -================================ - -Replication of the model found in NetLogo: - Wilensky, U. (1997). NetLogo Wolf Sheep Predation model. - http://ccl.northwestern.edu/netlogo/models/WolfSheepPredation. - Center for Connected Learning and Computer-Based Modeling, - Northwestern University, Evanston, IL. -""" +"""Example of using ABM simulator for Wolf-Sheep Predation Model.""" import mesa from mesa.experimental.devs.simulator import ABMSimulator class Animal(mesa.Agent): + """Base Animal class.""" + def __init__(self, model, moore, energy, p_reproduce, energy_from_food): + """Initialize Animal instance. + + Args: + model: a model instance + moore: using moore grid or not + energy: initial energy + p_reproduce: probability of reproduction + energy_from_food: energy gained from 1 unit of food + """ super().__init__(model) self.energy = energy self.p_reproduce = p_reproduce @@ -22,12 +24,14 @@ def __init__(self, model, moore, energy, p_reproduce, energy_from_food): self.moore = moore def random_move(self): + """Move to random neighboring cell.""" next_moves = self.model.grid.get_neighborhood(self.pos, self.moore, True) next_move = self.random.choice(next_moves) # Now move: self.model.grid.move_agent(self, next_move) def spawn_offspring(self): + """Create offspring.""" self.energy /= 2 offspring = self.__class__( self.model, @@ -38,13 +42,15 @@ def spawn_offspring(self): ) self.model.grid.place_agent(offspring, self.pos) - def feed(self): ... + def feed(self): ... # noqa: D102 def die(self): + """Die.""" self.model.grid.remove_agent(self) self.remove() def step(self): + """Execute one step of the agent.""" self.random_move() self.energy -= 1 @@ -57,13 +63,10 @@ def step(self): class Sheep(Animal): - """ - A sheep that walks around, reproduces (asexually) and gets eaten. - - The init is the same as the RandomWalker. - """ + """A sheep that walks around, reproduces (asexually) and gets eaten.""" def feed(self): + """Eat grass and gain energy.""" # If there is grass available, eat it agents = self.model.grid.get_cell_list_contents(self.pos) grass_patch = next(obj for obj in agents if isinstance(obj, GrassPatch)) @@ -73,11 +76,10 @@ def feed(self): class Wolf(Animal): - """ - A wolf that walks around, reproduces (asexually) and eats sheep. - """ + """A wolf that walks around, reproduces (asexually) and eats sheep.""" def feed(self): + """Eat wolf and gain energy.""" agents = self.model.grid.get_cell_list_contents(self.pos) sheep = [obj for obj in agents if isinstance(obj, Sheep)] if len(sheep) > 0: @@ -89,12 +91,10 @@ def feed(self): class GrassPatch(mesa.Agent): - """ - A patch of grass that grows at a fixed rate and it is eaten by sheep - """ + """A patch of grass that grows at a fixed rate and it is eaten by sheep.""" @property - def fully_grown(self) -> bool: + def fully_grown(self) -> bool: # noqa: D102 return self._fully_grown @fully_grown.setter @@ -109,12 +109,13 @@ def fully_grown(self, value: bool): ) def __init__(self, model, fully_grown, countdown, grass_regrowth_time): - """ - Creates a new patch of grass + """Creates a new patch of grass. Args: - grown: (boolean) Whether the patch of grass is fully grown or not + model: a model instance + fully_grown: (boolean) Whether the patch of grass is fully grown or not countdown: Time for the patch of grass to be fully grown again + grass_regrowth_time: regrowth time for the grass """ super().__init__(model) self._fully_grown = fully_grown @@ -125,13 +126,12 @@ def __init__(self, model, fully_grown, countdown, grass_regrowth_time): setattr, countdown, function_args=[self, "fully_grown", True] ) - def set_fully_grown(self): + def set_fully_grown(self): # noqa self.fully_grown = True class WolfSheep(mesa.Model): - """ - Wolf-Sheep Predation Model + """Wolf-Sheep Predation Model. A model for simulating wolf and sheep (predator-prey) ecosystem modelling. """ @@ -151,10 +151,11 @@ def __init__( simulator=None, seed=None, ): - """ - Create a new Wolf-Sheep model with the given parameters. + """Create a new Wolf-Sheep model with the given parameters. Args: + height: height of the grid + width: width of the grid initial_sheep: Number of sheep to start with initial_wolves: Number of wolves to start with sheep_reproduce: Probability of each sheep reproducing each step @@ -164,7 +165,9 @@ def __init__( grass_regrowth_time: How long it takes for a grass patch to regrow once it is eaten sheep_gain_from_food: Energy sheep gain from grass, if enabled. - moore: + moore: whether to use moore or von Neumann grid + simulator: Simulator to use for simulating wolf and sheep + seed: Random seed """ super().__init__(seed=seed) # Set parameters @@ -226,6 +229,7 @@ def __init__( self.grid.place_agent(patch, pos) def step(self): + """Perform one step of the model.""" self.agents_by_type[Sheep].shuffle(inplace=True).do("step") self.agents_by_type[Wolf].shuffle(inplace=True).do("step") diff --git a/mesa/experimental/devs/simulator.py b/mesa/experimental/devs/simulator.py index 74f018e883d..8967c19ef8e 100644 --- a/mesa/experimental/devs/simulator.py +++ b/mesa/experimental/devs/simulator.py @@ -1,3 +1,10 @@ +"""Provides several simulator classes. + +A Simulator is responsible for executing a simulation model. It controls time advancement and enables event scheduling. + + +""" + from __future__ import annotations import numbers @@ -27,6 +34,12 @@ class Simulator: # TODO: add experimentation support def __init__(self, time_unit: type, start_time: int | float): + """Initialize a Simulator instance. + + Args: + time_unit: type of the smulaiton time + start_time: the starttime of the simulator + """ # should model run in a separate thread, # and we can then interact with start, stop, run_until, and step? self.event_list = EventList() @@ -36,10 +49,10 @@ def __init__(self, time_unit: type, start_time: int | float): self.time = self.start_time self.model = None - def check_time_unit(self, time: int | float) -> bool: ... + def check_time_unit(self, time: int | float) -> bool: ... # noqa: D102 def setup(self, model: Model) -> None: - """Set up the simulator with the model to simulate + """Set up the simulator with the model to simulate. Args: model (Model): The model to simulate @@ -49,12 +62,13 @@ def setup(self, model: Model) -> None: self.model = model def reset(self): - """Reset the simulator by clearing the event list and removing the model to simulate""" + """Reset the simulator by clearing the event list and removing the model to simulate.""" self.event_list.clear() self.model = None self.time = self.start_time def run_until(self, end_time: int | float) -> None: + """Run the simulator until the end time.""" while True: try: event = self.event_list.pop_event() @@ -71,7 +85,7 @@ def run_until(self, end_time: int | float) -> None: break def run_for(self, time_delta: int | float): - """run the simulator for the specified time delta + """Run the simulator for the specified time delta. Args: time_delta (float| int): The time delta. The simulator is run from the current time to the current time @@ -88,7 +102,7 @@ def schedule_event_now( function_args: list[Any] | None = None, function_kwargs: dict[str, Any] | None = None, ) -> SimulationEvent: - """Schedule event for the current time instant + """Schedule event for the current time instant. Args: function (Callable): The callable to execute for this event @@ -116,7 +130,7 @@ def schedule_event_absolute( function_args: list[Any] | None = None, function_kwargs: dict[str, Any] | None = None, ) -> SimulationEvent: - """Schedule event for the specified time instant + """Schedule event for the specified time instant. Args: function (Callable): The callable to execute for this event @@ -150,7 +164,7 @@ def schedule_event_relative( function_args: list[Any] | None = None, function_kwargs: dict[str, Any] | None = None, ) -> SimulationEvent: - """Schedule event for the current time plus the time delta + """Schedule event for the current time plus the time delta. Args: function (Callable): The callable to execute for this event @@ -174,13 +188,12 @@ def schedule_event_relative( return event def cancel_event(self, event: SimulationEvent) -> None: - """remove the event from the event list + """Remove the event from the event list. Args: event (SimulationEvent): The simulation event to remove """ - self.event_list.remove(event) def _schedule_event(self, event: SimulationEvent): @@ -204,13 +217,29 @@ class ABMSimulator(Simulator): """ def __init__(self): + """Initialize a ABM simulator.""" super().__init__(int, 0) def setup(self, model): + """Set up the simulator with the model to simulate. + + Args: + model (Model): The model to simulate + + """ super().setup(model) self.schedule_event_now(self.model.step, priority=Priority.HIGH) def check_time_unit(self, time) -> bool: + """Check whether the time is of the correct unit. + + Args: + time (int | float): the time + + Returns: + bool: whether the time is of the correct unit + + """ if isinstance(time, int): return True if isinstance(time, float): @@ -225,9 +254,9 @@ def schedule_event_next_tick( function_args: list[Any] | None = None, function_kwargs: dict[str, Any] | None = None, ) -> SimulationEvent: - """Schedule a SimulationEvent for the next tick + """Schedule a SimulationEvent for the next tick. - Args + Args: function (Callable): the callable to execute priority (Priority): the priority of the event function_args (List[Any]): List of arguments to pass to the callable @@ -243,7 +272,7 @@ def schedule_event_next_tick( ) def run_until(self, end_time: int) -> None: - """run the simulator up to and included the specified end time + """Run the simulator up to and included the specified end time. Args: end_time (float| int): The end_time delta. The simulator is until the specified end time @@ -270,7 +299,7 @@ def run_until(self, end_time: int) -> None: break def run_for(self, time_delta: int): - """run the simulator for the specified time delta + """Run the simulator for the specified time delta. Args: time_delta (float| int): The time delta. The simulator is run from the current time to the current time @@ -282,13 +311,24 @@ def run_for(self, time_delta: int): class DEVSimulator(Simulator): - """A simulator where the unit of time is a float. Can be used for full-blown discrete event simulating using - event scheduling. + """A simulator where the unit of time is a float. + + Can be used for full-blown discrete event simulating using event scheduling. """ def __init__(self): + """Initialize a DEVS simulator.""" super().__init__(float, 0.0) def check_time_unit(self, time) -> bool: + """Check whether the time is of the correct unit. + + Args: + time (float): the time + + Returns: + bool: whether the time is of the correct unit + + """ return isinstance(time, numbers.Number) diff --git a/mesa/experimental/solara_viz.py b/mesa/experimental/solara_viz.py index 28f3a3c66b2..135563bb2db 100644 --- a/mesa/experimental/solara_viz.py +++ b/mesa/experimental/solara_viz.py @@ -1,5 +1,4 @@ -""" -Mesa visualization module for creating interactive model visualizations. +"""Mesa visualization module for creating interactive model visualizations. This module provides components to create browser- and Jupyter notebook-based visualizations of Mesa models, allowing users to watch models run step-by-step and interact with model parameters. @@ -39,8 +38,7 @@ def Card( model, measures, agent_portrayal, space_drawer, dependencies, color, layout_type ): - """ - Create a card component for visualizing model space or measures. + """Create a card component for visualizing model space or measures. Args: model: The Mesa model instance @@ -95,8 +93,7 @@ def SolaraViz( play_interval=150, seed=None, ): - """ - Initialize a component to visualize a model. + """Initialize a component to visualize a model. Args: model_class: Class of the model to instantiate @@ -212,8 +209,7 @@ def do_reseed(): @solara.component def ModelController(model, play_interval, current_step, reset_counter): - """ - Create controls for model execution (step, play, pause, reset). + """Create controls for model execution (step, play, pause, reset). Args: model: The model being visualized @@ -315,8 +311,7 @@ def do_set_playing(value): def split_model_params(model_params): - """ - Split model parameters into user-adjustable and fixed parameters. + """Split model parameters into user-adjustable and fixed parameters. Args: model_params: Dictionary of all model parameters @@ -335,8 +330,7 @@ def split_model_params(model_params): def check_param_is_fixed(param): - """ - Check if a parameter is fixed (not user-adjustable). + """Check if a parameter is fixed (not user-adjustable). Args: param: Parameter to check @@ -354,8 +348,8 @@ def check_param_is_fixed(param): @solara.component def UserInputs(user_params, on_change=None): - """ - Initialize user inputs for configurable model parameters. + """Initialize user inputs for configurable model parameters. + Currently supports :class:`solara.SliderInt`, :class:`solara.SliderFloat`, :class:`solara.Select`, and :class:`solara.Checkbox`. @@ -364,7 +358,6 @@ def UserInputs(user_params, on_change=None): min and max values, and other fields specific to the input type. on_change: Function to be called with (name, value) when the value of an input changes. """ - for name, options in user_params.items(): def change_handler(value, name=name): @@ -423,8 +416,7 @@ def change_handler(value, name=name): def make_text(renderer): - """ - Create a function that renders text using Markdown. + """Create a function that renders text using Markdown. Args: renderer: Function that takes a model and returns a string @@ -440,8 +432,7 @@ def function(model): def make_initial_grid_layout(layout_types): - """ - Create an initial grid layout for visualization components. + """Create an initial grid layout for visualization components. Args: layout_types: List of layout types (Space or Measure) diff --git a/mesa/main.py b/mesa/main.py index a6bcad16f2c..1d65500ee4f 100644 --- a/mesa/main.py +++ b/mesa/main.py @@ -1,3 +1,5 @@ +"""main module for running mesa models with a server.""" + import os import sys from pathlib import Path @@ -19,13 +21,13 @@ @click.group(context_settings=CONTEXT_SETTINGS) def cli(): - "Manage Mesa projects" + """Manage Mesa projects.""" @cli.command() @click.argument("project", type=PROJECT_PATH, default=".") def runserver(project): - """Run mesa project PROJECT + """Run mesa project PROJECT. PROJECT is the path to the directory containing `run.py`, or the current directory if not specified. @@ -45,7 +47,7 @@ def runserver(project): "--no-input", is_flag=True, help="Do not prompt user for custom mesa model input." ) def startproject(no_input): - """Create a new mesa project""" + """Create a new mesa project.""" args = ["cookiecutter", COOKIECUTTER_PATH] if no_input: args.append("--no-input") @@ -54,7 +56,7 @@ def startproject(no_input): @click.command() def version(): - """Show the version of mesa""" + """Show the version of mesa.""" print(f"mesa {__version__}") diff --git a/mesa/model.py b/mesa/model.py index 836f7445ca7..1302c60e482 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -1,5 +1,4 @@ -""" -The model class for Mesa framework. +"""The model class for Mesa framework. Core Objects: Model """ @@ -28,24 +27,8 @@ class Model: Attributes: running: A boolean indicating if the model should continue running. schedule: An object to manage the order and execution of agent steps. - - Properties: - agents: An AgentSet containing all agents in the model - agent_types: A list of different agent types present in the model. - agents_by_type: A dictionary where the keys are agent types and the values are the corresponding AgentSets. - steps: An integer representing the number of steps the model has taken. - It increases automatically at the start of each step() call. - - Methods: - get_agents_of_type: Returns an AgentSet of agents of the specified type. - Deprecated: Use agents_by_type[agenttype] instead. - run_model: Runs the model's simulation until a defined end condition is reached. - step: Executes a single step of the model's simulation process. - next_id: Generates and returns the next unique identifier for an agent. - reset_randomizer: Resets the model's random number generator with a new or existing seed. - initialize_data_collector: Sets up the data collector for the model, requiring an initialized scheduler and agents. - register_agent : register an agent with the model - deregister_agent : remove an agent from the model + steps: the number of times `model.step()` has been called. + random: a seeded random number generator. Notes: Model.agents returns the AgentSet containing all agents registered with the model. Changing @@ -55,9 +38,15 @@ class Model: """ def __init__(self, *args: Any, seed: float | None = None, **kwargs: Any) -> None: - """Create a new model. Overload this method with the actual code to - start the model. Always start with super().__init__() to initialize the - model object properly. + """Create a new model. + + Overload this method with the actual code to initialize the model. Always start with super().__init__() + to initialize the model object properly. + + Args: + args: arguments to pass onto super + seed: the seed for the random number generator + kwargs: keyword arguments to pass onto super """ self.running = True self.schedule = None @@ -83,7 +72,7 @@ def _wrapped_step(self, *args: Any, **kwargs: Any) -> None: # Call the original user-defined step method self._user_step(*args, **kwargs) - def next_id(self) -> int: + def next_id(self) -> int: # noqa: D102 warnings.warn( "using model.next_id() is deprecated. Agents track their unique ID automatically", DeprecationWarning, @@ -125,7 +114,7 @@ def get_agents_of_type(self, agenttype: type[Agent]) -> AgentSet: return self.agents_by_type[agenttype] def _setup_agent_registration(self): - """helper method to initialize the agent registration datastructures""" + """Helper method to initialize the agent registration datastructures.""" self._agents = {} # the hard references to all agents in the model self._agents_by_type: dict[ type[Agent], AgentSet @@ -133,7 +122,7 @@ def _setup_agent_registration(self): self._all_agents = AgentSet([], self) # an agenset with all agents def register_agent(self, agent): - """Register the agent with the model + """Register the agent with the model. Args: agent: The agent to register. @@ -170,10 +159,13 @@ def register_agent(self, agent): self._all_agents.add(agent) def deregister_agent(self, agent): - """Deregister the agent with the model + """Deregister the agent with the model. + + Args: + agent: The agent to deregister. - Notes:: - This method is called automatically by ``Agent.remove`` + Notes: + This method is called automatically by ``Agent.remove`` """ del self._agents[agent] @@ -181,8 +173,9 @@ def deregister_agent(self, agent): self._all_agents.remove(agent) def run_model(self) -> None: - """Run the model until the end condition is reached. Overload as - needed. + """Run the model until the end condition is reached. + + Overload as needed. """ while self.running: self.step() @@ -196,7 +189,6 @@ def reset_randomizer(self, seed: int | None = None) -> None: Args: seed: A new seed for the RNG; if None, reset using the current seed """ - if seed is None: seed = self._seed self.random.seed(seed) @@ -208,6 +200,14 @@ def initialize_data_collector( agent_reporters=None, tables=None, ) -> None: + """Initialize the data collector for the model. + + Args: + model_reporters: model reporters to collect + agent_reporters: agent reporters to collect + tables: tables to collect + + """ if not hasattr(self, "schedule") or self.schedule is None: raise RuntimeError( "You must initialize the scheduler (self.schedule) before initializing the data collector." diff --git a/mesa/space.py b/mesa/space.py index c4d855d8099..2daa3d0ff90 100644 --- a/mesa/space.py +++ b/mesa/space.py @@ -1,6 +1,4 @@ -""" -Mesa Space Module -================= +"""Mesa Space Module. Objects used to add a spatial component to a model. @@ -54,9 +52,10 @@ def accept_tuple_argument(wrapped_function: F) -> F: - """Decorator to allow grid methods that take a list of (x, y) coord tuples - to also handle a single position, by automatically wrapping tuple in - single-item list rather than forcing user to do it.""" + """Decorator to allow grid methods that take a list of (x, y) coord tuples to also handle a single position. + + Tuples are wrapped in a single-item list rather than forcing user to do it. + """ def wrapper(grid_instance, positions) -> Any: if len(positions) == 2 and not isinstance(positions[0], tuple): @@ -67,11 +66,13 @@ def wrapper(grid_instance, positions) -> Any: def is_integer(x: Real) -> bool: - # Check if x is either a CPython integer or Numpy integer. + """Check if x is either a CPython integer or Numpy integer.""" return isinstance(x, _types_integer) def warn_if_agent_has_position_already(placement_func): + """Decorator to give warning if agent has position already set.""" + def wrapper(self, agent, *args, **kwargs): if agent.pos is not None: warnings.warn( @@ -102,7 +103,8 @@ def __init__(self, width: int, height: int, torus: bool) -> None: """Create a new grid. Args: - width, height: The width and height of the grid + width: The grid's width. + height: The grid's height. torus: Boolean whether the grid wraps or not. """ self.height = height @@ -159,7 +161,6 @@ def __getitem__( def __getitem__(self, index): """Access contents from the grid.""" - if isinstance(index, int): # grid[x] return self._grid[index] @@ -192,8 +193,7 @@ def __getitem__(self, index): return [cell for rows in self._grid[x] for cell in rows[y]] def __iter__(self) -> Iterator[GridContent]: - """Create an iterator that chains the rows of the grid together - as if it is one list:""" + """Create an iterator that chains the rows of the grid together as if it is one list.""" return itertools.chain(*self._grid) def coord_iter(self) -> Iterator[tuple[GridContent, Coordinate]]: @@ -209,8 +209,7 @@ def iter_neighborhood( include_center: bool = False, radius: int = 1, ) -> Iterator[Coordinate]: - """Return an iterator over cell coordinates that are in the - neighborhood of a certain point. + """Return an iterator over cell coordinates that are in the neighborhood of a certain point. Args: pos: Coordinate tuple for the neighborhood to get. @@ -237,8 +236,7 @@ def get_neighborhood( include_center: bool = False, radius: int = 1, ) -> Sequence[Coordinate]: - """Return a list of cells that are in the neighborhood of a - certain point. + """Return a list of cells that are in the neighborhood of a certain point. Args: pos: Coordinate tuple for the neighborhood to get. @@ -381,8 +379,7 @@ def torus_adj(self, pos: Coordinate) -> Coordinate: return pos[0] % self.width, pos[1] % self.height def out_of_bounds(self, pos: Coordinate) -> bool: - """Determines whether position is off the grid, returns the out of - bounds coordinate.""" + """Determines whether position is off the grid, returns the out of bounds coordinate.""" x, y = pos return x < 0 or x >= self.width or y < 0 or y >= self.height @@ -390,8 +387,7 @@ def out_of_bounds(self, pos: Coordinate) -> bool: def iter_cell_list_contents( self, cell_list: Iterable[Coordinate] ) -> Iterator[Agent]: - """Returns an iterator of the agents contained in the cells identified - in `cell_list`; cells with empty content are excluded. + """Returns an iterator of the agents contained in the cells identified in `cell_list`; cells with empty content are excluded. Args: cell_list: Array-like of (x, y) tuples, or single tuple. @@ -407,8 +403,7 @@ def iter_cell_list_contents( @accept_tuple_argument def get_cell_list_contents(self, cell_list: Iterable[Coordinate]) -> list[Agent]: - """Returns an iterator of the agents contained in the cells identified - in `cell_list`; cells with empty content are excluded. + """Returns an iterator of the agents contained in the cells identified in `cell_list`; cells with empty content are excluded. Args: cell_list: Array-like of (x, y) tuples, or single tuple. @@ -441,8 +436,7 @@ def move_agent_to_one_of( selection: str = "random", handle_empty: str | None = None, ) -> None: - """ - Move an agent to one of the given positions. + """Move an agent to one of the given positions. Args: agent: Agent object to move. Assumed to have its current location stored in a 'pos' tuple. @@ -494,9 +488,7 @@ def move_agent_to_one_of( ) def _distance_squared(self, pos1: Coordinate, pos2: Coordinate) -> float: - """ - Calculate the squared Euclidean distance between two points for performance. - """ + """Calculate the squared Euclidean distance between two points for performance.""" # Use squared Euclidean distance to avoid sqrt operation dx, dy = abs(pos1[0] - pos2[0]), abs(pos1[1] - pos2[1]) if self.torus: @@ -505,7 +497,7 @@ def _distance_squared(self, pos1: Coordinate, pos2: Coordinate) -> float: return dx**2 + dy**2 def swap_pos(self, agent_a: Agent, agent_b: Agent) -> None: - """Swap agents positions""" + """Swap agents positions.""" agents_no_pos = [] if (pos_a := agent_a.pos) is None: agents_no_pos.append(agent_a) @@ -566,16 +558,16 @@ def is_single_argument_function(function): ) -def ufunc_requires_additional_input(ufunc): +def ufunc_requires_additional_input(ufunc): # noqa: D103 # NumPy ufuncs have a 'nargs' attribute indicating the number of input arguments # For binary ufuncs (like np.add), nargs is 2 return ufunc.nargs > 1 class PropertyLayer: - """ - A class representing a layer of properties in a two-dimensional grid. Each cell in the grid - can store a value of a specified data type. + """A class representing a layer of properties in a two-dimensional grid. + + Each cell in the grid can store a value of a specified data type. Attributes: name (str): The name of the property layer. @@ -583,13 +575,6 @@ class PropertyLayer: height (int): The height of the grid (number of rows). data (numpy.ndarray): A NumPy array representing the grid data. - Methods: - set_cell(position, value): Sets the value of a single cell. - set_cells(value, condition=None): Sets the values of multiple cells, optionally based on a condition. - modify_cell(position, operation, value): Modifies the value of a single cell using an operation. - modify_cells(operation, value, condition_function): Modifies the values of multiple cells using an operation. - select_cells(condition, return_list): Selects cells that meet a specified condition. - aggregate_property(operation): Performs an aggregate operation over all cells. """ propertylayer_experimental_warning_given = False @@ -597,8 +582,7 @@ class PropertyLayer: def __init__( self, name: str, width: int, height: int, default_value, dtype=np.float64 ): - """ - Initializes a new PropertyLayer instance. + """Initializes a new PropertyLayer instance. Args: name (str): The name of the property layer. @@ -649,14 +633,11 @@ def __init__( self.__class__.propertylayer_experimental_warning_given = True def set_cell(self, position: Coordinate, value): - """ - Update a single cell's value in-place. - """ + """Update a single cell's value in-place.""" self.data[position] = value def set_cells(self, value, condition=None): - """ - Perform a batch update either on the entire grid or conditionally, in-place. + """Perform a batch update either on the entire grid or conditionally, in-place. Args: value: The value to be used for the update. @@ -685,8 +666,8 @@ def set_cells(self, value, condition=None): np.copyto(self.data, value, where=condition_result) def modify_cell(self, position: Coordinate, operation, value=None): - """ - Modify a single cell using an operation, which can be a lambda function or a NumPy ufunc. + """Modify a single cell using an operation, which can be a lambda function or a NumPy ufunc. + If a NumPy ufunc is used, an additional value should be provided. Args: @@ -707,8 +688,8 @@ def modify_cell(self, position: Coordinate, operation, value=None): raise ValueError("Invalid operation or missing value for NumPy ufunc.") def modify_cells(self, operation, value=None, condition_function=None): - """ - Modify cells using an operation, which can be a lambda function or a NumPy ufunc. + """Modify cells using an operation, which can be a lambda function or a NumPy ufunc. + If a NumPy ufunc is used, an additional value should be provided. Args: @@ -742,8 +723,7 @@ def modify_cells(self, operation, value=None, condition_function=None): self.data = np.where(condition_array, modified_data, self.data) def select_cells(self, condition, return_list=True): - """ - Find cells that meet a specified condition using NumPy's boolean indexing, in-place. + """Find cells that meet a specified condition using NumPy's boolean indexing, in-place. Args: condition: A callable that returns a boolean array when applied to the data. @@ -768,9 +748,9 @@ def aggregate_property(self, operation): class _PropertyGrid(_Grid): - """ - A private subclass of _Grid that supports the addition of property layers, enabling - the representation and manipulation of additional data layers on the grid. This class is + """A private subclass of _Grid that supports the addition of property layers. + + This enables the representation and manipulation of additional data layers on the grid. This class is intended for internal use within the Mesa framework and is currently utilized by SingleGrid and MultiGrid classes to provide enhanced grid functionality. @@ -783,22 +763,15 @@ class _PropertyGrid(_Grid): properties (dict): A dictionary mapping property layer names to PropertyLayer instances. empty_mask (np.ndarray): A boolean array indicating empty cells on the grid. - Methods: - add_property_layer(property_layer): Adds a new property layer to the grid. - remove_property_layer(property_name): Removes a property layer from the grid by its name. - get_neighborhood_mask(pos, moore, include_center, radius): Generates a boolean mask of the neighborhood. - select_cells(conditions, extreme_values, masks, only_empty, return_list): Selects cells based on multiple conditions, - extreme values, masks, with an option to select only empty cells, returning either a list of coordinates or a mask. - - Mask Usage: - Several methods in this class accept a mask as an input, which is a NumPy ndarray of boolean values. This mask - specifies the cells to be considered (True) or ignored (False) in operations. Users can create custom masks, - including neighborhood masks, to apply specific conditions or constraints. Additionally, methods that deal with - cell selection or agent movement can return either a list of cell coordinates or a mask, based on the 'return_list' - parameter. This flexibility allows for more nuanced control and customization of grid operations, catering to a wide - range of modeling requirements and scenarios. - - Note: + + Several methods in this class accept a mask as an input, which is a NumPy ndarray of boolean values. This mask + specifies the cells to be considered (True) or ignored (False) in operations. Users can create custom masks, + including neighborhood masks, to apply specific conditions or constraints. Additionally, methods that deal with + cell selection or agent movement can return either a list of cell coordinates or a mask, based on the 'return_list' + parameter. This flexibility allows for more nuanced control and customization of grid operations, catering to a wide + range of modeling requirements and scenarios. + + Notes: This class is not intended for direct use in user models but is currently used by the SingleGrid and MultiGrid. """ @@ -809,8 +782,7 @@ def __init__( torus: bool, property_layers: None | PropertyLayer | list[PropertyLayer] = None, ): - """ - Initializes a new _PropertyGrid instance with specified dimensions and optional property layers. + """Initializes a new _PropertyGrid instance with specified dimensions and optional property layers. Args: width (int): The width of the grid (number of columns). @@ -839,15 +811,12 @@ def __init__( @property def empty_mask(self) -> np.ndarray: - """ - Returns a boolean mask indicating empty cells on the grid. - """ + """Returns a boolean mask indicating empty cells on the grid.""" return self._empty_mask # Add and remove properties to the grid def add_property_layer(self, property_layer: PropertyLayer): - """ - Adds a new property layer to the grid. + """Adds a new property layer to the grid. Args: property_layer (PropertyLayer): The PropertyLayer instance to be added to the grid. @@ -865,8 +834,7 @@ def add_property_layer(self, property_layer: PropertyLayer): self.properties[property_layer.name] = property_layer def remove_property_layer(self, property_name: str): - """ - Removes a property layer from the grid by its name. + """Removes a property layer from the grid by its name. Args: property_name (str): The name of the property layer to be removed. @@ -881,8 +849,8 @@ def remove_property_layer(self, property_name: str): def get_neighborhood_mask( self, pos: Coordinate, moore: bool, include_center: bool, radius: int ) -> np.ndarray: - """ - Generate a boolean mask representing the neighborhood. + """Generate a boolean mask representing the neighborhood. + Helper method for select_cells_multi_properties() and move_agent_to_random_cell() Args: @@ -910,8 +878,7 @@ def select_cells( only_empty: bool = False, return_list: bool = True, ) -> list[Coordinate] | np.ndarray: - """ - Select cells based on property conditions, extreme values, and/or masks, with an option to only select empty cells. + """Select cells based on property conditions, extreme values, and/or masks, with an option to only select empty cells. Args: conditions (dict): A dictionary where keys are property names and values are callables that return a boolean when applied. @@ -1064,7 +1031,7 @@ def remove_agent(self, agent: Agent) -> None: self._empty_mask[agent.pos] = False agent.pos = None - def iter_neighbors( + def iter_neighbors( # noqa: D102 self, pos: Coordinate, moore: bool, @@ -1079,8 +1046,9 @@ def iter_neighbors( def iter_cell_list_contents( self, cell_list: Iterable[Coordinate] ) -> Iterator[Agent]: - """Returns an iterator of the agents contained in the cells identified - in `cell_list`; cells with empty content are excluded. + """Returns an iterator of the agents contained in the cells identified in `cell_list`. + + Cells with empty content are excluded. Args: cell_list: Array-like of (x, y) tuples, or single tuple. @@ -1118,9 +1086,9 @@ def torus_adj_2d(self, pos: Coordinate) -> Coordinate: def get_neighborhood( self, pos: Coordinate, include_center: bool = False, radius: int = 1 ) -> list[Coordinate]: - """Return a list of coordinates that are in the - neighborhood of a certain point. To calculate the neighborhood - for a HexGrid the parity of the x coordinate of the point is + """Return a list of coordinates that are in the neighborhood of a certain point. + + To calculate the neighborhood for a HexGrid the parity of the x coordinate of the point is important, the neighborhood can be sketched as: Always: (0,-), (0,+) @@ -1206,8 +1174,7 @@ def get_neighborhood( def iter_neighborhood( self, pos: Coordinate, include_center: bool = False, radius: int = 1 ) -> Iterator[Coordinate]: - """Return an iterator over cell coordinates that are in the - neighborhood of a certain point. + """Return an iterator over cell coordinates that are in the neighborhood of a certain point. Args: pos: Coordinate tuple for the neighborhood to get. @@ -1257,8 +1224,7 @@ def get_neighbors( class HexSingleGrid(_HexGrid, SingleGrid): - """Hexagonal SingleGrid: a SingleGrid where neighbors are computed - according to a hexagonal tiling of the grid. + """Hexagonal SingleGrid: a SingleGrid where neighbors are computed according to a hexagonal tiling of the grid. Functions according to odd-q rules. See http://www.redblobgames.com/grids/hexagons/#coordinates for more. @@ -1274,8 +1240,7 @@ class HexSingleGrid(_HexGrid, SingleGrid): class HexMultiGrid(_HexGrid, MultiGrid): - """Hexagonal MultiGrid: a MultiGrid where neighbors are computed - according to a hexagonal tiling of the grid. + """Hexagonal MultiGrid: a MultiGrid where neighbors are computed according to a hexagonal tiling of the grid. Functions according to odd-q rules. See http://www.redblobgames.com/grids/hexagons/#coordinates for more. @@ -1293,8 +1258,7 @@ class HexMultiGrid(_HexGrid, MultiGrid): class HexGrid(HexSingleGrid): - """Hexagonal Grid: a Grid where neighbors are computed - according to a hexagonal tiling of the grid. + """Hexagonal Grid: a Grid where neighbors are computed according to a hexagonal tiling of the grid. Functions according to odd-q rules. See http://www.redblobgames.com/grids/hexagons/#coordinates for more. @@ -1305,6 +1269,13 @@ class HexGrid(HexSingleGrid): """ def __init__(self, width: int, height: int, torus: bool) -> None: + """Initializes a HexGrid, deprecated. + + Args: + width: the width of the grid + height: the height of the grid + torus: whether the grid wraps + """ super().__init__(width, height, torus) warn( ( @@ -1341,11 +1312,13 @@ def __init__( """Create a new continuous space. Args: - x_max, y_max: Maximum x and y coordinates for the space. + x_max: the maximum x-coordinate + y_max: the maximum y-coordinate. torus: Boolean for whether the edges loop around. - x_min, y_min: (default 0) If provided, set the minimum x and y - coordinates for the space. Below them, values loop to - the other edge (if torus=True) or raise an exception. + x_min: (default 0) If provided, set the minimum x -coordinate for the space. Below them, values loop to + the other edge (if torus=True) or raise an exception. + y_min: (default 0) If provided, set the minimum y -coordinate for the space. Below them, values loop to + the other edge (if torus=True) or raise an exception. """ self.x_min = x_min self.x_max = x_max @@ -1448,11 +1421,13 @@ def get_heading( self, pos_1: FloatCoordinate, pos_2: FloatCoordinate ) -> FloatCoordinate: """Get the heading vector between two points, accounting for toroidal space. + It is possible to calculate the heading angle by applying the atan2 function to the result. Args: - pos_1, pos_2: Coordinate tuples for both points. + pos_1: Coordinate tuples for both points. + pos_2: Coordinate tuples for both points. """ one = np.array(pos_1) two = np.array(pos_2) @@ -1478,7 +1453,8 @@ def get_distance(self, pos_1: FloatCoordinate, pos_2: FloatCoordinate) -> float: """Get the distance between two point, accounting for toroidal space. Args: - pos_1, pos_2: Coordinate tuples for both points. + pos_1: Coordinate tuples for point1. + pos_2: Coordinate tuples for point2. """ x1, y1 = pos_1 x2, y2 = pos_2 @@ -1525,7 +1501,7 @@ def __init__(self, g: Any) -> None: """Create a new network. Args: - G: a NetworkX graph instance. + g: a NetworkX graph instance. """ self.G = g for node_id in self.G.nodes: @@ -1545,7 +1521,16 @@ def place_agent(self, agent: Agent, node_id: int) -> None: def get_neighborhood( self, node_id: int, include_center: bool = False, radius: int = 1 ) -> list[int]: - """Get all adjacent nodes within a certain radius""" + """Get all adjacent nodes within a certain radius. + + Args: + node_id: node id for which to get neighborhood + include_center: boolean to include node itself or not + radius: size of neighborhood + + Returns: + a list + """ if radius == 1: neighborhood = list(self.G.neighbors(node_id)) if include_center: @@ -1562,28 +1547,61 @@ def get_neighborhood( def get_neighbors( self, node_id: int, include_center: bool = False, radius: int = 1 ) -> list[Agent]: - """Get all agents in adjacent nodes (within a certain radius).""" + """Get all agents in adjacent nodes (within a certain radius). + + Args: + node_id: node id for which to get neighbors + include_center: whether to include node itself or not + radius: size of neighborhood in which to find neighbors + + Returns: + list of agents in neighborhood. + """ neighborhood = self.get_neighborhood(node_id, include_center, radius) return self.get_cell_list_contents(neighborhood) def move_agent(self, agent: Agent, node_id: int) -> None: - """Move an agent from its current node to a new node.""" + """Move an agent from its current node to a new node. + + Args: + agent: agent instance + node_id: id of node + + """ self.remove_agent(agent) self.place_agent(agent, node_id) def remove_agent(self, agent: Agent) -> None: - """Remove the agent from the network and set its pos attribute to None.""" + """Remove the agent from the network and set its pos attribute to None. + + Args: + agent: agent instance + + """ node_id = agent.pos self.G.nodes[node_id]["agent"].remove(agent) agent.pos = None def is_cell_empty(self, node_id: int) -> bool: - """Returns a bool of the contents of a cell.""" + """Returns a bool of the contents of a cell. + + Args: + node_id: id of node + + """ return self.G.nodes[node_id]["agent"] == self.default_val() def get_cell_list_contents(self, cell_list: list[int]) -> list[Agent]: - """Returns a list of the agents contained in the nodes identified - in `cell_list`; nodes with empty content are excluded. + """Returns a list of the agents contained in the nodes identified in `cell_list`. + + Nodes with empty content are excluded. + + Args: + cell_list: list of cell ids. + + Returns: + list of the agents contained in the nodes identified in `cell_list`. + """ return list(self.iter_cell_list_contents(cell_list)) @@ -1592,8 +1610,16 @@ def get_all_cell_contents(self) -> list[Agent]: return self.get_cell_list_contents(self.G) def iter_cell_list_contents(self, cell_list: list[int]) -> Iterator[Agent]: - """Returns an iterator of the agents contained in the nodes identified - in `cell_list`; nodes with empty content are excluded. + """Returns an iterator of the agents contained in the nodes identified in `cell_list`. + + Nodes with empty content are excluded. + + Args: + cell_list: list of cell ids. + + Returns: + iterator of the agents contained in the nodes identified in `cell_list`. + """ return itertools.chain.from_iterable( self.G.nodes[node_id]["agent"] diff --git a/mesa/time.py b/mesa/time.py index f50aa89b478..82de81859ac 100644 --- a/mesa/time.py +++ b/mesa/time.py @@ -1,6 +1,4 @@ -""" -Mesa Time Module -================ +"""Mesa Time Module. Objects for handling the time component of a model. In particular, this module contains Schedulers, which handle agent activation. A Scheduler is an object @@ -39,23 +37,16 @@ class BaseScheduler: - """ - A simple scheduler that activates agents one at a time, in the order they were added. + """A simple scheduler that activates agents one at a time, in the order they were added. This scheduler is designed to replicate the behavior of the scheduler in MASON, a multi-agent simulation toolkit. It assumes that each agent added has a `step` method which takes no arguments and executes the agent's actions. Attributes: - - model (Model): The model instance associated with the scheduler. - - steps (int): The number of steps the scheduler has taken. - - time (TimeT): The current time in the simulation. Can be an integer or a float. - - Methods: - - add: Adds an agent to the scheduler. - - remove: Removes an agent from the scheduler. - - step: Executes a step, which involves activating each agent once. - - get_agent_count: Returns the number of agents in the scheduler. - - agents (property): Returns a list of all agent instances. + model (Model): The model instance associated with the scheduler. + steps (int): The number of steps the scheduler has taken. + time (TimeT): The current time in the simulation. Can be an integer or a float. + """ def __init__(self, model: Model, agents: Iterable[Agent] | None = None) -> None: @@ -82,10 +73,8 @@ def add(self, agent: Agent) -> None: """Add an Agent object to the schedule. Args: - agent: An Agent to be added to the schedule. NOTE: The agent must - have a step() method. + agent (Agent): An Agent to be added to the schedule. """ - if agent not in self._agents: self._agents.add(agent) else: @@ -94,12 +83,13 @@ def add(self, agent: Agent) -> None: def remove(self, agent: Agent) -> None: """Remove all instances of a given agent from the schedule. + Args: + agent: An `Agent` instance. + Note: It is only necessary to explicitly remove agents from the schedule if the agent is not removed from the model. - Args: - agent: An agent object. """ self._agents.remove(agent) @@ -117,10 +107,12 @@ def get_agent_count(self) -> int: @property def agents(self) -> AgentSet: + """Return agents in the scheduler.""" # a bit dirty, but returns a copy of the internal agent set return self._agents.select() def get_agent_keys(self, shuffle: bool = False) -> list[int]: + """Deprecated.""" # To be able to remove and/or add agents during stepping # it's necessary to cast the keys view to a list. @@ -138,14 +130,21 @@ def get_agent_keys(self, shuffle: bool = False) -> list[int]: return agent_keys def do_each(self, method, shuffle=False): + """Perform `method` on each agent. + + Args: + method: method to call + shuffle: shuffle the agents or not prior to calling method + + + """ if shuffle: self._agents.shuffle(inplace=True) self._agents.do(method) class RandomActivation(BaseScheduler): - """ - A scheduler that activates each agent once per step, in a random order, with the order reshuffled each step. + """A scheduler that activates each agent once per step, in a random order, with the order reshuffled each step. This scheduler is equivalent to the NetLogo 'ask agents...' behavior and is a common default for ABMs. It assumes that all agents have a `step` method. @@ -155,23 +154,17 @@ class RandomActivation(BaseScheduler): Inherits all attributes and methods from BaseScheduler. - Methods: - - step: Executes a step, activating each agent in a random order. """ def step(self) -> None: - """Executes the step of all agents, one at a time, in - random order. - - """ + """Executes the step of all agents, one at a time, in random order.""" self.do_each("step", shuffle=True) self.steps += 1 self.time += 1 class SimultaneousActivation(BaseScheduler): - """ - A scheduler that simulates the simultaneous activation of all agents. + """A scheduler that simulates the simultaneous activation of all agents. This scheduler is unique in that it requires agents to have both `step` and `advance` methods. - The `step` method is for activating the agent and staging any changes without applying them immediately. @@ -182,8 +175,6 @@ class SimultaneousActivation(BaseScheduler): Inherits all attributes and methods from BaseScheduler. - Methods: - - step: Executes a step for all agents, first calling `step` then `advance` on each. """ def step(self) -> None: @@ -198,9 +189,9 @@ def step(self) -> None: class StagedActivation(BaseScheduler): - """ - A scheduler allowing agent activation to be divided into several stages, with all agents executing one stage - before moving on to the next. This class is a generalization of SimultaneousActivation. + """A scheduler allowing agent activation to be divided into several stages. + + All agents executing one stage before moving on to the next. This class is a generalization of SimultaneousActivation. This scheduler is useful for complex models where actions need to be broken down into distinct phases for each agent in each time step. Agents must implement methods for each defined stage. @@ -215,8 +206,6 @@ class StagedActivation(BaseScheduler): - shuffle (bool): Determines whether to shuffle the order of agents each step. - shuffle_between_stages (bool): Determines whether to shuffle agents between each stage. - Methods: - - step: Executes all the stages for all agents in the defined order. """ def __init__( @@ -261,8 +250,7 @@ def step(self) -> None: class RandomActivationByType(BaseScheduler): - """ - A scheduler that activates each type of agent once per step, in random order, with the order reshuffled every step. + """A scheduler that activates each type of agent once per step, in random order, with the order reshuffled every step. This scheduler is useful for models with multiple types of agents, ensuring that each type is treated equitably in terms of activation order. The randomness in activation order helps in reducing biases @@ -278,14 +266,10 @@ class RandomActivationByType(BaseScheduler): Attributes: - agents_by_type (defaultdict): A dictionary mapping agent types to dictionaries of agents. - Methods: - - step: Executes the step of each agent type in a random order. - - step_type: Activates all agents of a given type. - - get_type_count: Returns the count of agents of a specific type. """ @property - def agents_by_type(self): + def agents_by_type(self): # noqa: D102 warnings.warn( "Because of the shift to using AgentSet, in the future this attribute will return a dict with" "type as key as AgentSet as value. Future behavior is available via RandomActivationByType._agents_by_type", @@ -300,14 +284,13 @@ def agents_by_type(self): return agentsbytype def __init__(self, model: Model, agents: Iterable[Agent] | None = None) -> None: - super().__init__(model, agents) - """ + """Initialize RandomActivationByType instance. Args: model (Model): The model to which the schedule belongs agents (Iterable[Agent], None, optional): An iterable of agents who are controlled by the schedule """ - + super().__init__(model, agents) # can't be a defaultdict because we need to pass model to AgentSet self._agents_by_type: [type, AgentSet] = {} @@ -319,8 +302,7 @@ def __init__(self, model: Model, agents: Iterable[Agent] | None = None) -> None: self._agents_by_type[type(agent)] = AgentSet([agent], self.model) def add(self, agent: Agent) -> None: - """ - Add an Agent object to the schedule + """Add an Agent object to the schedule. Args: agent: An Agent to be added to the schedule. @@ -333,21 +315,21 @@ def add(self, agent: Agent) -> None: self._agents_by_type[type(agent)] = AgentSet([agent], self.model) def remove(self, agent: Agent) -> None: - """ - Remove all instances of a given agent from the schedule. + """Remove all instances of a given agent from the schedule. + + Args: + agent: An Agent to be removed from the schedule. + """ super().remove(agent) self._agents_by_type[type(agent)].remove(agent) def step(self, shuffle_types: bool = True, shuffle_agents: bool = True) -> None: - """ - Executes the step of each agent type, one at a time, in random order. + """Executes the step of each agent type, one at a time, in random order. Args: - shuffle_types: If True, the order of execution of each types is - shuffled. - shuffle_agents: If True, the order of execution of each agents in a - type group is shuffled. + shuffle_types: If True, the order of execution of each types is shuffled. + shuffle_agents: If True, the order of execution of each agents in a type group is shuffled. """ # To be able to remove and/or add agents during stepping # it's necessary to cast the keys view to a list. @@ -360,12 +342,11 @@ def step(self, shuffle_types: bool = True, shuffle_agents: bool = True) -> None: self.time += 1 def step_type(self, agenttype: type[Agent], shuffle_agents: bool = True) -> None: - """ - Shuffle order and run all agents of a given type. - This method is equivalent to the NetLogo 'ask [breed]...'. + """Shuffle order and run all agents of a given type. Args: agenttype: Class object of the type to run. + shuffle_agents: If True, shuffle agents """ agents = self._agents_by_type[agenttype] @@ -374,19 +355,15 @@ def step_type(self, agenttype: type[Agent], shuffle_agents: bool = True) -> None agents.do("step") def get_type_count(self, agenttype: type[Agent]) -> int: - """ - Returns the current number of agents of certain type in the queue. - """ + """Returns the current number of agents of certain type in the queue.""" return len(self._agents_by_type[agenttype]) class DiscreteEventScheduler(BaseScheduler): - """ - This class has been deprecated and replaced by the functionality provided by experimental.devs - """ + """This class has been deprecated and replaced by the functionality provided by experimental.devs.""" def __init__(self, model: Model, time_step: TimeT = 1) -> None: - """ + """Initialize DiscreteEventScheduler. Args: model (Model): The model to which the schedule belongs diff --git a/mesa/visualization/UserParam.py b/mesa/visualization/UserParam.py index 5b342471ddb..26660e4bd2d 100644 --- a/mesa/visualization/UserParam.py +++ b/mesa/visualization/UserParam.py @@ -1,7 +1,12 @@ +"""Solara visualization related helper classes.""" + + class UserParam: + """UserParam.""" + _ERROR_MESSAGE = "Missing or malformed inputs for '{}' Option '{}'" - def maybe_raise_error(self, param_type, valid): + def maybe_raise_error(self, param_type, valid): # noqa: D102 if valid: return msg = self._ERROR_MESSAGE.format(param_type, self.label) @@ -9,11 +14,9 @@ def maybe_raise_error(self, param_type, valid): class Slider(UserParam): - """ - A number-based slider input with settable increment. + """A number-based slider input with settable increment. Example: - slider_option = Slider("My Slider", value=123, min=10, max=200, step=0.1) Args: @@ -34,6 +37,16 @@ def __init__( step=1, dtype=None, ): + """Initializes a slider. + + Args: + label: The displayed label in the UI + value: The initial value of the slider + min: The minimum possible value of the slider + max: The maximum possible value of the slider + step: The step between min and max for a range of possible values + dtype: either int or float + """ self.label = label self.value = value self.min = min @@ -49,8 +62,8 @@ def __init__( else: self.is_float_slider = dtype is float - def _check_values_are_float(self, value, min, max, step): + def _check_values_are_float(self, value, min, max, step): # D103 return any(isinstance(n, float) for n in (value, min, max, step)) - def get(self, attr): + def get(self, attr): # noqa: D102 return getattr(self, attr) diff --git a/mesa/visualization/__init__.py b/mesa/visualization/__init__.py index d8a0ebecf86..3208642d562 100644 --- a/mesa/visualization/__init__.py +++ b/mesa/visualization/__init__.py @@ -1,3 +1,5 @@ +"""Solara based visualization for Mesa models.""" + from .components.altair import make_space_altair from .components.matplotlib import make_plot_measure, make_space_matplotlib from .solara_viz import JupyterViz, SolaraViz, make_text diff --git a/mesa/visualization/components/altair.py b/mesa/visualization/components/altair.py index 1d23b170bda..3dcb7e2c5d9 100644 --- a/mesa/visualization/components/altair.py +++ b/mesa/visualization/components/altair.py @@ -1,3 +1,5 @@ +"""Altair based solara components for visualization mesa spaces.""" + import contextlib import solara @@ -8,7 +10,7 @@ from mesa.visualization.utils import update_counter -def make_space_altair(agent_portrayal=None): +def make_space_altair(agent_portrayal=None): # noqa: D103 if agent_portrayal is None: def agent_portrayal(a): @@ -21,7 +23,7 @@ def MakeSpaceAltair(model): @solara.component -def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): +def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): # noqa: D103 update_counter.get() space = getattr(model, "grid", None) if space is None: diff --git a/mesa/visualization/components/matplotlib.py b/mesa/visualization/components/matplotlib.py index b1c61581b71..6061356af7b 100644 --- a/mesa/visualization/components/matplotlib.py +++ b/mesa/visualization/components/matplotlib.py @@ -1,3 +1,5 @@ +"""Matplotlib based solara components for visualization MESA spaces and plots.""" + from collections import defaultdict import networkx as nx @@ -10,7 +12,7 @@ from mesa.visualization.utils import update_counter -def make_space_matplotlib(agent_portrayal=None): +def make_space_matplotlib(agent_portrayal=None): # noqa: D103 if agent_portrayal is None: def agent_portrayal(a): @@ -23,7 +25,7 @@ def MakeSpaceMatplotlib(model): @solara.component -def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = None): +def SpaceMatplotlib(model, agent_portrayal, dependencies: list[any] | None = None): # noqa: D103 update_counter.get() space_fig = Figure() space_ax = space_fig.subplots() @@ -217,7 +219,7 @@ def portray(g): space_ax.plot(*zip(*polygon), color="black") # Plot polygon edges in red -def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): +def make_plot_measure(measure: str | dict[str, str] | list[str] | tuple[str]): # noqa: D103 def MakePlotMeasure(model): return PlotMatplotlib(model, measure) @@ -225,7 +227,7 @@ def MakePlotMeasure(model): @solara.component -def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): +def PlotMatplotlib(model, measure, dependencies: list[any] | None = None): # noqa: D103 update_counter.get() fig = Figure() ax = fig.subplots() diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index a21ce67dea2..7b1a0464f65 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -1,5 +1,4 @@ -""" -Mesa visualization module for creating interactive model visualizations. +"""Mesa visualization module for creating interactive model visualizations. This module provides components to create browser- and Jupyter notebook-based visualizations of Mesa models, allowing users to watch models run step-by-step and interact with model parameters. @@ -44,8 +43,7 @@ def Card( model, measures, agent_portrayal, space_drawer, dependencies, color, layout_type ): - """ - Create a card component for visualizing model space or measures. + """Create a card component for visualizing model space or measures. Args: model: The Mesa model instance @@ -93,12 +91,22 @@ def Card( def SolaraViz( model: "Model" | solara.Reactive["Model"], components: list[solara.component] | Literal["default"] = "default", - *args, play_interval=100, model_params=None, seed=0, name: str | None = None, ): + """Solara visualization component. + + Args: + model: a Model instance + components: list of solara components + play_interval: int + model_params: parameters for instantiating a model + seed: the seed for the rng + name: str + + """ update_counter.get() if components == "default": components = [components_altair.make_space_altair()] @@ -149,8 +157,7 @@ def step(): @solara.component def ModelController(model: solara.Reactive["Model"], play_interval=100): - """ - Create controls for model execution (step, play, pause, reset). + """Create controls for model execution (step, play, pause, reset). Args: model: The reactive model being visualized @@ -202,8 +209,7 @@ def do_reset(): def split_model_params(model_params): - """ - Split model parameters into user-adjustable and fixed parameters. + """Split model parameters into user-adjustable and fixed parameters. Args: model_params: Dictionary of all model parameters @@ -222,8 +228,7 @@ def split_model_params(model_params): def check_param_is_fixed(param): - """ - Check if a parameter is fixed (not user-adjustable). + """Check if a parameter is fixed (not user-adjustable). Args: param: Parameter to check @@ -241,6 +246,15 @@ def check_param_is_fixed(param): @solara.component def ModelCreator(model, model_params, seed=1): + """Helper class to create a new Model instance. + + Args: + model: model instance + model_params: model parameters + seed: the seed to use for the random number generator + + + """ user_params, fixed_params = split_model_params(model_params) reactive_seed = solara.use_reactive(seed) @@ -278,17 +292,15 @@ def create_model(): @solara.component def UserInputs(user_params, on_change=None): - """ - Initialize user inputs for configurable model parameters. + """Initialize user inputs for configurable model parameters. + Currently supports :class:`solara.SliderInt`, :class:`solara.SliderFloat`, :class:`solara.Select`, and :class:`solara.Checkbox`. Args: - user_params: Dictionary with options for the input, including label, - min and max values, and other fields specific to the input type. + user_params: Dictionary with options for the input, including label, min and max values, and other fields specific to the input type. on_change: Function to be called with (name, value) when the value of an input changes. """ - for name, options in user_params.items(): def change_handler(value, name=name): @@ -347,8 +359,7 @@ def change_handler(value, name=name): def make_text(renderer): - """ - Create a function that renders text using Markdown. + """Create a function that renders text using Markdown. Args: renderer: Function that takes a model and returns a string @@ -364,8 +375,7 @@ def function(model): def make_initial_grid_layout(layout_types): - """ - Create an initial grid layout for visualization components. + """Create an initial grid layout for visualization components. Args: layout_types: List of layout types (Space or Measure) @@ -387,6 +397,6 @@ def make_initial_grid_layout(layout_types): @solara.component -def ShowSteps(model): +def ShowSteps(model): # noqa: D103 update_counter.get() return solara.Text(f"Step: {model.steps}") diff --git a/mesa/visualization/utils.py b/mesa/visualization/utils.py index c49b35e3664..95d9a14f55c 100644 --- a/mesa/visualization/utils.py +++ b/mesa/visualization/utils.py @@ -1,7 +1,9 @@ +"""Solara related utils.""" + import solara update_counter = solara.reactive(0) -def force_update(): +def force_update(): # noqa: D103 update_counter.value += 1 diff --git a/pyproject.toml b/pyproject.toml index bc949f837a4..edee30d161d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ select = [ "UP", # upgrade "W", # style warnings "YTT", # sys.version + "D", # docstring ] # Ignore list taken from https://github.com/psf/black/blob/master/.flake8 # E203 Whitespace before ':' @@ -136,3 +137,5 @@ extend-ignore = [ "ISC001", # ruff format asks to disable this feature "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes ] +[tool.ruff.lint.pydocstyle] +convention = "google" diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb2d..0c562e6b2f8 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""init of tests.""" diff --git a/tests/read_requirements.py b/tests/read_requirements.py index 83ccfd95219..7ab70a595cf 100644 --- a/tests/read_requirements.py +++ b/tests/read_requirements.py @@ -1,3 +1,4 @@ +# noqa: D100 import toml # This file reads the pyproject.toml and prints out the diff --git a/tests/test_agent.py b/tests/test_agent.py index 787cfc0bc75..769be4ec5fc 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,3 +1,5 @@ +"""Agent.py related tests.""" + import pickle import numpy as np @@ -8,30 +10,42 @@ class TestAgent(Agent): + """Agent class for testing.""" + def get_unique_identifier(self): + """Return unique identifier for this agent.""" return self.unique_id class TestAgentDo(Agent): + """Agent class for testing.""" + def __init__( self, model, ): + """Initialize an Agent. + + Args: + model (Model): the model to which the agent belongs + + """ super().__init__(model) self.agent_set = None - def get_unique_identifier(self): + def get_unique_identifier(self): # noqa: D102 return self.unique_id - def do_add(self): + def do_add(self): # noqa: D102 agent = TestAgentDo(self.model) self.agent_set.add(agent) - def do_remove(self): + def do_remove(self): # noqa: D102 self.agent_set.remove(self) def test_agent_removal(): + """Test agent removal.""" model = Model() agent = TestAgent(model) # Check if the agent is added @@ -43,6 +57,7 @@ def test_agent_removal(): def test_agentset(): + """Test agentset class.""" # create agentset model = Model() agents = [TestAgent(model) for _ in range(10)] @@ -112,6 +127,7 @@ def test_function(agent): def test_agentset_initialization(): + """Test agentset initialization.""" model = Model() empty_agentset = AgentSet([], model) assert len(empty_agentset) == 0 @@ -122,6 +138,7 @@ def test_agentset_initialization(): def test_agentset_serialization(): + """Test pickleability of agentset.""" model = Model() agents = [TestAgent(model) for _ in range(5)] agentset = AgentSet(agents, model) @@ -136,6 +153,7 @@ def test_agentset_serialization(): def test_agent_membership(): + """Test agent membership in AgentSet.""" model = Model() agents = [TestAgent(model) for _ in range(5)] agentset = AgentSet(agents, model) @@ -145,6 +163,7 @@ def test_agent_membership(): def test_agent_add_remove_discard(): + """Test adding, removing and discarding agents from AgentSet.""" model = Model() agent = TestAgent(model) agentset = AgentSet([], model) @@ -164,6 +183,7 @@ def test_agent_add_remove_discard(): def test_agentset_get_item(): + """Test integer based access to AgentSet.""" model = Model() agents = [TestAgent(model) for _ in range(10)] agentset = AgentSet(agents, model) @@ -177,6 +197,7 @@ def test_agentset_get_item(): def test_agentset_do_str(): + """Test AgentSet.do with str.""" model = Model() agents = [TestAgent(model) for _ in range(10)] agentset = AgentSet(agents, model) @@ -211,6 +232,7 @@ def test_agentset_do_str(): def test_agentset_do_callable(): + """Test AgentSet.do with callable.""" model = Model() agents = [TestAgent(model) for _ in range(10)] agentset = AgentSet(agents, model) @@ -277,6 +299,7 @@ def remove_function(agent): def test_agentset_get(): + """Test AgentSet.get.""" model = Model() _ = [TestAgent(i, model) for i in range(10)] @@ -322,6 +345,7 @@ def test_agentset_get(): def test_agentset_agg(): + """Test agentset.agg.""" model = Model() agents = [TestAgent(model) for i in range(10)] @@ -357,6 +381,8 @@ def custom_func(values): def test_agentset_set_method(): + """Test AgentSet.set.""" + # Initialize the model and agents with and without existing attributes class TestAgentWithAttribute(Agent): def __init__(self, model, age=None): @@ -381,6 +407,7 @@ def __init__(self, model, age=None): def test_agentset_map_str(): + """Test AgentSet.map with strings.""" model = Model() agents = [TestAgent(model) for _ in range(10)] agentset = AgentSet(agents, model) @@ -393,6 +420,7 @@ def test_agentset_map_str(): def test_agentset_map_callable(): + """Test AgentSet.map with callable.""" model = Model() agents = [TestAgent(model) for _ in range(10)] agentset = AgentSet(agents, model) @@ -410,6 +438,7 @@ def test_agentset_map_callable(): def test_agentset_get_attribute(): + """Test AgentSet.get for attributes.""" model = Model() agents = [TestAgent(model) for _ in range(10)] agentset = AgentSet(agents, model) @@ -440,11 +469,15 @@ def test_agentset_get_attribute(): class OtherAgentType(Agent): + """Another Agent class for testing.""" + def get_unique_identifier(self): + """Return unique identifier.""" return self.unique_id def test_agentset_select_by_type(): + """Test AgentSet.select for agent type.""" model = Model() # Create a mix of agents of two different types test_agents = [TestAgent(model) for _ in range(4)] @@ -471,6 +504,7 @@ def test_agentset_select_by_type(): def test_agentset_shuffle(): + """Test AgentSet.shuffle.""" model = Model() test_agents = [TestAgent(model) for _ in range(12)] @@ -484,6 +518,8 @@ def test_agentset_shuffle(): def test_agentset_groupby(): + """Test AgentSet.groupby.""" + class TestAgent(Agent): def __init__(self, model): super().__init__(model) @@ -527,7 +563,9 @@ def get_unique_identifier(self): def test_oldstyle_agent_instantiation(): """Old behavior of Agent creation with unique_id and model as positional arguments. - Can be removed/updated in the future.""" + + Can be removed/updated in the future. + """ model = Model() agent = Agent("some weird unique id", model) assert isinstance(agent.unique_id, int) diff --git a/tests/test_batch_run.py b/tests/test_batch_run.py index 784c041679a..ee18d0677c9 100644 --- a/tests/test_batch_run.py +++ b/tests/test_batch_run.py @@ -1,3 +1,5 @@ +"""Test Batchrunner.""" + import mesa from mesa.agent import Agent from mesa.batchrunner import _make_model_kwargs @@ -6,7 +8,7 @@ from mesa.time import BaseScheduler -def test_make_model_kwargs(): +def test_make_model_kwargs(): # noqa: D103 assert _make_model_kwargs({"a": 3, "b": 5}) == [{"a": 3, "b": 5}] assert _make_model_kwargs({"a": 3, "b": range(3)}) == [ {"a": 3, "b": 0}, @@ -24,24 +26,26 @@ def test_make_model_kwargs(): class MockAgent(Agent): - """ - Minimalistic agent implementation for testing purposes - """ + """Minimalistic agent implementation for testing purposes.""" def __init__(self, model, val): + """Initialize a MockAgent. + + Args: + model: a model instance + val: a value for attribute + """ super().__init__(model) self.val = val self.local = 0 - def step(self): + def step(self): # noqa: D102 self.val += 1 self.local += 0.25 class MockModel(Model): - """ - Minimalistic model for testing purposes - """ + """Minimalistic model for testing purposes.""" def __init__( self, @@ -53,6 +57,17 @@ def __init__( n_agents=3, **kwargs, ): + """Initialize a MockModel. + + Args: + variable_model_param: variable model parameters + variable_agent_param: variable agent parameters + fixed_model_param: fixed model parameters + schedule: schedule instance + enable_agent_reporters: whether to enable agent reporters + n_agents: number of agents + kwargs: keyword arguments + """ super().__init__() self.schedule = BaseScheduler(self) if schedule is None else schedule self.variable_model_param = variable_model_param @@ -71,6 +86,7 @@ def __init__( self.init_agents() def init_agents(self): + """Initialize agents.""" if self.variable_agent_param is None: agent_val = 1 else: @@ -78,15 +94,15 @@ def init_agents(self): for _ in range(self.n_agents): self.schedule.add(MockAgent(self, agent_val)) - def get_local_model_param(self): + def get_local_model_param(self): # noqa: D102 return 42 - def step(self): + def step(self): # noqa: D102 self.schedule.step() self.datacollector.collect(self) -def test_batch_run(): +def test_batch_run(): # noqa: D103 result = mesa.batch_run(MockModel, {}, number_processes=2) assert result == [ { @@ -119,7 +135,7 @@ def test_batch_run(): ] -def test_batch_run_with_params(): +def test_batch_run_with_params(): # noqa: D103 mesa.batch_run( MockModel, { @@ -130,7 +146,7 @@ def test_batch_run_with_params(): ) -def test_batch_run_no_agent_reporters(): +def test_batch_run_no_agent_reporters(): # noqa: D103 result = mesa.batch_run( MockModel, {"enable_agent_reporters": False}, number_processes=2 ) @@ -146,11 +162,11 @@ def test_batch_run_no_agent_reporters(): ] -def test_batch_run_single_core(): +def test_batch_run_single_core(): # noqa: D103 mesa.batch_run(MockModel, {}, number_processes=1, iterations=6) -def test_batch_run_unhashable_param(): +def test_batch_run_unhashable_param(): # noqa: D103 result = mesa.batch_run( MockModel, { diff --git a/tests/test_cell_space.py b/tests/test_cell_space.py index 4127847d17b..6d88a90baf1 100644 --- a/tests/test_cell_space.py +++ b/tests/test_cell_space.py @@ -1,3 +1,5 @@ +"""Test cell spaces.""" + import random import pytest @@ -16,6 +18,7 @@ def test_orthogonal_grid_neumann(): + """Test orthogonal grid with von Neumann neighborhood.""" width = 10 height = 10 grid = OrthogonalVonNeumannGrid((width, height), torus=False, capacity=None) @@ -67,6 +70,7 @@ def test_orthogonal_grid_neumann(): def test_orthogonal_grid_neumann_3d(): + """Test 3D orthogonal grid with von Neumann neighborhood.""" width = 10 height = 10 depth = 10 @@ -130,6 +134,7 @@ def test_orthogonal_grid_neumann_3d(): def test_orthogonal_grid_moore(): + """Test orthogonal grid with Moore neighborhood.""" width = 10 height = 10 @@ -160,6 +165,7 @@ def test_orthogonal_grid_moore(): def test_orthogonal_grid_moore_3d(): + """Test 3D orthogonal grid with Moore neighborhood.""" width = 10 height = 10 depth = 10 @@ -199,6 +205,7 @@ def test_orthogonal_grid_moore_3d(): def test_orthogonal_grid_moore_4d(): + """Test 4D orthogonal grid with Moore neighborhood.""" width = 10 height = 10 depth = 10 @@ -243,6 +250,7 @@ def test_orthogonal_grid_moore_4d(): def test_orthogonal_grid_moore_1d(): + """Test 1D orthogonal grid with Moore neighborhood.""" width = 10 # Moore neighborhood, torus false, left edge @@ -264,6 +272,7 @@ def test_orthogonal_grid_moore_1d(): def test_cell_neighborhood(): + """Test neighborhood method of cell in different GridSpaces.""" # orthogonal grid ## von Neumann @@ -304,6 +313,7 @@ def test_cell_neighborhood(): def test_hexgrid(): + """Test HexGrid.""" width = 10 height = 10 @@ -358,6 +368,7 @@ def test_hexgrid(): def test_networkgrid(): + """Test NetworkGrid.""" import networkx as nx n = 10 @@ -374,6 +385,7 @@ def test_networkgrid(): def test_voronoigrid(): + """Test VoronoiGrid.""" points = [[0, 1], [1, 3], [1.1, 1], [1, 1]] grid = VoronoiGrid(points) @@ -396,6 +408,7 @@ def test_voronoigrid(): def test_empties_space(): + """Test empties method for Discrete Spaces.""" import networkx as nx n = 10 @@ -415,6 +428,7 @@ def test_empties_space(): def test_cell(): + """Test Cell class.""" cell1 = Cell((1,), capacity=None, random=random.Random()) cell2 = Cell((2,), capacity=None, random=random.Random()) @@ -453,6 +467,7 @@ def test_cell(): def test_cell_collection(): + """Test CellCollection.""" cell1 = Cell((1,), capacity=None, random=random.Random()) collection = CellCollection({cell1: cell1.agents}, random=random.Random()) diff --git a/tests/test_datacollector.py b/tests/test_datacollector.py index 94e7742afab..88792554737 100644 --- a/tests/test_datacollector.py +++ b/tests/test_datacollector.py @@ -1,6 +1,4 @@ -""" -Test the DataCollector -""" +"""Test the DataCollector.""" import unittest @@ -9,45 +7,37 @@ class MockAgent(Agent): - """ - Minimalistic agent for testing purposes. - """ + """Minimalistic agent for testing purposes.""" - def __init__(self, model, val=0): + def __init__(self, model, val=0): # noqa: D107 super().__init__(model) self.val = val self.val2 = val - def step(self): - """ - Increment vals by 1. - """ + def step(self): # D103 + """Increment vals by 1.""" self.val += 1 self.val2 += 1 - def double_val(self): + def double_val(self): # noqa: D102 return self.val * 2 - def write_final_values(self): - """ - Write the final value to the appropriate table. - """ + def write_final_values(self): # D103 + """Write the final value to the appropriate table.""" row = {"agent_id": self.unique_id, "final_value": self.val} self.model.datacollector.add_table_row("Final_Values", row) -def agent_function_with_params(agent, multiplier, offset): +def agent_function_with_params(agent, multiplier, offset): # noqa: D103 return (agent.val * multiplier) + offset class MockModel(Model): - """ - Minimalistic model for testing purposes. - """ + """Minimalistic model for testing purposes.""" schedule = BaseScheduler(None) - def __init__(self): + def __init__(self): # noqa: D107 super().__init__() self.schedule = BaseScheduler(self) self.model_val = 100 @@ -72,23 +62,23 @@ def __init__(self): tables={"Final_Values": ["agent_id", "final_value"]}, ) - def test_model_calc_comp(self, input1, input2): + def test_model_calc_comp(self, input1, input2): # noqa: D102 if input2 > 0: return (self.model_val * input1) / input2 else: assert ValueError return None - def step(self): + def step(self): # noqa: D102 self.schedule.step() self.datacollector.collect(self) class TestDataCollector(unittest.TestCase): + """Tests for DataCollector.""" + def setUp(self): - """ - Create the model and run it a set number of steps. - """ + """Create the model and run it a set number of steps.""" self.model = MockModel() for i in range(7): if i == 4: @@ -99,7 +89,7 @@ def setUp(self): for agent in self.model.schedule.agents: agent.write_final_values() - def step_assertion(self, model_var): + def step_assertion(self, model_var): # noqa: D102 for element in model_var: if model_var.index(element) < 4: assert element == 10 @@ -107,9 +97,7 @@ def step_assertion(self, model_var): assert element == 9 def test_model_vars(self): - """ - Test model-level variable collection. - """ + """Test model-level variable collection.""" data_collector = self.model.datacollector assert "total_agents" in data_collector.model_vars assert "model_value" in data_collector.model_vars @@ -131,9 +119,7 @@ def test_model_vars(self): assert element is None def test_agent_records(self): - """ - Test agent-level variable collection. - """ + """Test agent-level variable collection.""" data_collector = self.model.datacollector agent_table = data_collector.get_agent_vars_dataframe() @@ -168,9 +154,7 @@ def test_agent_records(self): data_collector._agent_records[8] def test_table_rows(self): - """ - Test table collection - """ + """Test table collection.""" data_collector = self.model.datacollector assert len(data_collector.tables["Final_Values"]) == 2 assert "agent_id" in data_collector.tables["Final_Values"] @@ -185,9 +169,7 @@ def test_table_rows(self): data_collector.add_table_row("Final_Values", {"final_value": 10}) def test_exports(self): - """ - Test DataFrame exports - """ + """Test DataFrame exports.""" data_collector = self.model.datacollector model_vars = data_collector.get_model_vars_dataframe() agent_vars = data_collector.get_agent_vars_dataframe() @@ -201,10 +183,12 @@ def test_exports(self): class TestDataCollectorInitialization(unittest.TestCase): - def setUp(self): + """Tests for DataCollector initialization.""" + + def setUp(self): # noqa: D102 self.model = Model() - def test_initialize_before_scheduler(self): + def test_initialize_before_scheduler(self): # noqa: D102 with self.assertRaises(RuntimeError) as cm: self.model.initialize_data_collector() self.assertEqual( @@ -212,7 +196,7 @@ def test_initialize_before_scheduler(self): "You must initialize the scheduler (self.schedule) before initializing the data collector.", ) - def test_initialize_before_agents_added_to_scheduler(self): + def test_initialize_before_agents_added_to_scheduler(self): # noqa: D102 with self.assertRaises(RuntimeError) as cm: self.model.schedule = BaseScheduler(self) self.model.initialize_data_collector() diff --git a/tests/test_devs.py b/tests/test_devs.py index 06d3629ed16..8f1dd9373fd 100644 --- a/tests/test_devs.py +++ b/tests/test_devs.py @@ -1,3 +1,5 @@ +"""Tests for experimental Simulator classes.""" + from unittest.mock import MagicMock import pytest @@ -8,6 +10,7 @@ def test_devs_simulator(): + """Tests devs simulator.""" simulator = DEVSimulator() # setup @@ -69,6 +72,7 @@ def test_devs_simulator(): def test_abm_simulator(): + """Tests abm simulator.""" simulator = ABMSimulator() # setup @@ -86,6 +90,7 @@ def test_abm_simulator(): def test_simulation_event(): + """Tests for SimulationEvent class.""" some_test_function = MagicMock() time = 10 @@ -199,6 +204,7 @@ def some_test_function(x, y): def test_eventlist(): + """Tests for EventList.""" event_list = EventList() assert len(event_list._events) == 0 diff --git a/tests/test_examples.py b/tests/test_examples.py index 25b8e0e07b9..7318665dcf8 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,3 +1,4 @@ +# noqa: D100 import contextlib import importlib import os.path @@ -5,7 +6,7 @@ import unittest -def classcase(name): +def classcase(name): # noqa: D103 return "".join(x.capitalize() for x in name.replace("-", "_").split("_")) @@ -13,8 +14,9 @@ def classcase(name): "Skipping TextExamples, because examples folder was moved. More discussion needed." ) class TestExamples(unittest.TestCase): - """ - Test examples' models. This creates a model object and iterates it through + """Test examples' models. + + This creates a model object and iterates it through some steps. The idea is to get code coverage, rather than to test the details of each example's model. """ @@ -23,7 +25,7 @@ class TestExamples(unittest.TestCase): @contextlib.contextmanager def active_example_dir(self, example): - "save and restore sys.path and sys.modules" + """Save and restore sys.path and sys.modules.""" old_sys_path = sys.path[:] old_sys_modules = sys.modules.copy() old_cwd = os.getcwd() @@ -40,7 +42,7 @@ def active_example_dir(self, example): sys.modules.update(old_sys_modules) sys.path[:] = old_sys_path - def test_examples(self): + def test_examples(self): # noqa: D102 for example in os.listdir(self.EXAMPLES): if not os.path.isdir(os.path.join(self.EXAMPLES, example)): continue diff --git a/tests/test_grid.py b/tests/test_grid.py index be68b1ffe2f..c284a1b3755 100644 --- a/tests/test_grid.py +++ b/tests/test_grid.py @@ -1,6 +1,4 @@ -""" -Test the Grid objects. -""" +"""Test the Grid objects.""" import random import unittest @@ -21,27 +19,21 @@ class MockAgent: - """ - Minimalistic agent for testing purposes. - """ + """Minimalistic agent for testing purposes.""" - def __init__(self, unique_id): + def __init__(self, unique_id): # noqa: D107 self.random = random.Random(0) self.unique_id = unique_id self.pos = None class TestSingleGrid(unittest.TestCase): - """ - Testing a non-toroidal singlegrid. - """ + """Testing a non-toroidal singlegrid.""" torus = False def setUp(self): - """ - Create a test non-toroidal grid and populate it with Mock Agents - """ + """Create a test non-toroidal grid and populate it with Mock Agents.""" # The height needs to be even to test the edge case described in PR #1517 height = 6 # height of grid width = 3 # width of grid @@ -59,24 +51,20 @@ def setUp(self): self.grid.place_agent(a, (x, y)) def test_agent_positions(self): - """ - Ensure that the agents are all placed properly. - """ + """Ensure that the agents are all placed properly.""" for agent in self.agents: x, y = agent.pos assert self.grid[x][y] == agent def test_cell_agent_reporting(self): - """ - Ensure that if an agent is in a cell, get_cell_list_contents accurately - reports that fact. - """ + """Ensure that if an agent is in a cell, get_cell_list_contents accurately reports that fact.""" for agent in self.agents: x, y = agent.pos assert agent in self.grid.get_cell_list_contents([(x, y)]) def test_listfree_cell_agent_reporting(self): - """ + """Test if agent is correctly tracked in cell. + Ensure that if an agent is in a cell, get_cell_list_contents accurately reports that fact, even when single position is not wrapped in a list. """ @@ -85,16 +73,14 @@ def test_listfree_cell_agent_reporting(self): assert agent in self.grid.get_cell_list_contents((x, y)) def test_iter_cell_agent_reporting(self): - """ - Ensure that if an agent is in a cell, iter_cell_list_contents - accurately reports that fact. - """ + """Ensure that if an agent is in a cell, iter_cell_list_contents accurately reports that fact.""" for agent in self.agents: x, y = agent.pos assert agent in self.grid.iter_cell_list_contents([(x, y)]) def test_listfree_iter_cell_agent_reporting(self): - """ + """Test if agent is correctly tracked in cell in iterator. + Ensure that if an agent is in a cell, iter_cell_list_contents accurately reports that fact, even when single position is not wrapped in a list. @@ -104,10 +90,7 @@ def test_listfree_iter_cell_agent_reporting(self): assert agent in self.grid.iter_cell_list_contents((x, y)) def test_neighbors(self): - """ - Test the base neighborhood methods on the non-toroid. - """ - + """Test the base neighborhood methods on the non-toroid.""" neighborhood = self.grid.get_neighborhood((1, 1), moore=True) assert len(neighborhood) == 8 @@ -129,7 +112,7 @@ def test_neighbors(self): neighbors = self.grid.get_neighbors((1, 3), moore=False, radius=2) assert len(neighbors) == 3 - def test_coord_iter(self): + def test_coord_iter(self): # noqa: D102 ci = self.grid.coord_iter() # no agent in first space @@ -143,7 +126,7 @@ def test_coord_iter(self): assert second[0].pos == (0, 1) assert second[1] == (0, 1) - def test_agent_move(self): + def test_agent_move(self): # noqa: D102 # get the agent at [0, 1] agent = self.agents[0] self.grid.move_agent(agent, (1, 0)) @@ -160,14 +143,14 @@ def test_agent_move(self): self.grid.move_agent(agent, [1, self.grid.height]) assert agent.pos == (1, 0) - def test_agent_remove(self): + def test_agent_remove(self): # noqa: D102 agent = self.agents[0] x, y = agent.pos self.grid.remove_agent(agent) assert agent.pos is None assert self.grid[x][y] is None - def test_swap_pos(self): + def test_swap_pos(self): # noqa: D102 # Swap agents positions agent_a, agent_b = list(filter(None, self.grid))[:2] pos_a = agent_a.pos @@ -198,17 +181,12 @@ def test_swap_pos(self): class TestSingleGridTorus(TestSingleGrid): - """ - Testing the toroidal singlegrid. - """ + """Testing the toroidal singlegrid.""" torus = True def test_neighbors(self): - """ - Test the toroidal neighborhood methods. - """ - + """Test the toroidal neighborhood methods.""" neighborhood = self.grid.get_neighborhood((1, 1), moore=True) assert len(neighborhood) == 8 @@ -240,14 +218,10 @@ def test_neighbors(self): class TestSingleGridEnforcement(unittest.TestCase): - """ - Test the enforcement in SingleGrid. - """ + """Test the enforcement in SingleGrid.""" def setUp(self): - """ - Create a test non-toroidal grid and populate it with Mock Agents - """ + """Create a test non-toroidal grid and populate it with Mock Agents.""" width = 3 height = 5 self.grid = SingleGrid(width, height, True) @@ -266,10 +240,7 @@ def setUp(self): @patch.object(MockAgent, "model", create=True) def test_enforcement(self, mock_model): - """ - Test the SingleGrid empty count and enforcement. - """ - + """Test the SingleGrid empty count and enforcement.""" assert len(self.grid.empties) == 9 a = MockAgent(100) with self.assertRaises(Exception): @@ -315,16 +286,12 @@ def test_enforcement(self, mock_model): class TestMultiGrid(unittest.TestCase): - """ - Testing a toroidal MultiGrid - """ + """Testing a toroidal MultiGrid.""" torus = True def setUp(self): - """ - Create a test non-toroidal grid and populate it with Mock Agents - """ + """Create a test non-toroidal grid and populate it with Mock Agents.""" width = 3 height = 5 self.grid = MultiGrid(width, height, self.torus) @@ -340,18 +307,13 @@ def setUp(self): self.grid.place_agent(a, (x, y)) def test_agent_positions(self): - """ - Ensure that the agents are all placed properly on the MultiGrid. - """ + """Ensure that the agents are all placed properly on the MultiGrid.""" for agent in self.agents: x, y = agent.pos assert agent in self.grid[x][y] def test_neighbors(self): - """ - Test the toroidal MultiGrid neighborhood methods. - """ - + """Test the toroidal MultiGrid neighborhood methods.""" neighborhood = self.grid.get_neighborhood((1, 1), moore=True) assert len(neighborhood) == 8 @@ -375,14 +337,10 @@ def test_neighbors(self): class TestHexSingleGrid(unittest.TestCase): - """ - Testing a hexagonal singlegrid. - """ + """Testing a hexagonal singlegrid.""" def setUp(self): - """ - Create a test non-toroidal grid and populate it with Mock Agents - """ + """Create a test non-toroidal grid and populate it with Mock Agents.""" width = 3 height = 5 self.grid = HexSingleGrid(width, height, torus=False) @@ -399,9 +357,7 @@ def setUp(self): self.grid.place_agent(a, (x, y)) def test_neighbors(self): - """ - Test the hexagonal neighborhood methods on the non-toroid. - """ + """Test the hexagonal neighborhood methods on the non-toroid.""" neighborhood = self.grid.get_neighborhood((1, 1)) assert len(neighborhood) == 6 @@ -429,14 +385,10 @@ def test_neighbors(self): class TestHexSingleGridTorus(TestSingleGrid): - """ - Testing a hexagonal toroidal singlegrid. - """ + """Testing a hexagonal toroidal singlegrid.""" def setUp(self): - """ - Create a test non-toroidal grid and populate it with Mock Agents - """ + """Create a test non-toroidal grid and populate it with Mock Agents.""" width = 3 height = 5 self.grid = HexSingleGrid(width, height, torus=True) @@ -453,9 +405,7 @@ def setUp(self): self.grid.place_agent(a, (x, y)) def test_neighbors(self): - """ - Test the hexagonal neighborhood methods on the toroid. - """ + """Test the hexagonal neighborhood methods on the toroid.""" neighborhood = self.grid.get_neighborhood((1, 1)) assert len(neighborhood) == 6 @@ -476,27 +426,27 @@ def test_neighbors(self): assert sum(x + y for x, y in neighborhood) == 45 -class TestIndexing: +class TestIndexing: # noqa: D101 # Create a grid where the content of each coordinate is a tuple of its coordinates grid = SingleGrid(3, 5, True) for _, pos in grid.coord_iter(): x, y = pos grid._grid[x][y] = pos - def test_int(self): + def test_int(self): # noqa: D102 assert self.grid[0][0] == (0, 0) - def test_tuple(self): + def test_tuple(self): # noqa: D102 assert self.grid[1, 1] == (1, 1) - def test_list(self): + def test_list(self): # noqa: D102 assert self.grid[(0, 0), (1, 1)] == [(0, 0), (1, 1)] assert self.grid[(0, 0), (5, 3)] == [(0, 0), (2, 3)] - def test_torus(self): + def test_torus(self): # noqa: D102 assert self.grid[3, 5] == (0, 0) - def test_slice(self): + def test_slice(self): # noqa: D102 assert self.grid[:, 0] == [(0, 0), (1, 0), (2, 0)] assert self.grid[::-1, 0] == [(2, 0), (1, 0), (0, 0)] assert self.grid[1, :] == [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4)] diff --git a/tests/test_import_namespace.py b/tests/test_import_namespace.py index 5e9679e6abb..449727b95d8 100644 --- a/tests/test_import_namespace.py +++ b/tests/test_import_namespace.py @@ -1,6 +1,11 @@ +"""Test if namespsaces importing work better.""" + + def test_import(): - # This tests the new, simpler Mesa namespace. See - # https://github.com/projectmesa/mesa/pull/1294. + """This tests the new, simpler Mesa namespace. + + See https://github.com/projectmesa/mesa/pull/1294. + """ import mesa from mesa.time import RandomActivation diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py index 00dd878c8e9..00e830bf645 100644 --- a/tests/test_lifespan.py +++ b/tests/test_lifespan.py @@ -1,3 +1,5 @@ +"""Test removal of agents.""" + import unittest import numpy as np @@ -8,10 +10,10 @@ class LifeTimeModel(Model): - """Simple model for running models with a finite life""" + """Simple model for running models with a finite life.""" - def __init__(self, agent_lifetime=1, n_agents=10): - super().__init__() + def __init__(self, agent_lifetime=1, n_agents=10, seed=None): # noqa: D107 + super().__init__(seed=seed) self.agent_lifetime = agent_lifetime self.n_agents = n_agents @@ -32,7 +34,7 @@ def __init__(self, agent_lifetime=1, n_agents=10): self.schedule.add(FiniteLifeAgent(self.agent_lifetime, self)) def step(self): - """Add agents back to n_agents in each step""" + """Add agents back to n_agents in each step.""" self.datacollector.collect(self) self.schedule.step() @@ -40,30 +42,31 @@ def step(self): for _ in range(self.n_agents - len(self.schedule.agents)): self.schedule.add(FiniteLifeAgent(self.agent_lifetime, self)) - def run_model(self, step_count=100): + def run_model(self, step_count=100): # noqa: D102 for _ in range(step_count): self.step() class FiniteLifeAgent(Agent): """An agent that is supposed to live for a finite number of ticks. + Also has a 10% chance of dying in each tick. """ - def __init__(self, lifetime, model): + def __init__(self, lifetime, model): # noqa: D107 super().__init__(model) self.remaining_life = lifetime self.steps = 0 self.model = model - def step(self): + def step(self): # noqa: D102 inactivated = self.inactivate() if not inactivated: self.steps += 1 # keep track of how many ticks are seen if np.random.binomial(1, 0.1) != 0: # 10% chance of dying self.model.schedule.remove(self) - def inactivate(self): + def inactivate(self): # noqa: D102 self.remaining_life -= 1 if self.remaining_life < 0: self.model.schedule.remove(self) @@ -71,18 +74,18 @@ def inactivate(self): return False -class TestAgentLifespan(unittest.TestCase): - def setUp(self): +class TestAgentLifespan(unittest.TestCase): # noqa: D101 + def setUp(self): # noqa: D102 self.model = LifeTimeModel() self.model.run_model() self.df = self.model.datacollector.get_agent_vars_dataframe() self.df = self.df.reset_index() def test_ticks_seen(self): - """Each agent should be activated no more than one time""" + """Each agent should be activated no more than one time.""" assert self.df.steps.max() == 1 - def test_agent_lifetime(self): + def test_agent_lifetime(self): # noqa: D102 lifetimes = self.df.groupby(["AgentID"]).agg({"Step": len}) assert lifetimes.Step.max() == 2 diff --git a/tests/test_model.py b/tests/test_model.py index 016185e6084..c7d17bf806d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,8 +1,11 @@ +"""Tests for model.py.""" + from mesa.agent import Agent, AgentSet from mesa.model import Model def test_model_set_up(): + """Test Model initialization.""" model = Model() assert model.running is True assert model.schedule is None @@ -12,6 +15,8 @@ def test_model_set_up(): def test_running(): + """Test Model is running.""" + class TestModel(Model): steps = 0 @@ -26,6 +31,7 @@ def step(self): def test_seed(seed=23): + """Test initialization of model with specific seed.""" model = Model(seed=seed) assert model._seed == seed model2 = Model(seed=seed + 1) @@ -34,6 +40,7 @@ def test_seed(seed=23): def test_reset_randomizer(newseed=42): + """Test resetting the random seed on the model.""" model = Model() oldseed = model._seed model.reset_randomizer() @@ -43,6 +50,8 @@ def test_reset_randomizer(newseed=42): def test_agent_types(): + """Test Mode.agent_types property.""" + class TestAgent(Agent): pass @@ -53,6 +62,8 @@ class TestAgent(Agent): def test_agents_by_type(): + """Test getting agents by type from Model.""" + class Wolf(Agent): pass diff --git a/tests/test_scaffold.py b/tests/test_scaffold.py index 627552eb0eb..4d711b8b5a4 100644 --- a/tests/test_scaffold.py +++ b/tests/test_scaffold.py @@ -1,3 +1,4 @@ +# noqa: D100 import os import unittest @@ -7,15 +8,13 @@ class ScaffoldTest(unittest.TestCase): - """ - Test mesa project scaffolding command - """ + """Test mesa project scaffolding command.""" @classmethod - def setUpClass(cls): + def setUpClass(cls): # noqa: D102 cls.runner = CliRunner() - def test_scaffold_creates_project_dir(self): + def test_scaffold_creates_project_dir(self): # noqa: D102 with self.runner.isolated_filesystem(): assert not os.path.isdir("example_project") self.runner.invoke(cli, ["startproject", "--no-input"]) diff --git a/tests/test_solara_viz.py b/tests/test_solara_viz.py index cd5ffe1b5dd..798777bb5eb 100644 --- a/tests/test_solara_viz.py +++ b/tests/test_solara_viz.py @@ -1,3 +1,5 @@ +"""Test Solara visualizations.""" + import unittest from unittest.mock import Mock @@ -9,8 +11,8 @@ from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs -class TestMakeUserInput(unittest.TestCase): - def test_unsupported_type(self): +class TestMakeUserInput(unittest.TestCase): # noqa: D101 + def test_unsupported_type(self): # noqa: D102 @solara.component def Test(user_params): UserInputs(user_params) @@ -24,7 +26,7 @@ def Test(user_params): with self.assertRaisesRegex(ValueError, "not a supported input type"): solara.render(Test({"mock": {}}), handle_error=False) - def test_slider_int(self): + def test_slider_int(self): # noqa: D102 @solara.component def Test(user_params): UserInputs(user_params) @@ -47,7 +49,7 @@ def Test(user_params): assert slider_int.max == options["max"] assert slider_int.step == options["step"] - def test_checkbox(self): + def test_checkbox(self): # noqa: D102 @solara.component def Test(user_params): UserInputs(user_params) @@ -61,7 +63,7 @@ def Test(user_params): assert checkbox.label == options["label"] def test_label_fallback(self): - """name should be used as fallback label""" + """Name should be used as fallback label.""" @solara.component def Test(user_params): @@ -83,7 +85,7 @@ def Test(user_params): assert slider_int.step is None -def test_call_space_drawer(mocker): +def test_call_space_drawer(mocker): # noqa: D103 mock_space_matplotlib = mocker.patch( "mesa.visualization.components.matplotlib.SpaceMatplotlib" ) @@ -122,7 +124,7 @@ def test_call_space_drawer(mocker): ) -def test_slider(): +def test_slider(): # noqa: D103 slider_float = Slider("Agent density", 0.8, 0.1, 1.0, 0.1) assert slider_float.is_float_slider assert slider_float.value == 0.8 diff --git a/tests/test_space.py b/tests/test_space.py index 2da9ee2acb7..d8d537bcf31 100644 --- a/tests/test_space.py +++ b/tests/test_space.py @@ -1,3 +1,5 @@ +"""Test spaces.""" + import unittest import networkx as nx @@ -26,20 +28,14 @@ @pytest.mark.skip(reason="a perf test will slow down the CI") class TestSpacePerformance(unittest.TestCase): - """ - Testing adding many agents for a continuous space. - """ + """Testing adding many agents for a continuous space.""" def setUp(self): - """ - Create a test space and populate with Mock Agents. - """ + """Create a test space and populate with Mock Agents.""" self.space = ContinuousSpace(10, 10, True, -10, -10) def test_agents_add_many(self): - """ - Add many agents - """ + """Add many agents.""" positions = np.random.rand(TEST_AGENTS_PERF, 2) for i in range(TEST_AGENTS_PERF): a = MockAgent(i) @@ -48,14 +44,10 @@ def test_agents_add_many(self): class TestSpaceToroidal(unittest.TestCase): - """ - Testing a toroidal continuous space. - """ + """Testing a toroidal continuous space.""" def setUp(self): - """ - Create a test space and populate with Mock Agents. - """ + """Create a test space and populate with Mock Agents.""" self.space = ContinuousSpace(70, 20, True, -30, -30) self.agents = [] for i, pos in enumerate(TEST_AGENTS): @@ -64,25 +56,19 @@ def setUp(self): self.space.place_agent(a, pos) def test_agent_positions(self): - """ - Ensure that the agents are all placed properly. - """ + """Ensure that the agents are all placed properly.""" for i, pos in enumerate(TEST_AGENTS): a = self.agents[i] assert a.pos == pos def test_agent_matching(self): - """ - Ensure that the agents are all placed and indexed properly. - """ + """Ensure that the agents are all placed and indexed properly.""" for i, agent in self.space._index_to_agent.items(): assert agent.pos == tuple(self.space._agent_points[i, :]) assert i == self.space._agent_to_index[agent] def test_distance_calculations(self): - """ - Test toroidal distance calculations. - """ + """Test toroidal distance calculations.""" pos_1 = (-30, -30) pos_2 = (70, 20) assert self.space.get_distance(pos_1, pos_2) == 0 @@ -98,7 +84,7 @@ def test_distance_calculations(self): pos_7 = (21, -5) assert self.space.get_distance(pos_6, pos_7) == np.sqrt(49**2 + 24**2) - def test_heading(self): + def test_heading(self): # noqa: D102 pos_1 = (-30, -30) pos_2 = (70, 20) self.assertEqual((0, 0), self.space.get_heading(pos_1, pos_2)) @@ -108,9 +94,7 @@ def test_heading(self): self.assertEqual((10, 0), self.space.get_heading(pos_1, pos_2)) def test_neighborhood_retrieval(self): - """ - Test neighborhood retrieval - """ + """Test neighborhood retrieval.""" neighbors_1 = self.space.get_neighbors((-20, -20), 1) assert len(neighbors_1) == 2 @@ -121,9 +105,7 @@ def test_neighborhood_retrieval(self): assert len(neighbors_3) == 1 def test_bounds(self): - """ - Test positions outside of boundary - """ + """Test positions outside of boundary.""" boundary_agents = [] for i, pos in enumerate(OUTSIDE_POSITIONS): a = MockAgent(len(self.agents) + i) @@ -141,14 +123,10 @@ def test_bounds(self): class TestSpaceNonToroidal(unittest.TestCase): - """ - Testing a toroidal continuous space. - """ + """Testing a toroidal continuous space.""" def setUp(self): - """ - Create a test space and populate with Mock Agents. - """ + """Create a test space and populate with Mock Agents.""" self.space = ContinuousSpace(70, 20, False, -30, -30) self.agents = [] for i, pos in enumerate(TEST_AGENTS): @@ -157,31 +135,24 @@ def setUp(self): self.space.place_agent(a, pos) def test_agent_positions(self): - """ - Ensure that the agents are all placed properly. - """ + """Ensure that the agents are all placed properly.""" for i, pos in enumerate(TEST_AGENTS): a = self.agents[i] assert a.pos == pos def test_agent_matching(self): - """ - Ensure that the agents are all placed and indexed properly. - """ + """Ensure that the agents are all placed and indexed properly.""" for i, agent in self.space._index_to_agent.items(): assert agent.pos == tuple(self.space._agent_points[i, :]) assert i == self.space._agent_to_index[agent] def test_distance_calculations(self): - """ - Test toroidal distance calculations. - """ - + """Test toroidal distance calculations.""" pos_2 = (70, 20) pos_3 = (-30, -20) assert self.space.get_distance(pos_2, pos_3) == 107.70329614269008 - def test_heading(self): + def test_heading(self): # noqa: D102 pos_1 = (-30, -30) pos_2 = (70, 20) self.assertEqual((100, 50), self.space.get_heading(pos_1, pos_2)) @@ -191,9 +162,7 @@ def test_heading(self): self.assertEqual((-90, 0), self.space.get_heading(pos_1, pos_2)) def test_neighborhood_retrieval(self): - """ - Test neighborhood retrieval - """ + """Test neighborhood retrieval.""" neighbors_1 = self.space.get_neighbors((-20, -20), 1) assert len(neighbors_1) == 2 @@ -204,9 +173,7 @@ def test_neighborhood_retrieval(self): assert len(neighbors_3) == 0 def test_bounds(self): - """ - Test positions outside of boundary - """ + """Test positions outside of boundary.""" for i, pos in enumerate(OUTSIDE_POSITIONS): a = MockAgent(len(self.agents) + i) with self.assertRaises(Exception): @@ -220,14 +187,10 @@ def test_bounds(self): class TestSpaceAgentMapping(unittest.TestCase): - """ - Testing a continuous space for agent mapping during removal. - """ + """Testing a continuous space for agent mapping during removal.""" def setUp(self): - """ - Create a test space and populate with Mock Agents. - """ + """Create a test space and populate with Mock Agents.""" self.space = ContinuousSpace(70, 50, False, -30, -30) self.agents = [] for i, pos in enumerate(REMOVAL_TEST_AGENTS): @@ -236,9 +199,7 @@ def setUp(self): self.space.place_agent(a, pos) def test_remove_first(self): - """ - Test removing the first entry - """ + """Test removing the first entry.""" agent_to_remove = self.agents[0] self.space.remove_agent(agent_to_remove) for i, agent in self.space._index_to_agent.items(): @@ -250,9 +211,7 @@ def test_remove_first(self): self.space.remove_agent(agent_to_remove) def test_remove_last(self): - """ - Test removing the last entry - """ + """Test removing the last entry.""" agent_to_remove = self.agents[-1] self.space.remove_agent(agent_to_remove) for i, agent in self.space._index_to_agent.items(): @@ -264,9 +223,7 @@ def test_remove_last(self): self.space.remove_agent(agent_to_remove) def test_remove_middle(self): - """ - Test removing a middle entry - """ + """Test removing a middle entry.""" agent_to_remove = self.agents[3] self.space.remove_agent(agent_to_remove) for i, agent in self.space._index_to_agent.items(): @@ -278,28 +235,28 @@ def test_remove_middle(self): self.space.remove_agent(agent_to_remove) -class TestPropertyLayer(unittest.TestCase): - def setUp(self): +class TestPropertyLayer(unittest.TestCase): # noqa: D101 + def setUp(self): # noqa: D102 self.layer = PropertyLayer("test_layer", 10, 10, 0, dtype=int) # Initialization Test - def test_initialization(self): + def test_initialization(self): # noqa: D102 self.assertEqual(self.layer.name, "test_layer") self.assertEqual(self.layer.width, 10) self.assertEqual(self.layer.height, 10) self.assertTrue(np.array_equal(self.layer.data, np.zeros((10, 10)))) # Set Cell Test - def test_set_cell(self): + def test_set_cell(self): # noqa: D102 self.layer.set_cell((5, 5), 1) self.assertEqual(self.layer.data[5, 5], 1) # Set Cells Tests - def test_set_cells_no_condition(self): + def test_set_cells_no_condition(self): # noqa: D102 self.layer.set_cells(2) np.testing.assert_array_equal(self.layer.data, np.full((10, 10), 2)) - def test_set_cells_with_condition(self): + def test_set_cells_with_condition(self): # noqa: D102 self.layer.set_cell((5, 5), 1) def condition(x): @@ -311,7 +268,7 @@ def condition(x): # Check if the sum is correct self.assertEqual(np.sum(self.layer.data), 3 * 99 + 1) - def test_set_cells_with_random_condition(self): + def test_set_cells_with_random_condition(self): # noqa: D102 # Probability for a cell to be updated update_probability = 0.5 @@ -336,73 +293,73 @@ def condition(val): assert expected_min <= true_count <= expected_max # Modify Cell Test - def test_modify_cell_lambda(self): + def test_modify_cell_lambda(self): # noqa: D102 self.layer.data = np.zeros((10, 10)) self.layer.modify_cell((2, 2), lambda x: x + 5) self.assertEqual(self.layer.data[2, 2], 5) - def test_modify_cell_ufunc(self): + def test_modify_cell_ufunc(self): # noqa: D102 self.layer.data = np.ones((10, 10)) self.layer.modify_cell((3, 3), np.add, 4) self.assertEqual(self.layer.data[3, 3], 5) - def test_modify_cell_invalid_operation(self): + def test_modify_cell_invalid_operation(self): # noqa: D102 with self.assertRaises(ValueError): self.layer.modify_cell((1, 1), np.add) # Missing value for ufunc # Modify Cells Test - def test_modify_cells_lambda(self): + def test_modify_cells_lambda(self): # noqa: D102 self.layer.data = np.zeros((10, 10)) self.layer.modify_cells(lambda x: x + 2) np.testing.assert_array_equal(self.layer.data, np.full((10, 10), 2)) - def test_modify_cells_ufunc(self): + def test_modify_cells_ufunc(self): # noqa: D102 self.layer.data = np.ones((10, 10)) self.layer.modify_cells(np.multiply, 3) np.testing.assert_array_equal(self.layer.data, np.full((10, 10), 3)) - def test_modify_cells_invalid_operation(self): + def test_modify_cells_invalid_operation(self): # noqa: D102 with self.assertRaises(ValueError): self.layer.modify_cells(np.add) # Missing value for ufunc # Aggregate Property Test - def test_aggregate_property_lambda(self): + def test_aggregate_property_lambda(self): # noqa: D102 self.layer.data = np.arange(100).reshape(10, 10) result = self.layer.aggregate_property(lambda x: np.sum(x)) self.assertEqual(result, np.sum(np.arange(100))) - def test_aggregate_property_ufunc(self): + def test_aggregate_property_ufunc(self): # noqa: D102 self.layer.data = np.full((10, 10), 2) result = self.layer.aggregate_property(np.mean) self.assertEqual(result, 2) # Edge Case: Negative or Zero Dimensions - def test_initialization_negative_dimensions(self): + def test_initialization_negative_dimensions(self): # noqa: D102 with self.assertRaises(ValueError): PropertyLayer("test_layer", -10, 10, 0, dtype=int) - def test_initialization_zero_dimensions(self): + def test_initialization_zero_dimensions(self): # noqa: D102 with self.assertRaises(ValueError): PropertyLayer("test_layer", 0, 10, 0, dtype=int) # Edge Case: Out-of-Bounds Cell Access - def test_set_cell_out_of_bounds(self): + def test_set_cell_out_of_bounds(self): # noqa: D102 with self.assertRaises(IndexError): self.layer.set_cell((10, 10), 1) - def test_modify_cell_out_of_bounds(self): + def test_modify_cell_out_of_bounds(self): # noqa: D102 with self.assertRaises(IndexError): self.layer.modify_cell((10, 10), lambda x: x + 5) # Edge Case: Selecting Cells with Complex Conditions - def test_select_cells_complex_condition(self): + def test_select_cells_complex_condition(self): # noqa: D102 self.layer.data = np.random.rand(10, 10) selected = self.layer.select_cells(lambda x: (x > 0.5) & (x < 0.75)) for c in selected: self.assertTrue(0.5 < self.layer.data[c] < 0.75) # More edge cases - def test_set_cells_with_numpy_ufunc(self): + def test_set_cells_with_numpy_ufunc(self): # noqa: D102 # Set some cells to a specific value self.layer.data[0:5, 0:5] = 5 @@ -419,26 +376,26 @@ def test_set_cells_with_numpy_ufunc(self): unchanged_cells = self.layer.data[5:, 5:] np.testing.assert_array_equal(unchanged_cells, np.zeros((5, 5))) - def test_modify_cell_boundary_condition(self): + def test_modify_cell_boundary_condition(self): # noqa: D102 self.layer.data = np.zeros((10, 10)) self.layer.modify_cell((0, 0), lambda x: x + 5) self.layer.modify_cell((9, 9), lambda x: x + 5) self.assertEqual(self.layer.data[0, 0], 5) self.assertEqual(self.layer.data[9, 9], 5) - def test_aggregate_property_std_dev(self): + def test_aggregate_property_std_dev(self): # noqa: D102 self.layer.data = np.arange(100).reshape(10, 10) result = self.layer.aggregate_property(np.std) self.assertAlmostEqual(result, np.std(np.arange(100)), places=5) - def test_data_type_consistency(self): + def test_data_type_consistency(self): # noqa: D102 self.layer.data = np.zeros((10, 10), dtype=int) self.layer.set_cell((5, 5), 5.5) self.assertIsInstance(self.layer.data[5, 5], self.layer.data.dtype.type) -class TestSingleGrid(unittest.TestCase): - def setUp(self): +class TestSingleGrid(unittest.TestCase): # noqa: D101 + def setUp(self): # noqa: D102 self.space = SingleGrid(50, 50, False) self.agents = [] for i, pos in enumerate(TEST_AGENTS_GRID): @@ -447,14 +404,12 @@ def setUp(self): self.space.place_agent(a, pos) def test_agent_positions(self): - """ - Ensure that the agents are all placed properly. - """ + """Ensure that the agents are all placed properly.""" for i, pos in enumerate(TEST_AGENTS_GRID): a = self.agents[i] assert a.pos == pos - def test_remove_agent(self): + def test_remove_agent(self): # noqa: D102 for i, pos in enumerate(TEST_AGENTS_GRID): a = self.agents[i] assert a.pos == pos @@ -463,7 +418,7 @@ def test_remove_agent(self): assert a.pos is None assert self.space[pos[0]][pos[1]] is None - def test_empty_cells(self): + def test_empty_cells(self): # noqa: D102 if self.space.exists_empty_cells(): for i, pos in enumerate(list(self.space.empties)): a = MockAgent(-i) @@ -472,7 +427,7 @@ def test_empty_cells(self): self.space.move_to_empty(a) def test_empty_mask_consistency(self): - # Check that the empty mask is consistent with the empties set + """Check that the empty mask is consistent with the empties set.""" empty_mask = self.space.empty_mask empties = self.space.empties for i in range(self.space.width): @@ -481,7 +436,7 @@ def test_empty_mask_consistency(self): empties_value = (i, j) in empties assert mask_value == empties_value - def move_agent(self): + def move_agent(self): # noqa: D102 agent_number = 0 initial_pos = TEST_AGENTS_GRID[agent_number] final_pos = (7, 7) @@ -496,20 +451,20 @@ def move_agent(self): assert self.space[initial_pos[0]][initial_pos[1]] is None assert self.space[final_pos[0]][final_pos[1]] == _agent - def test_move_agent_random_selection(self): + def test_move_agent_random_selection(self): # noqa: D102 agent = self.agents[0] possible_positions = [(10, 10), (20, 20), (30, 30)] self.space.move_agent_to_one_of(agent, possible_positions, selection="random") assert agent.pos in possible_positions - def test_move_agent_closest_selection(self): + def test_move_agent_closest_selection(self): # noqa: D102 agent = self.agents[0] agent.pos = (5, 5) possible_positions = [(6, 6), (10, 10), (20, 20)] self.space.move_agent_to_one_of(agent, possible_positions, selection="closest") assert agent.pos == (6, 6) - def test_move_agent_closest_selection_multiple(self): + def test_move_agent_closest_selection_multiple(self): # noqa: D102 random_locations = [] agent = self.agents[0] agent.pos = (5, 5) @@ -526,7 +481,7 @@ def test_move_agent_closest_selection_multiple(self): non_random_locations = [random_locations[0]] * repetititions assert random_locations != non_random_locations - def test_move_agent_invalid_selection(self): + def test_move_agent_invalid_selection(self): # noqa: D102 agent = self.agents[0] possible_positions = [(10, 10), (20, 20), (30, 30)] with self.assertRaises(ValueError): @@ -534,16 +489,14 @@ def test_move_agent_invalid_selection(self): agent, possible_positions, selection="invalid_option" ) - def test_distance_squared(self): + def test_distance_squared(self): # noqa: D102 pos1 = (3, 4) pos2 = (0, 0) expected_distance_squared = 3**2 + 4**2 assert self.space._distance_squared(pos1, pos2) == expected_distance_squared def test_iter_cell_list_contents(self): - """ - Test neighborhood retrieval - """ + """Test neighborhood retrieval.""" cell_list_1 = list(self.space.iter_cell_list_contents(TEST_AGENTS_GRID[0])) assert len(cell_list_1) == 1 @@ -563,8 +516,8 @@ def test_iter_cell_list_contents(self): assert len(cell_list_4) == 1 -class TestSingleGridTorus(unittest.TestCase): - def setUp(self): +class TestSingleGridTorus(unittest.TestCase): # noqa: D101 + def setUp(self): # noqa: D102 self.space = SingleGrid(50, 50, True) # Torus is True here self.agents = [] for i, pos in enumerate(TEST_AGENTS_GRID): @@ -572,13 +525,13 @@ def setUp(self): self.agents.append(a) self.space.place_agent(a, pos) - def test_move_agent_random_selection(self): + def test_move_agent_random_selection(self): # noqa: D102 agent = self.agents[0] possible_positions = [(49, 49), (1, 1), (25, 25)] self.space.move_agent_to_one_of(agent, possible_positions, selection="random") assert agent.pos in possible_positions - def test_move_agent_closest_selection(self): + def test_move_agent_closest_selection(self): # noqa: D102 agent = self.agents[0] agent.pos = (0, 0) possible_positions = [(3, 3), (49, 49), (25, 25)] @@ -586,7 +539,7 @@ def test_move_agent_closest_selection(self): # Expecting (49, 49) to be the closest in a torus grid assert agent.pos == (49, 49) - def test_move_agent_invalid_selection(self): + def test_move_agent_invalid_selection(self): # noqa: D102 agent = self.agents[0] possible_positions = [(10, 10), (20, 20), (30, 30)] with self.assertRaises(ValueError): @@ -594,14 +547,14 @@ def test_move_agent_invalid_selection(self): agent, possible_positions, selection="invalid_option" ) - def test_move_agent_empty_list(self): + def test_move_agent_empty_list(self): # noqa: D102 agent = self.agents[0] possible_positions = [] agent.pos = (3, 3) self.space.move_agent_to_one_of(agent, possible_positions, selection="random") assert agent.pos == (3, 3) - def test_move_agent_empty_list_warning(self): + def test_move_agent_empty_list_warning(self): # noqa: D102 agent = self.agents[0] possible_positions = [] # Should assert RuntimeWarning @@ -610,7 +563,7 @@ def test_move_agent_empty_list_warning(self): agent, possible_positions, selection="random", handle_empty="warning" ) - def test_move_agent_empty_list_error(self): + def test_move_agent_empty_list_error(self): # noqa: D102 agent = self.agents[0] possible_positions = [] with self.assertRaises(ValueError): @@ -618,15 +571,15 @@ def test_move_agent_empty_list_error(self): agent, possible_positions, selection="random", handle_empty="error" ) - def test_distance_squared_torus(self): + def test_distance_squared_torus(self): # noqa: D102 pos1 = (0, 0) pos2 = (49, 49) expected_distance_squared = 1**2 + 1**2 # In torus, these points are close assert self.space._distance_squared(pos1, pos2) == expected_distance_squared -class TestSingleGridWithPropertyGrid(unittest.TestCase): - def setUp(self): +class TestSingleGridWithPropertyGrid(unittest.TestCase): # noqa: D101 + def setUp(self): # noqa: D102 self.grid = SingleGrid(10, 10, False) self.property_layer1 = PropertyLayer("layer1", 10, 10, 0, dtype=int) self.property_layer2 = PropertyLayer("layer2", 10, 10, 1.0, dtype=float) @@ -634,32 +587,32 @@ def setUp(self): self.grid.add_property_layer(self.property_layer2) # Test adding and removing property layers - def test_add_property_layer(self): + def test_add_property_layer(self): # noqa: D102 self.assertIn("layer1", self.grid.properties) self.assertIn("layer2", self.grid.properties) - def test_remove_property_layer(self): + def test_remove_property_layer(self): # noqa: D102 self.grid.remove_property_layer("layer1") self.assertNotIn("layer1", self.grid.properties) - def test_add_property_layer_mismatched_dimensions(self): + def test_add_property_layer_mismatched_dimensions(self): # noqa: D102 with self.assertRaises(ValueError): self.grid.add_property_layer(PropertyLayer("layer3", 5, 5, 0, dtype=int)) - def test_add_existing_property_layer(self): + def test_add_existing_property_layer(self): # noqa: D102 with self.assertRaises(ValueError): self.grid.add_property_layer(self.property_layer1) - def test_remove_nonexistent_property_layer(self): + def test_remove_nonexistent_property_layer(self): # noqa: D102 with self.assertRaises(ValueError): self.grid.remove_property_layer("nonexistent_layer") # Test getting masks - def test_get_empty_mask(self): + def test_get_empty_mask(self): # noqa: D102 empty_mask = self.grid.empty_mask self.assertTrue(np.all(empty_mask == np.ones((10, 10), dtype=bool))) - def test_get_empty_mask_with_agent(self): + def test_get_empty_mask_with_agent(self): # noqa: D102 agent = MockAgent(0) self.grid.place_agent(agent, (4, 6)) @@ -669,7 +622,7 @@ def test_get_empty_mask_with_agent(self): self.assertTrue(np.all(empty_mask == expected_mask)) - def test_get_neighborhood_mask(self): + def test_get_neighborhood_mask(self): # noqa: D102 agent = MockAgent(0) agent2 = MockAgent(1) self.grid.place_agent(agent, (5, 5)) @@ -681,14 +634,14 @@ def test_get_neighborhood_mask(self): self.assertTrue(np.all(neighborhood_mask == expected_mask)) # Test selecting and moving to cells based on multiple conditions - def test_select_cells_by_properties(self): + def test_select_cells_by_properties(self): # noqa: D102 def condition(x): return x == 0 selected_cells = self.grid.select_cells({"layer1": condition}) self.assertEqual(len(selected_cells), 100) - def test_select_cells_by_properties_return_mask(self): + def test_select_cells_by_properties_return_mask(self): # noqa: D102 def condition(x): return x == 0 @@ -696,7 +649,7 @@ def condition(x): self.assertTrue(isinstance(selected_mask, np.ndarray)) self.assertTrue(selected_mask.all()) - def test_move_agent_to_cell_by_properties(self): + def test_move_agent_to_cell_by_properties(self): # noqa: D102 agent = MockAgent(1) self.grid.place_agent(agent, (5, 5)) conditions = {"layer1": lambda x: x == 0} @@ -705,7 +658,7 @@ def test_move_agent_to_cell_by_properties(self): # Agent should move, since none of the cells match the condition self.assertNotEqual(agent.pos, (5, 5)) - def test_move_agent_no_eligible_cells(self): + def test_move_agent_no_eligible_cells(self): # noqa: D102 agent = MockAgent(3) self.grid.place_agent(agent, (5, 5)) conditions = {"layer1": lambda x: x != 0} @@ -714,12 +667,12 @@ def test_move_agent_no_eligible_cells(self): self.assertEqual(agent.pos, (5, 5)) # Test selecting and moving to cells based on extreme values - def test_select_extreme_value_cells(self): + def test_select_extreme_value_cells(self): # noqa: D102 self.grid.properties["layer2"].set_cell((3, 1), 1.1) target_cells = self.grid.select_cells(extreme_values={"layer2": "highest"}) self.assertIn((3, 1), target_cells) - def test_select_extreme_value_cells_return_mask(self): + def test_select_extreme_value_cells_return_mask(self): # noqa: D102 self.grid.properties["layer2"].set_cell((3, 1), 1.1) target_mask = self.grid.select_cells( extreme_values={"layer2": "highest"}, return_list=False @@ -727,7 +680,7 @@ def test_select_extreme_value_cells_return_mask(self): self.assertTrue(isinstance(target_mask, np.ndarray)) self.assertTrue(target_mask[3, 1]) - def test_move_agent_to_extreme_value_cell(self): + def test_move_agent_to_extreme_value_cell(self): # noqa: D102 agent = MockAgent(2) self.grid.place_agent(agent, (5, 5)) self.grid.properties["layer2"].set_cell((3, 1), 1.1) @@ -736,7 +689,7 @@ def test_move_agent_to_extreme_value_cell(self): self.assertEqual(agent.pos, (3, 1)) # Test using masks - def test_select_cells_by_properties_with_empty_mask(self): + def test_select_cells_by_properties_with_empty_mask(self): # noqa: D102 self.grid.place_agent( MockAgent(0), (5, 5) ) # Placing an agent to ensure some cells are not empty @@ -750,7 +703,7 @@ def condition(x): (5, 5), selected_cells ) # (5, 5) should not be in the selection as it's not empty - def test_select_cells_by_properties_with_neighborhood_mask(self): + def test_select_cells_by_properties_with_neighborhood_mask(self): # noqa: D102 neighborhood_mask = self.grid.get_neighborhood_mask((5, 5), True, False, 1) def condition(x): @@ -771,7 +724,7 @@ def condition(x): ] # Cells in the neighborhood of (5, 5) self.assertCountEqual(selected_cells, expected_selection) - def test_move_agent_to_cell_by_properties_with_empty_mask(self): + def test_move_agent_to_cell_by_properties_with_empty_mask(self): # noqa: D102 agent = MockAgent(1) self.grid.place_agent(agent, (5, 5)) self.grid.place_agent( @@ -785,7 +738,7 @@ def test_move_agent_to_cell_by_properties_with_empty_mask(self): agent.pos, (4, 5) ) # Agent should not move to (4, 5) as it's not empty - def test_move_agent_to_cell_by_properties_with_neighborhood_mask(self): + def test_move_agent_to_cell_by_properties_with_neighborhood_mask(self): # noqa: D102 agent = MockAgent(1) self.grid.place_agent(agent, (5, 5)) neighborhood_mask = self.grid.get_neighborhood_mask((5, 5), True, False, 1) @@ -797,7 +750,7 @@ def test_move_agent_to_cell_by_properties_with_neighborhood_mask(self): ) # Agent should move within the neighborhood # Test invalid inputs - def test_invalid_property_name_in_conditions(self): + def test_invalid_property_name_in_conditions(self): # noqa: D102 def condition(x): return x == 0 @@ -805,7 +758,7 @@ def condition(x): self.grid.select_cells(conditions={"nonexistent_layer": condition}) # Test if coordinates means the same between the grid and the property layer - def test_property_layer_coordinates(self): + def test_property_layer_coordinates(self): # noqa: D102 agent = MockAgent(0) correct_pos = (1, 8) incorrect_pos = (8, 1) @@ -827,14 +780,14 @@ def test_property_layer_coordinates(self): self.assertNotEqual(incorrect_grid_value, agent_grid_value) # Test selecting cells with only_empty parameter - def test_select_cells_only_empty(self): + def test_select_cells_only_empty(self): # noqa: D102 self.grid.place_agent(MockAgent(0), (5, 5)) # Occupying a cell selected_cells = self.grid.select_cells(only_empty=True) self.assertNotIn( (5, 5), selected_cells ) # The occupied cell should not be selected - def test_select_cells_only_empty_with_conditions(self): + def test_select_cells_only_empty_with_conditions(self): # noqa: D102 self.grid.place_agent(MockAgent(1), (5, 5)) self.grid.properties["layer1"].set_cell((5, 5), 2) self.grid.properties["layer1"].set_cell((6, 6), 2) @@ -847,7 +800,7 @@ def condition(x): self.assertNotIn((5, 5), selected_cells) # Test selecting cells with multiple extreme values - def test_select_cells_multiple_extreme_values(self): + def test_select_cells_multiple_extreme_values(self): # noqa: D102 self.grid.properties["layer1"].set_cell((1, 1), 3) self.grid.properties["layer1"].set_cell((2, 2), 3) self.grid.properties["layer2"].set_cell((2, 2), 0.5) @@ -861,13 +814,11 @@ def test_select_cells_multiple_extreme_values(self): self.assertEqual(len(selected_cells), 1) -class TestSingleNetworkGrid(unittest.TestCase): +class TestSingleNetworkGrid(unittest.TestCase): # noqa D101 GRAPH_SIZE = 10 def setUp(self): - """ - Create a test network grid and populate with Mock Agents. - """ + """Create a test network grid and populate with Mock Agents.""" G = nx.cycle_graph(TestSingleNetworkGrid.GRAPH_SIZE) # noqa: N806 self.space = NetworkGrid(G) self.agents = [] @@ -877,22 +828,21 @@ def setUp(self): self.space.place_agent(a, pos) def test_agent_positions(self): - """ - Ensure that the agents are all placed properly. - """ + """Ensure that the agents are all placed properly.""" for i, pos in enumerate(TEST_AGENTS_NETWORK_SINGLE): a = self.agents[i] assert a.pos == pos - def test_get_neighborhood(self): + def test_get_neighborhood(self): # noqa: D102 assert len(self.space.get_neighborhood(0, include_center=True)) == 3 assert len(self.space.get_neighborhood(0, include_center=False)) == 2 assert len(self.space.get_neighborhood(2, include_center=True, radius=3)) == 7 assert len(self.space.get_neighborhood(2, include_center=False, radius=3)) == 6 def test_get_neighbors(self): - """ - Test the get_neighbors method with varying radius and include_center values. Note there are agents on node 0, 1 and 5. + """Test the get_neighbors method with varying radius and include_center values. + + Note there are agents on node 0, 1 and 5. """ # Test with default radius (1) and include_center = False neighbors_default = self.space.get_neighbors(0, include_center=False) @@ -932,7 +882,7 @@ def test_get_neighbors(self): f"Should have {expected_count_radius_2_include_center} neighbors (including center) with radius 2", ) - def test_move_agent(self): + def test_move_agent(self): # noqa: D102 initial_pos = 1 agent_number = 1 final_pos = TestSingleNetworkGrid.GRAPH_SIZE - 1 @@ -947,7 +897,7 @@ def test_move_agent(self): assert _agent not in self.space.G.nodes[initial_pos]["agent"] assert _agent in self.space.G.nodes[final_pos]["agent"] - def test_remove_agent(self): + def test_remove_agent(self): # noqa: D102 for i, pos in enumerate(TEST_AGENTS_NETWORK_SINGLE): a = self.agents[i] assert a.pos == pos @@ -956,17 +906,17 @@ def test_remove_agent(self): assert a.pos is None assert a not in self.space.G.nodes[pos]["agent"] - def test_is_cell_empty(self): + def test_is_cell_empty(self): # noqa: D102 assert not self.space.is_cell_empty(0) assert self.space.is_cell_empty(TestSingleNetworkGrid.GRAPH_SIZE - 1) - def test_get_cell_list_contents(self): + def test_get_cell_list_contents(self): # noqa: D102 assert self.space.get_cell_list_contents([0]) == [self.agents[0]] assert self.space.get_cell_list_contents( list(range(TestSingleNetworkGrid.GRAPH_SIZE)) ) == [self.agents[0], self.agents[1], self.agents[2]] - def test_get_all_cell_contents(self): + def test_get_all_cell_contents(self): # noqa: D102 assert self.space.get_all_cell_contents() == [ self.agents[0], self.agents[1], @@ -974,13 +924,11 @@ def test_get_all_cell_contents(self): ] -class TestMultipleNetworkGrid(unittest.TestCase): +class TestMultipleNetworkGrid(unittest.TestCase): # noqa: D101 GRAPH_SIZE = 3 def setUp(self): - """ - Create a test network grid and populate with Mock Agents. - """ + """Create a test network grid and populate with Mock Agents.""" G = nx.complete_graph(TestMultipleNetworkGrid.GRAPH_SIZE) # noqa: N806 self.space = NetworkGrid(G) self.agents = [] @@ -990,14 +938,12 @@ def setUp(self): self.space.place_agent(a, pos) def test_agent_positions(self): - """ - Ensure that the agents are all placed properly. - """ + """Ensure that the agents are all placed properly.""" for i, pos in enumerate(TEST_AGENTS_NETWORK_MULTIPLE): a = self.agents[i] assert a.pos == pos - def test_get_neighbors(self): + def test_get_neighbors(self): # noqa: D102 assert ( len(self.space.get_neighborhood(0, include_center=True)) == TestMultipleNetworkGrid.GRAPH_SIZE @@ -1007,7 +953,7 @@ def test_get_neighbors(self): == TestMultipleNetworkGrid.GRAPH_SIZE - 1 ) - def test_move_agent(self): + def test_move_agent(self): # noqa: D102 initial_pos = 1 agent_number = 1 final_pos = 0 @@ -1028,12 +974,12 @@ def test_move_agent(self): assert len(self.space.G.nodes[initial_pos]["agent"]) == 1 assert len(self.space.G.nodes[final_pos]["agent"]) == 2 - def test_is_cell_empty(self): + def test_is_cell_empty(self): # noqa: D102 assert not self.space.is_cell_empty(0) assert not self.space.is_cell_empty(1) assert self.space.is_cell_empty(2) - def test_get_cell_list_contents(self): + def test_get_cell_list_contents(self): # noqa: D102 assert self.space.get_cell_list_contents([0]) == [self.agents[0]] assert self.space.get_cell_list_contents([1]) == [ self.agents[1], @@ -1043,7 +989,7 @@ def test_get_cell_list_contents(self): list(range(TestMultipleNetworkGrid.GRAPH_SIZE)) ) == [self.agents[0], self.agents[1], self.agents[2]] - def test_get_all_cell_contents(self): + def test_get_all_cell_contents(self): # noqa: D102 assert self.space.get_all_cell_contents() == [ self.agents[0], self.agents[1], diff --git a/tests/test_time.py b/tests/test_time.py index b21b7b9555d..3d0b1f38f11 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -1,6 +1,4 @@ -""" -Test the advanced schedulers. -""" +"""Test the advanced schedulers.""" import unittest from unittest import TestCase, mock @@ -22,53 +20,49 @@ class MockAgent(Agent): - """ - Minimalistic agent for testing purposes. - """ + """Minimalistic agent for testing purposes.""" - def __init__(self, model): + def __init__(self, model): # noqa: D107 super().__init__(model) self.steps = 0 self.advances = 0 - def kill_other_agent(self): + def kill_other_agent(self): # noqa: D102 for agent in self.model.schedule.agents: if agent is not self: agent.remove() - def stage_one(self): + def stage_one(self): # noqa: D102 if self.model.enable_kill_other_agent: self.kill_other_agent() self.model.log.append(f"{self.unique_id}_1") - def stage_two(self): + def stage_two(self): # noqa: D102 self.model.log.append(f"{self.unique_id}_2") - def advance(self): + def advance(self): # noqa: D102 self.advances += 1 - def step(self): + def step(self): # noqa: D102 if self.model.enable_kill_other_agent: self.kill_other_agent() self.steps += 1 self.model.log.append(self.unique_id) -class MockModel(Model): +class MockModel(Model): # noqa: D101 def __init__(self, shuffle=False, activation=STAGED, enable_kill_other_agent=False): - """ - Creates a Model instance with a schedule + """Creates a Model instance with a schedule. Args: - shuffle (Bool): whether or not to instantiate a scheduler - with shuffling. - This option is only used for + shuffle (Bool): whether to instantiate a scheduler + with shuffling. This option is only used for StagedActivation schedulers. - activation (str): which kind of scheduler to use. 'random' creates a RandomActivation scheduler. 'staged' creates a StagedActivation scheduler. The default scheduler is a BaseScheduler. + enable_kill_other_agent (bool): whether to enable killing of other agents """ super().__init__() self.log = [] @@ -94,34 +88,27 @@ def __init__(self, shuffle=False, activation=STAGED, enable_kill_other_agent=Fal agent = MockAgent(self) self.schedule.add(agent) - def step(self): + def step(self): # noqa: D102 self.schedule.step() - def model_stage(self): + def model_stage(self): # noqa: D102 self.log.append("model_stage") class TestStagedActivation(TestCase): - """ - Test the staged activation. - """ + """Test the staged activation.""" expected_output = ["1_1", "1_1", "model_stage", "1_2", "1_2"] def test_no_shuffle(self): - """ - Testing the staged activation without shuffling. - """ - + """Testing the staged activation without shuffling.""" model = MockModel(shuffle=False) model.step() model.step() assert all(i == j for i, j in zip(model.log[:5], model.log[5:])) def test_shuffle(self): - """ - Test the staged activation with shuffling - """ + """Test the staged activation with shuffling.""" model = MockModel(shuffle=True) model.step() for output in self.expected_output[:2]: @@ -130,7 +117,7 @@ def test_shuffle(self): assert output in model.log[3:] assert self.expected_output[2] == model.log[2] - def test_shuffle_shuffles_agents(self): + def test_shuffle_shuffles_agents(self): # noqa: D102 model = MockModel(shuffle=True) model.random = mock.Mock() assert model.random.shuffle.call_count == 0 @@ -138,9 +125,7 @@ def test_shuffle_shuffles_agents(self): assert model.random.shuffle.call_count == 1 def test_remove(self): - """ - Test the staged activation can remove an agent - """ + """Test the staged activation can remove an agent.""" model = MockModel(shuffle=True) agents = list(model.schedule._agents) agent = agents[0] @@ -148,15 +133,12 @@ def test_remove(self): assert agent not in model.schedule.agents def test_intrastep_remove(self): - """ - Test the staged activation can remove an agent in a - step of another agent so that the one removed doesn't step. - """ + """Test removing an agent in a step of another agent so that the one removed doesn't step.""" model = MockModel(shuffle=True, enable_kill_other_agent=True) model.step() assert len(model.log) == 3 - def test_add_existing_agent(self): + def test_add_existing_agent(self): # noqa: D102 model = MockModel() agent = model.schedule.agents[0] with self.assertRaises(Exception): @@ -164,11 +146,9 @@ def test_add_existing_agent(self): class TestRandomActivation(TestCase): - """ - Test the random activation. - """ + """Test the random activation.""" - def test_init(self): + def test_init(self): # noqa: D102 model = Model() agents = [MockAgent(model) for _ in range(10)] @@ -176,18 +156,14 @@ def test_init(self): assert all(agent in scheduler.agents for agent in agents) def test_random_activation_step_shuffles(self): - """ - Test the random activation step - """ + """Test the random activation step.""" model = MockModel(activation=RANDOM) model.random = mock.Mock() model.schedule.step() assert model.random.shuffle.call_count == 1 def test_random_activation_step_increments_step_and_time_counts(self): - """ - Test the random activation step increments step and time counts - """ + """Test the random activation step increments step and time counts.""" model = MockModel(activation=RANDOM) assert model.schedule.steps == 0 assert model.schedule.time == 0 @@ -196,9 +172,7 @@ def test_random_activation_step_increments_step_and_time_counts(self): assert model.schedule.time == 1 def test_random_activation_step_steps_each_agent(self): - """ - Test the random activation step causes each agent to step - """ + """Test the random activation step causes each agent to step.""" model = MockModel(activation=RANDOM) model.step() agent_steps = [i.steps for i in model.schedule.agents] @@ -206,15 +180,12 @@ def test_random_activation_step_steps_each_agent(self): assert all(x == 1 for x in agent_steps) def test_intrastep_remove(self): - """ - Test the random activation can remove an agent in a - step of another agent so that the one removed doesn't step. - """ + """Test removal an agent in astep of another agent so that the one removed doesn't step.""" model = MockModel(activation=RANDOM, enable_kill_other_agent=True) model.step() assert len(model.log) == 1 - def test_get_agent_keys(self): + def test_get_agent_keys(self): # noqa: D102 model = MockModel(activation=RANDOM) keys = model.schedule.get_agent_keys() @@ -225,7 +196,7 @@ def test_get_agent_keys(self): agent_ids = {agent.unique_id for agent in model.agents} assert all(entry in agent_ids for entry in keys) - def test_not_sequential(self): + def test_not_sequential(self): # noqa: D102 model = MockModel(activation=RANDOM) # Create 10 agents for _ in range(10): @@ -249,14 +220,10 @@ def test_not_sequential(self): class TestSimultaneousActivation(TestCase): - """ - Test the simultaneous activation. - """ + """Test the simultaneous activation.""" def test_simultaneous_activation_step_steps_and_advances_each_agent(self): - """ - Test the simultaneous activation step causes each agent to step - """ + """Test the simultaneous activation step causes each agent to step.""" model = MockModel(activation=SIMULTANEOUS) model.step() # one step for each of 2 agents @@ -267,13 +234,13 @@ def test_simultaneous_activation_step_steps_and_advances_each_agent(self): class TestRandomActivationByType(TestCase): - """ - Test the random activation by type. + """Test the random activation by type. + TODO implement at least 2 types of agents, and test that step_type only does step for one type of agents, not the entire agents. """ - def test_init(self): + def test_init(self): # noqa: D102 model = Model() agents = [MockAgent(model) for _ in range(10)] agents += [Agent(model) for _ in range(10)] @@ -282,18 +249,14 @@ def test_init(self): assert all(agent in scheduler.agents for agent in agents) def test_random_activation_step_shuffles(self): - """ - Test the random activation by type step - """ + """Test the random activation by type step.""" model = MockModel(activation=RANDOM_BY_TYPE) model.random = mock.Mock() model.schedule.step() assert model.random.shuffle.call_count == 2 def test_random_activation_step_increments_step_and_time_counts(self): - """ - Test the random activation by type step increments step and time counts - """ + """Test the random activation by type step increments step and time counts.""" model = MockModel(activation=RANDOM_BY_TYPE) assert model.schedule.steps == 0 assert model.schedule.time == 0 @@ -302,10 +265,7 @@ def test_random_activation_step_increments_step_and_time_counts(self): assert model.schedule.time == 1 def test_random_activation_step_steps_each_agent(self): - """ - Test the random activation by type step causes each agent to step - """ - + """Test the random activation by type step causes each agent to step.""" model = MockModel(activation=RANDOM_BY_TYPE) model.step() agent_steps = [i.steps for i in model.schedule.agents] @@ -313,10 +273,7 @@ def test_random_activation_step_steps_each_agent(self): assert all(x == 1 for x in agent_steps) def test_random_activation_counts(self): - """ - Test the random activation by type step causes each agent to step - """ - + """Test the random activation by type step causes each agent to step.""" model = MockModel(activation=RANDOM_BY_TYPE) agent_types = model.agent_types