Skip to content

Commit

Permalink
update epstein 3 for rl
Browse files Browse the repository at this point in the history
- update epstein rl for mesa 3.0
  • Loading branch information
tpike3 committed Nov 14, 2024
1 parent 30a3475 commit 877d9ee
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 20 deletions.
3 changes: 1 addition & 2 deletions rl/epstein_civil_violence/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from mesa.examples.advanced.epstein_civil_violence.agents import Citizen, Cop

from .utility import move
from utility import move


class CitizenRL(Citizen):
Expand Down
25 changes: 11 additions & 14 deletions rl/epstein_civil_violence/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import gymnasium as gym
import mesa
import numpy as np
from agent import CitizenRL, CopRL
from mesa.examples.advanced.epstein_civil_violence.model import EpsteinCivilViolence
from ray.rllib.env import MultiAgentEnv

from .agent import CitizenRL, CopRL
from .utility import create_intial_agents, grid_to_observation
from utility import create_intial_agents, grid_to_observation


class EpsteinCivilViolenceRL(EpsteinCivilViolence, MultiAgentEnv):
Expand Down Expand Up @@ -88,7 +87,7 @@ def step(self, action_dict):
self.action_dict = action_dict

# Step the model
self.schedule.step()
self.agents.shuffle_do("step")
self.datacollector.collect(self)

# Calculate rewards
Expand All @@ -104,10 +103,10 @@ def step(self, action_dict):
] # Get the values from the observation grid for the neighborhood cells

# RL specific outputs for the environment
done = {a.unique_id: False for a in self.schedule.agents}
truncated = {a.unique_id: False for a in self.schedule.agents}
done = {a.unique_id: False for a in self.agents}
truncated = {a.unique_id: False for a in self.agents}
truncated["__all__"] = np.all(list(truncated.values()))
if self.schedule.time > self.max_iters:
if self.time > self.max_iters:
done["__all__"] = True
else:
done["__all__"] = False
Expand All @@ -116,7 +115,7 @@ def step(self, action_dict):

def cal_reward(self):
rewards = {}
for agent in self.schedule.agents:
for agent in self.agents:
if isinstance(agent, CopRL):
if agent.arrest_made:
# Cop is rewarded for making an arrest
Expand Down Expand Up @@ -149,19 +148,17 @@ def reset(self, *, seed=None, options=None):
"""

super().reset()
# Using base scheduler to maintain the order of agents
self.schedule = mesa.time.BaseScheduler(self)
self.grid = mesa.space.SingleGrid(self.width, self.height, torus=True)
create_intial_agents(self, CitizenRL, CopRL)
grid_to_observation(self, CitizenRL)
# Intialize action dictionary with no action
self.action_dict = {a.unique_id: (0, 0) for a in self.schedule.agents}
self.action_dict = {a.unique_id: (0, 0) for a in self.agents}
# Update neighbors for observation space
for agent in self.schedule.agents:
for agent in self.agents:
agent.update_neighbors()
self.schedule.step()
self.agents.shuffle_do("step")
observation = {}
for agent in self.schedule.agents:
for agent in self.agents:
observation[agent.unique_id] = [
self.obs_grid[neighbor[0]][neighbor[1]]
for neighbor in agent.neighborhood
Expand Down
3 changes: 1 addition & 2 deletions rl/epstein_civil_violence/train_config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os

from model import EpsteinCivilViolenceRL
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.policy.policy import PolicySpec

from .model import EpsteinCivilViolenceRL


# Configuration for the PPO algorithm
# You can change the configuration as per your requirements
Expand Down
4 changes: 2 additions & 2 deletions rl/epstein_civil_violence/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def create_intial_agents(self, CitizenRL, CopRL):
# Initializing cops then citizens
# This ensures cops act out their step before citizens
for cop in cops:
self.schedule.add(cop)
self.add(cop)
for citizen in citizens:
self.schedule.add(citizen)
self.add(citizen)


def grid_to_observation(self, CitizenRL):
Expand Down

0 comments on commit 877d9ee

Please sign in to comment.