forked from oliehoek-research/interactive_agents
-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualize.py
124 lines (95 loc) · 4.66 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#!/usr/bin/env python3
'''Visualizes (and optionally records) rollouts of a collection of policy checkpoints'''
import argparse
import gym
import os
import os.path
import yaml
import torch
from interactive_agents.envs import get_env_class, VisualizeGym
from interactive_agents.sampling import FrozenPolicy
def parse_args():
parser = argparse.ArgumentParser("Visualizes a set of trained policies")
parser.add_argument("path", type=str, help="path to directory containing the policy checkpoints")
parser.add_argument("-e", "--num-episodes", type=int, default=100,
help="the number of episodes to run (default: 100)")
parser.add_argument("-s", "--max-steps", type=int, default=1000,
help="the maximum number of steps per episode (default: 1000)")
parser.add_argument("-m", "--map", nargs="+",
help="the mapping from agents to policies")
parser.add_argument("--seed", type=int, default=0,
help="the random seed of the training run to load (default: 0)")
parser.add_argument("-r", "--record", type=str,
help="the path to save recorded videos (no recording if not provided)")
parser.add_argument("--headless", action="store_true",
help="do not display visualization (record only in headless environments)")
parser.add_argument("--speed", type=float, default=1,
help="the speed at which to play the visualization (in steps per second)")
return parser.parse_args()
# TODO: Move policy loading code to the main library, so we don't need to reproduce it for every
def load_experiment(path, seed, policy_map):
policies = {}
config_path = os.path.join(path, "config.yaml")
if not os.path.isfile(config_path):
raise FileNotFoundError(f"Config File: '{config_path}' not found")
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
# NOTE: Needed because most configs are actually dictionaries with multiple configs
if "trainer" not in config:
config = list(config.values())[0]
# Load environment
trainer_config = config.get("config", {})
env_name = trainer_config.get("env")
env_config = trainer_config.get("env_config", {})
env_config = trainer_config.get("env_eval_config", env_config)
env_cls = get_env_class(env_name)
env = env_cls(env_config)
# If we don't specify a mapping from policies to agents, assume there is a 1-1 mapping between policies and agents
if policy_map is None:
policy_map = {}
for policy_id in env.observation_space.keys():
policy_map[policy_id] = policy_id
# Load directory that for the desired random seed
sub_path = os.path.join(path, f"seed_{seed}/policies")
if not os.path.isdir(sub_path):
raise FileNotFoundError(f"Directory: '{sub_path}' not found")
for agent_id, policy_id in policy_map.items():
policy_path = os.path.join(sub_path, f"{policy_id}.pt")
if os.path.isfile(policy_path):
model = torch.jit.load(policy_path)
policies[agent_id] = FrozenPolicy(model)
else:
raise FileNotFoundError(f"seed '{seed}' does not define policy '{policy_id}'")
return policies, env
if __name__ == '__main__':
args = parse_args()
# Parse policy mapping if provided as a command line argument
if args.map is not None:
policy_map = {}
for idx in range(0, len(args.map), 2):
agent_id = policy_map[idx]
policy_id = policy_map[idx + 1]
if agent_id.isnumeric(): # NOTE: This is a hack due to the fact that most environments us integer agent IDs
agent_id = int(agent_id)
policy_map[agent_id] = policy_id
policy_fn = lambda id: policy_map[id]
else:
policy_map = None
policy_fn = lambda id: id
# Load policies from experiment directory
print(f"Loading policies from: {args.path}")
policies, env = load_experiment(args.path, args.seed, policy_map)
# If environment doesn't support visualization, wrap with gym visualizer
if not hasattr(env, "visualize"):
if isinstance(env, gym.Env):
env = VisualizeGym(env)
else:
raise NotImplementedError("Environment does not support visualization")
# Launch visualization
env.visualize(policies=policies,
policy_fn=policy_fn,
max_episodes=args.num_episodes,
max_steps=args.max_steps,
speed=args.speed,
record_path=args.record,
headless=args.headless)