Skip to content

Commit

Permalink
Added ppo_trxl to README.md, fixed enjoy and ppo_trxl for MiniGrid, a…
Browse files Browse the repository at this point in the history
…dded proper rendering to ProofofMemory-v0, updated docs for training and enjoying MiniGrid and ProofofMemory-v0
  • Loading branch information
MarcoMeter committed Sep 17, 2024
1 parent 40cb39f commit 2545864
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 47 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ You may also use a prebuilt development environment hosted in Gitpod:
| | [`ppo_atari_multigpu.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_multigpu.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_multigpupy)
| | [`ppo_pettingzoo_ma_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_pettingzoo_ma_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_pettingzoo_ma_ataripy)
| | [`ppo_continuous_action_isaacgym.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_continuous_action_isaacgym/ppo_continuous_action_isaacgym.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_action_isaacgympy)
| | [`ppo_trxl.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_trxl/ppo_trxl.py), [docs](https://docs.cleanrl.dev/rl-algorithms/ppo_trxl/)
|[Deep Q-Learning (DQN)](https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf) | [`dqn.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy) |
| | [`dqn_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_atari.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_ataripy) |
| | [`dqn_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn_jax.py), [docs](https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_jaxpy) |
Expand Down
5 changes: 4 additions & 1 deletion cleanrl/ppo_trxl/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,8 @@ class Args:
done = termination or truncation
t += 1

print(f"Episode return: {info['reward']}, Episode length: {info['length']}")
if "r" in info["episode"].keys():
print(f"Episode return: {info['episode']['r'][0]}, Episode length: {info['episode']['l'][0]}")
else:
print(f"Episode return: {info['reward']}, Episode length: {info['length']}")
env.close()
124 changes: 82 additions & 42 deletions cleanrl/ppo_trxl/pom_env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import os
import time

import gymnasium as gym
import numpy as np
from gymnasium import spaces
from reprint import output
import pygame

gym.register(
id="ProofofMemory-v0",
Expand All @@ -27,32 +24,35 @@ class PoMEnv(gym.Env):
To further challenge the agent, the step_size can be decreased.
"""

metadata = {"render_modes": ["human"], "render_fps": 1}
metadata = {"render_modes": ["human", "rgb_array", "debug_rgb_array"], "render_fps": 4}

def __init__(self, render_mode="human"):
self._freeze = True
self._step_size = 0.2
self._min_steps = int(1.0 / self._step_size) + 1
self._time_penalty = 0.1
self._num_show_steps = 2
self._op = None
self.render_mode = render_mode
glob = False

self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)

# Create an array with possible positions
# Valid local positions are one tick away from 0.0 or between -0.4 and 0.4
# Valid global positions are between -1 + step_size and 1 - step_size
num_steps = int(0.4 / self._step_size)
lower = min(-2.0 * self._step_size, -num_steps * self._step_size) if not glob else -1 + self._step_size
upper = max(3.0 * self._step_size, self._step_size, (num_steps + 1) * self._step_size) if not glob else 1
self.possible_positions = np.arange(lower, upper, self._step_size).clip(-1 + self._step_size, 1 - self._step_size)
self.possible_positions = list(map(lambda x: round(x, 2), self.possible_positions)) # fix floating point errors

# Pygame-related attributes for rendering
self.window = None
self.clock = None
self.width = 400
self.height = 80
self.cell_width = self.width / (2 * int(1 / self._step_size) + 1)

def step(self, action):
action = action[0]
reward = 0.0
done = False

Expand Down Expand Up @@ -86,13 +86,20 @@ def step(self, action):
done = True
else:
reward -= self._time_penalty
self.rewards.append(reward)

if done:
info = {"reward": sum(self.rewards), "length": len(self.rewards)}
else:
info = {}

self._step_count += 1

return obs, reward, done, False, {}
return obs, reward, done, False, info

def reset(self, *, seed=None, options=None):
super().reset(seed=seed)
self.rewards = []
self._position = np.random.choice(self.possible_positions)
self._step_count = 0
goals = np.asarray([-1.0, 1.0])
Expand All @@ -101,44 +108,77 @@ def reset(self, *, seed=None, options=None):
return obs, {}

def render(self):
if self._op is None:
self.init_render = False
self._op = output()
self._op = self._op.warped_obj
os.system("cls||clear")

for _ in range(6):
self._op.append("#")

num_grids = 2 * int(1 / self._step_size) + 1
agent_grid = int(num_grids / 2 + self._position / self._step_size) + 1
self._op[1] = "######" * num_grids + "#"
self._op[2] = "# " * num_grids + "#"
field = [*("# " * agent_grid)[:-3], *"A ", *("# " * (num_grids - agent_grid)), "#"]
if field[3] != "A":
field[3] = "+" if self._goals[0] > 0 else "-"
if field[-4] != "A":
field[-4] = "+" if self._goals[1] > 0 else "-"
self._op[3] = "".join(field)
self._op[4] = "# " * num_grids + "#"
self._op[5] = "######" * num_grids + "#"

self._op[6] = "Goals are shown: " + str(self._num_show_steps > self._step_count)

time.sleep(1.0)
if self.render_mode not in self.metadata["render_modes"]:
return

# Initialize Pygame
if not pygame.get_init():
pygame.init()
if self.window is None and self.render_mode == "human":
pygame.display.init()
self.window = pygame.display.set_mode((self.width, self.height))
pygame.display.set_caption("Proof of Memory Environment")
if self.clock is None and self.render_mode == "human":
self.clock = pygame.time.Clock()

# Create surface
canvas = pygame.Surface((self.width, self.height))
canvas.fill((255, 255, 255)) # Fill the background with white

# Draw grid
num_cells = 2 * int(1 / self._step_size) + 1
for i in range(num_cells):
x = i * self.cell_width
pygame.draw.rect(canvas, (200, 200, 200), pygame.Rect(x, 0, self.cell_width, self.height), 1)

# Draw agent
agent_pos = int((self._position + 1) / self._step_size)
agent_x = agent_pos * self.cell_width + self.cell_width / 2
pygame.draw.circle(canvas, (0, 0, 255), (agent_x, self.height / 2), 15)

# Draw goals
show_goals = self._num_show_steps > self._step_count
if show_goals:
left_goal_color = (0, 255, 0) if self._goals[0] > 0 else (255, 0, 0)
pygame.draw.rect(canvas, left_goal_color, pygame.Rect(0, 0, self.cell_width, self.height))
right_goal_color = (0, 255, 0) if self._goals[1] > 0 else (255, 0, 0)
pygame.draw.rect(canvas, right_goal_color, pygame.Rect(self.width - self.cell_width, 0, self.cell_width, self.height))
else:
pygame.draw.rect(canvas, (200, 200, 200), pygame.Rect(0, 0, self.cell_width, self.height))
pygame.draw.rect(canvas, (200, 200, 200), pygame.Rect(self.width - self.cell_width, 0, self.cell_width, self.height))

# Render text information
font = pygame.font.SysFont(None, 24)
text = font.render(f"Goals are shown: {show_goals}", True, (0, 0, 0))
canvas.blit(text, (10, 10))

if self.render_mode == "human":
self.window.blit(canvas, (0, 0))
pygame.display.flip()
self.clock.tick(self.metadata["render_fps"])
elif self.render_mode in ["rgb_array", "debug_rgb_array"]:
return np.transpose(
np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
)

def close(self):
if self._op is not None:
self._op.clear()
self._op = None
if self.window is not None:
pygame.display.quit()
pygame.quit()
self.window = None
self.clock = None


if __name__ == "__main__":
env = gym.make("ProofofMemory-v0")
env = PoMEnv(render_mode="human")
o, _ = env.reset()
env.render()
img = env.render()
done = False
rewards = []
const_action = 1
while not done:
o, r, done, _, _ = env.step(1)
env.render()
o, r, done, _, _ = env.step(const_action)
rewards.append(r)
img = env.render()
print(f"Total reward: {sum(rewards)}, Steps: {len(rewards)}")
env.close()
6 changes: 5 additions & 1 deletion cleanrl/ppo_trxl/ppo_trxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,13 @@ class Args:


def make_env(env_id, idx, capture_video, run_name, render_mode="debug_rgb_array"):
if "MiniGrid" in env_id:
if render_mode == "debug_rgb_array":
render_mode = "rgb_array"
def thunk():
if "MiniGrid" in env_id:
env = gym.make(env_id, agent_view_size=3, tile_size=28)
env = gym.make(env_id, agent_view_size=3, tile_size=28,
render_mode="rgb_array" if render_mode == "debug_rgb_array" else render_mode)
env = ImgObsWrapper(RGBImgPartialObsWrapper(env, tile_size=28))
env = gym.wrappers.TimeLimit(env, 96)
else:
Expand Down
52 changes: 49 additions & 3 deletions docs/rl-algorithms/ppo-trxl.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Below is our single-file implementation of PPO-TrXL:

### Implementation details

Most details are derived from [ppo.py](/rl-algorithms/ppo#ppopy). These are additional or differing details:
Most details are derived from [`ppo.py`](/rl-algorithms/ppo#ppopy). These are additional or differing details:

1. The policy and value function share parameters.
2. Multi-head attention is implemented so that all heads share parameters.
Expand All @@ -83,7 +83,7 @@ Most details are derived from [ppo.py](/rl-algorithms/ppo#ppopy). These are addi

### Experiment results

Note: When training on potentially endless episodes, the cached hidden states demand a large GPU memory. To reproduce the following experiments a minimum of 40GB is required.
Note: When training on potentially endless episodes, the cached hidden states demand a large GPU memory. To reproduce the following experiments a minimum of 40GB is required. One workaround is to cache the hidden states in the buffer with lower precision as bfloat16. This is under examination for future updates.

| | PPO-TrXL |
|:-----------------------------|:------------|
Expand All @@ -106,9 +106,53 @@ Tracked experiments:

<iframe src="https://api.wandb.ai/links/m-pleines/wo9m43hv" style="width:100%; height:500px" title="CleanRL-s-PPO-TrXL"></iframe>


### Hyperparameters

Memory Gym Environments

Please refer to the defaults in [`ppo_trxl.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_trxl/ppo_trxl.py) and the single modifications as found in [`benchmark/ppo_trxl.sh`](https://github.com/vwxyzjn/cleanrl/blob/master/benchmark/ppo_trxl.sh)

ProofofMemory-v0
```bash
python ppo_trxl.py \
--env_id ProofofMemory-v0 \
--total_timesteps 25000 \
--num_envs 16 \
--num_steps 128 \
--num_minibatches 8 \
--update_epochs 4 \
--trxl_num_layers 4 \
--trxl_num_heads 1 \
--trxl_dim 64 \
--trxl_memory_length 16 \
--trxl_positional_encoding none \
--vf_coef 0.1 \
--max_grad_norm 0.5 \
--init_lr 3.0e-4 \
--init_ent_coef 0.001 \
--clip_coef 0.2
```

MiniGrid-MemoryS9-v0
```bash
python ppo_trxl.py \
--env_id MiniGrid-MemoryS9-v0 \
--total_timesteps 2048000 \
--num_envs 16 \
--num_steps 256 \
--trxl_num_layers 2 \
--trxl_num_heads 4 \
--trxl_dim 256 \
--trxl_memory_length 64 \
--max_grad_norm 0.25 \
--anneal_steps 4096000
--clip_coef 0.2
```

### Enjoy pre-trained models

Use [cleanrl/ppo_trxl/enjoy.py](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_trxl/en.py) to watch pre-trained agents.
Use [`cleanrl/ppo_trxl/enjoy.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_trxl/en.py) to watch pre-trained agents.
You can retrieve pre-trained models from [huggingface](https://huggingface.co/LilHairdy/cleanrl_memory_gym).


Expand All @@ -117,6 +161,8 @@ Run models from the hub:
python cleanrl/ppo_trxl/enjoy.py --hub --name Endless-MortarMayhem-v0_12.nn
python cleanrl/ppo_trxl/enjoy.py --hub --name Endless-MysterPath-v0_11.nn
python cleanrl/ppo_trxl/enjoy.py --hub --name Endless-SearingSpotlights-v0_30.nn
python cleanrl/ppo_trxl/enjoy.py --hub --name MiniGrid-MemoryS9-v0_10.nn
python cleanrl/ppo_trxl/enjoy.py --hub --name ProofofMemory-v0_1.nn
```


Expand Down

0 comments on commit 2545864

Please sign in to comment.