Skip to content

Commit

Permalink
update boltmann rl for Mesa 3
Browse files Browse the repository at this point in the history
- updated boltzmann_rl for Mesa 3.0
- creating duplicate agents for some reason; need ot reset the unique_id iterator
  • Loading branch information
tpike3 committed Nov 14, 2024
1 parent 877d9ee commit 1a36a75
Showing 1 changed file with 20 additions and 24 deletions.
44 changes: 20 additions & 24 deletions rl/boltzmann_money/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@
# Import necessary libraries
import numpy as np
import seaborn as sns
from mesa_models.boltzmann_wealth_model.model import (
BoltzmannWealthModel,
MoneyAgent,
compute_gini,
)
from mesa.examples.basic.boltzmann_wealth_model.agents import MoneyAgent
from mesa.examples.basic.boltzmann_wealth_model.model import BoltzmannWealth

NUM_AGENTS = 10


# Define the agent class
class MoneyAgentRL(MoneyAgent):
def __init__(self, unique_id, model):
super().__init__(unique_id, model)
def __init__(self, model):
super().__init__(model)
self.wealth = np.random.randint(1, NUM_AGENTS)

def move(self, action):
Expand Down Expand Up @@ -74,45 +71,46 @@ def take_money(self):

def step(self):
# Get the action for the agent
action = self.model.action_dict[self.unique_id]
# TODO: figure out why agents are being made twice
action = self.model.action_dict[self.unique_id - 11]
# Move the agent based on the action
self.move(action)
# Take money from other agents in the same cell
self.take_money()


# Define the model class
class BoltzmannWealthModelRL(BoltzmannWealthModel, gymnasium.Env):
def __init__(self, N, width, height):
super().__init__(N, width, height)
class BoltzmannWealthModelRL(BoltzmannWealth, gymnasium.Env):
def __init__(self, n, width, height):
super().__init__(n, width, height)
# Define the observation and action space for the RL model
# The observation space is the wealth of each agent and their position
self.observation_space = gymnasium.spaces.Box(low=0, high=10 * N, shape=(N, 3))
self.observation_space = gymnasium.spaces.Box(low=0, high=10 * n, shape=(n, 3))
# The action space is a MultiDiscrete space with 5 possible actions for each agent
self.action_space = gymnasium.spaces.MultiDiscrete([5] * N)
self.action_space = gymnasium.spaces.MultiDiscrete([5] * n)
self.is_visualize = False

def step(self, action):
self.action_dict = action
# Perform one step of the model
self.schedule.step()
self.agents.shuffle_do("step")
# Collect data for visualization
self.datacollector.collect(self)
# Compute the new Gini coefficient
new_gini = compute_gini(self)
new_gini = self.compute_gini()
# Compute the reward based on the change in Gini coefficient
reward = self.calculate_reward(new_gini)
self.prev_gini = new_gini
# Get the observation for the RL model
obs = self._get_obs()
if self.schedule.time > 5 * NUM_AGENTS:
if self.time > 5 * NUM_AGENTS:
# Terminate the episode if the model has run for a certain number of timesteps
done = True
reward = -1
elif new_gini < 0.1:
# Terminate the episode if the Gini coefficient is below a certain threshold
done = True
reward = 50 / self.schedule.time
reward = 50 / self.time
else:
done = False
info = {}
Expand Down Expand Up @@ -142,20 +140,18 @@ def reset(self, *, seed=None, options=None):
self.visualize()
super().reset()
self.grid = mesa.space.MultiGrid(self.grid.width, self.grid.height, True)
self.schedule = mesa.time.RandomActivation(self)
self.remove_all_agents()
for i in range(self.num_agents):
# Create MoneyAgentRL instances and add them to the schedule
a = MoneyAgentRL(i, self)
self.schedule.add(a)
a = MoneyAgentRL(self)
x = self.random.randrange(self.grid.width)
y = self.random.randrange(self.grid.height)
self.grid.place_agent(a, (x, y))
self.prev_gini = compute_gini(self)
self.prev_gini = self.compute_gini()
return self._get_obs(), {}

def _get_obs(self):
# The observation is the wealth of each agent and their position
obs = []
for a in self.schedule.agents:
obs.append([a.wealth, *list(a.pos)])
obs = [[a.wealth, *a.pos] for a in self.agents]
obs = np.array(obs)
return np.array(obs)

0 comments on commit 1a36a75

Please sign in to comment.