-
Notifications
You must be signed in to change notification settings - Fork 1
/
jpc.py
277 lines (202 loc) · 8.61 KB
/
jpc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
#!/usr/bin/env python3
'''Computes the Joint Policy Correlation matrix for a set of trained policies'''
import argparse
from collections import defaultdict
import io
import matplotlib.pyplot as plt
import numpy as np
import os
import os.path
import traceback
import yaml
import torch
from torch.multiprocessing import Pool
from interactive_agents.envs import get_env_class
from interactive_agents.sampling import sample, FrozenPolicy
# TODO: We compute the JPC in several different scripts, would be good to consolidate this implementation
# TODO: This script may be deprecated, compare with the other two JPC scripts
def print_error(error):
traceback.print_exception(type(error), error, error.__traceback__, limit=5)
def parse_args():
parser = argparse.ArgumentParser("Computes the Joint Policy Correlation matrix for a set of trained policies")
parser.add_argument("path", type=str, help="path to directory containing training results")
parser.add_argument("num_seeds", type=int, help="number of seeds used in the experiment")
parser.add_argument("-o", "--output-path", type=str, default=None,
help="directory in which we should save matrix (defaults to experiment directory)")
parser.add_argument("-f", "--filename", type=str, default="jpc",
help="filename for saved matrix")
parser.add_argument("-n", "--num-cpus", type=int, default=1,
help="the number of parallel worker processes to launch")
parser.add_argument("-e", "--num-episodes", type=int, default=100,
help="the number of episodes to run for each policy combination")
parser.add_argument("-m", "--map", nargs="+")
parser.add_argument("--title", type=str, default="Joint Policy Correlation",
help="title for figure")
parser.add_argument("--min", type=float, help="min payoff value (for image rendering)")
parser.add_argument("--max", type=float, help="max payoff value (for image rendering)")
parser.add_argument("-d", "--display", action="store_true", help="display JPC matrix when ready")
return parser.parse_args()
def plot_matrix(matrix, path, title, min, max, size=300, disp=False):
if min is None:
min = matrix.min()
if max is None:
max = matrix.max()
# Scale range to cut off dark reds
max += 0.15 * (max - min)
cm = plt.get_cmap("jet")
# Ticks for each seed on the x and y axis
tick_space = size / matrix.shape[0]
tick_pos = 0.5 * tick_space
ticks = []
labels = []
for idx in range(matrix.shape[0]):
ticks.append(tick_pos)
labels.append(idx)
tick_pos += tick_space
# Generate figure
plt.clf()
im = plt.imshow(matrix,
cmap=cm,
vmin=min,
vmax=max,
extent=(0,size,0,size))
plt.colorbar(im)
plt.xticks(ticks, labels=labels)
plt.yticks(ticks, labels=labels)
ax = plt.gca()
ax.grid(which='minor', color='k', linestyle='-', linewidth=2)
plt.title(title, fontsize=14)
plt.xlabel("seeds", fontsize=16)
plt.ylabel("seeds", fontsize=16)
plt.savefig(path, bbox_inches="tight")
if disp:
plt.show(block=True)
def load_populations(path, policy_map, num_seeds):
populations = defaultdict(dict)
config_path = os.path.join(path, "config.yaml")
if not os.path.isfile(config_path):
raise ValueError(f"Config File: '{config_path}' not defined")
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if "trainer" not in config: # NOTE: When would this be needed?
config = list(config.values())[0]
trainer_config = config.get("config", {})
if policy_map is None:
env_name = trainer_config.get("env")
env_config = trainer_config.get("env_config", {})
env_config = trainer_config.get("env_eval_config", env_config)
env_cls = get_env_class(env_name)
env = env_cls(env_config, spec_only=True)
map = {}
for policy_id in env.observation_spaces.keys():
map[policy_id] = policy_id
else:
map = {}
for idx in range(0, len(policy_map), 2):
agent_id = policy_map[idx]
policy_id = policy_map[idx + 1]
if agent_id.isnumeric():
agent_id = int(agent_id)
map[agent_id] = policy_id
for seed in range(num_seeds):
sub_path = os.path.join(path, f"seed_{seed}/policies")
if os.path.isdir(sub_path):
print(f"\nloading path: {sub_path}")
for agent_id, policy_id in map.items():
policy_path = os.path.join(sub_path, f"{policy_id}.pt")
print(f"loading: {policy_path}")
if os.path.isfile(policy_path):
model = torch.jit.load(policy_path)
populations[seed][agent_id] = model
else:
raise FileNotFoundError(f"seed '{seed}' does not define policy '{policy_id}'")
return populations, trainer_config
def evaluate(env_cls, env_config, models, num_episodes, max_steps):
# Build environment instance
env = env_cls(env_config)
# Instantiate policies
policies = {}
for id, model in models.items():
if isinstance(model, io.BytesIO):
model.seek(0)
model = torch.jit.load(model)
policies[id] = FrozenPolicy(model)
stats = sample(env, policies, num_episodes, max_steps).statistics()
return stats
def permutations(num_agents, num_populations):
num_permutations = num_populations ** num_agents
for index in range(num_permutations):
permutation = [0] * num_agents
idx = index
for id in range(num_agents):
permutation[id] = idx % num_populations
idx = idx // num_populations
yield permutation
def cross_evaluate(populations, config, num_cpus, num_episodes):
# NOTE: Used as a handle for single-threaded execution
class dummy_async:
def __init__(self, result):
self._result = result
def get(self):
return self._result
if num_cpus > 1:
pool = Pool(num_cpus)
max_steps = config.get("max_steps", 100)
env_name = config.get("env")
env_config = config.get("env_config", {})
env_cls = get_env_class(env_name)
env = env_cls(env_config, spec_only=True)
agent_ids = list(env.observation_spaces.keys())
population_ids = list(populations.keys())
num_agents = len(agent_ids)
num_populations = len(population_ids)
threads = {}
for permutation in permutations(num_agents, num_populations):
models = {}
for a, p in enumerate(permutation):
agent_id = agent_ids[a]
models[agent_id] = populations[p][agent_id]
idx = tuple(permutation)
if num_cpus > 1:
# Serialize torch policies
for id, model in models.items():
buffer = io.BytesIO()
torch.jit.save(model, buffer)
models[id] = buffer
threads[idx] = pool.apply_async(evaluate, (env_cls, env_config,
models, num_episodes, max_steps), error_callback=print_error)
else:
threads[idx] = dummy_async(evaluate(env_cls,
env_config, models, num_episodes, max_steps))
returns = np.zeros(tuple([num_populations] * num_agents))
for idx, thread in threads.items():
stats = thread.get()
returns[idx] = stats["reward_mean"]
return returns
if __name__ == '__main__':
args = parse_args()
# Limit CPU paralellism
torch.set_num_threads(args.num_cpus)
print(f"Loading policies from: {args.path}")
populations, config = load_populations(args.path, args.map, args.num_seeds)
print(f"Evaluating Policies")
jpc = cross_evaluate(populations, config, args.num_cpus, args.num_episodes)
print("\nJCP Tensor:")
print(jpc)
if args.output_path is not None:
matrix_path = os.path.join(args.output_path, args.filename + ".npy")
image_path = os.path.join(args.output_path, args.filename + ".png")
else:
matrix_path = os.path.join(args.path, args.filename + ".npy")
image_path = os.path.join(args.path, args.filename + ".png")
print(f"\nwriting JPC tensor to: {matrix_path}")
np.save(matrix_path, jpc, allow_pickle=False)
if len(jpc.shape) == 2:
print(f"\nrendering JPC tensor to: {matrix_path}")
plot_matrix(
jpc,
image_path,
title=args.title,
min=args.min,
max=args.max,
disp=args.display)