-
Notifications
You must be signed in to change notification settings - Fork 2
/
example.py
27 lines (23 loc) · 1018 Bytes
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from epstein_civil_violence.model import EpsteinCivilViolence_RL
from epstein_civil_violence.server import run_model
from epstein_civil_violence.train_config import config
from train import train_model
# Load the environment
env = EpsteinCivilViolence_RL()
observation, info = env.reset(seed=42)
# Running the environment on some random actions
for _ in range(10):
action_dict = {}
for agent in env.schedule.agents:
action_dict[agent.unique_id] = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action_dict)
if terminated or truncated:
observation, info = env.reset()
# Training a model
train_model(config, num_iterations=1, result_path='results.txt', checkpoint_dir='checkpoints')
# Running the model and visualizing it
server = run_model(path='checkpoints')
# You can also try running pre-trained checkpoints present in model folder
# server = run_model(path='model/epstein_civil_violence')
server.port = 6005
server.launch(open_browser=True)