Skip to content

Commit

Permalink
Python3.9 does not have the zip strict keyword yet (#168)
Browse files Browse the repository at this point in the history
Fix strict flag
  • Loading branch information
chrismatix authored Oct 19, 2023
1 parent 9d3fa62 commit df8fc26
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 85 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
additional_dependencies: [flake8-bugbear]
args: [
"--show-source",
"--ignore=E203,E266,E501,W503,F403,F401,B008,E712",
"--ignore=E203,E266,E501,W503,F403,F401,B008,B905,E712",
"--max-line-length=100",
"--max-complexity=18",
"--select=B,C,E,F,W,T4,B9"]
2 changes: 1 addition & 1 deletion pax/agents/ppo/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def step(self, actions: jnp.ndarray) -> TimeStep:
rewards = []
observations = []
discounts = []
for env, action in zip(self.envs, actions, strict=True):
for env, action in zip(self.envs, actions):
t = env.step(int(action))
if t.step_type == 2:
t_reset = env.reset()
Expand Down
3 changes: 1 addition & 2 deletions pax/conf/experiment/rice/shaper_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ popsize: 1000
num_envs: 1
num_opps: 1
num_outer_steps: 200
num_inner_steps: 20
num_inner_steps: 200
num_iters: 1500
num_devices: 1
num_steps: 200


# PPO agent parameters
Expand Down
12 changes: 4 additions & 8 deletions pax/runners/runner_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _inner_rollout(carry, unused):

a1_actions = []
new_a1_memories = []
for _obs, _mem in zip(obs1, a1_mem, strict=True):
for _obs, _mem in zip(obs1, a1_mem):
a1_action, a1_state, new_a1_memory = agent1.batch_policy(
a1_state,
_obs,
Expand All @@ -155,7 +155,7 @@ def _inner_rollout(carry, unused):

a2_actions = []
new_a2_memories = []
for _obs, _mem in zip(obs2, a2_mem, strict=True):
for _obs, _mem in zip(obs2, a2_mem):
a2_action, a2_state, new_a2_memory = agent2.batch_policy(
a2_state,
_obs,
Expand Down Expand Up @@ -192,7 +192,6 @@ def _inner_rollout(carry, unused):
rewards[: self.args.agent1_roles],
new_a1_memories,
a1_mem,
strict=True,
)
]
a2_trajectories = [
Expand All @@ -211,7 +210,6 @@ def _inner_rollout(carry, unused):
rewards[self.args.agent1_roles :],
new_a2_memories,
a2_mem,
strict=True,
)
]

Expand Down Expand Up @@ -267,9 +265,7 @@ def _outer_rollout(carry, unused):
a2_metrics = {}
else:
new_a2_memories = []
for _obs, mem, traj in zip(
obs2, a2_mem, stack[1], strict=True
):
for _obs, mem, traj in zip(obs2, a2_mem, stack[1]):
a2_state, a2_mem, a2_metrics = agent2.batch_update(
traj,
_obs,
Expand Down Expand Up @@ -512,7 +508,7 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers):
agent2._logger.metrics | flattened_metrics
)

for watcher, agent in zip(watchers, agents, strict=True):
for watcher, agent in zip(watchers, agents):
watcher(agent)
wandb.log(
{
Expand Down
33 changes: 16 additions & 17 deletions pax/runners/runner_evo.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,20 @@ def _inner_rollout(carry, unused):
env_rng = rngs[:, :, :, 0, :]
rngs = rngs[:, :, :, 3, :]

a1, a1_state, new_a1_mem = agent1.batch_policy(
a1_state,
obs1,
a1_mem,
)
a1_actions = []
new_a1_memories = []
for _obs, _mem in zip(obs1, a1_mem):
a1_action, a1_state, new_a1_memory = agent1.batch_policy(
a1_state,
_obs,
_mem,
)
a1_actions.append(a1_action)
new_a1_memories.append(new_a1_memory)

a2_actions = []
new_a2_memories = []
for _obs, _mem in zip(obs2, a2_mem, strict=True):
for _obs, _mem in zip(obs2, a2_mem):
a2_action, a2_state, new_a2_memory = agent2.batch_policy(
a2_state,
_obs,
Expand All @@ -211,7 +217,7 @@ def _inner_rollout(carry, unused):
a2_actions.append(a2_action)
new_a2_memories.append(new_a2_memory)

actions = jnp.asarray([a1, *a2_actions])[agent_order]
actions = jnp.asarray([*a1_actions, *a2_actions])[agent_order]
obs, env_state, rewards, done, info = env.step(
env_rng,
env_state,
Expand Down Expand Up @@ -248,7 +254,6 @@ def _inner_rollout(carry, unused):
rewards[1:],
new_a2_memories,
a2_mem,
strict=True,
)
]

Expand Down Expand Up @@ -299,9 +304,7 @@ def _outer_rollout(carry, unused):

# update second agent
new_a2_memories = []
for _obs, mem, traj in zip(
obs2, a2_mem, trajectories[1], strict=True
):
for _obs, mem, traj in zip(obs2, a2_mem, trajectories[1]):
a2_state, a2_mem, a2_metrics = agent2.batch_update(
traj,
_obs,
Expand Down Expand Up @@ -370,8 +373,6 @@ def _rollout(
if args.shuffle_players:
agent_order = jax.random.permutation(_rng_run, agent_order)

inv_agent_order = jnp.argsort(agent_order)
obs = jnp.asarray(obs)[inv_agent_order]
# run trials
vals, stack = jax.lax.scan(
_outer_rollout,
Expand Down Expand Up @@ -681,9 +682,7 @@ def run_loop(
wandb_log.update(env_stats)
# loop through population
for idx, (overall_fitness, gen_fitness) in enumerate(
zip(
log["top_fitness"], log["top_gen_fitness"], strict=True
)
zip(log["top_fitness"], log["top_gen_fitness"])
):
wandb_log[
f"train/fitness/top_overall_agent_{idx + 1}"
Expand All @@ -699,7 +698,7 @@ def run_loop(
)

agent2._logger.metrics.update(flattened_metrics)
for watcher, agent in zip(watchers, agents, strict=True):
for watcher, agent in zip(watchers, agents):
watcher(agent)
wandb_log = jax.tree_util.tree_map(
lambda x: x.item() if isinstance(x, jax.Array) else x,
Expand Down
20 changes: 6 additions & 14 deletions pax/runners/runner_evo_multishaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,23 +635,21 @@ def run_loop(
# Tell
fitness_re = [
fit_shaper.apply(x, fitness)
for x, fitness in zip(xs, shapers_fitness, strict=True)
for x, fitness in zip(xs, shapers_fitness)
]

if self.args.es.mean_reduce:
fitness_re = [fit_re - fit_re.mean() for fit_re in fitness_re]
evo_states = [
strategy.tell(x, fit_re, evo_state, es_params)
for x, fit_re, evo_state in zip(
xs, fitness_re, evo_states, strict=True
)
for x, fit_re, evo_state in zip(xs, fitness_re, evo_states)
]

# Logging
logs = [
es_log.update(log, x, fitness)
for es_log, log, x, fitness in zip(
es_logging, logs, xs, shapers_fitness, strict=True
es_logging, logs, xs, shapers_fitness
)
]
# Saving
Expand Down Expand Up @@ -727,9 +725,7 @@ def run_loop(
]
rewards_strs = shaper_rewards_strs + target_rewards_strs
rewards_val = shaper_rewards_val + target_rewards_val
rewards_dict = dict(
zip(rewards_strs, rewards_val, strict=True)
)
rewards_dict = dict(zip(rewards_strs, rewards_val))

shaper_fitness_str = [
"train/fitness/shaper_" + str(i)
Expand All @@ -748,9 +744,7 @@ def run_loop(
fitness_strs = shaper_fitness_str + target_fitness_str
fitness_vals = shaper_fitness_val + target_fitness_val

fitness_dict = dict(
zip(fitness_strs, fitness_vals, strict=True)
)
fitness_dict = dict(zip(fitness_strs, fitness_vals))

shaper_welfare = float(
sum([reward.mean() for reward in shapers_rewards])
Expand Down Expand Up @@ -801,9 +795,7 @@ def run_loop(

# other player metrics
# metrics [outer_timesteps, num_opps]
for agent, metrics in zip(
agents[1:], targets_metrics, strict=True
):
for agent, metrics in zip(agents[1:], targets_metrics):
flattened_metrics = jax.tree_util.tree_map(
lambda x: jnp.sum(jnp.mean(x, 1)), metrics
)
Expand Down
2 changes: 1 addition & 1 deletion pax/runners/runner_weight_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _inner_rollout(carry, unused) -> Tuple[Tuple, List[Sample]]:
memory.hidden,
)
for observation, action, reward, memory, new_memory in zip(
obs, actions, rewards, memories, new_memories, strict=True
obs, actions, rewards, memories, new_memories
)
]

Expand Down
66 changes: 26 additions & 40 deletions pax/watchers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,18 @@ def policy_logger(agent) -> dict:
] # [layer_name]['w']
log_pi = nn.softmax(weights)
probs = {
"policy/" + str(s): p[0] for (s, p) in zip(State, log_pi, strict=True)
"policy/" + str(s): p[0] for (s, p) in zip(State, log_pi)
} # probability of cooperating is p[0]
return probs


def value_logger(agent) -> dict:
weights = agent.critic_optimizer.target["Dense_0"]["kernel"]
values = {
f"value/{str(s)}.cooperate": p[0]
for (s, p) in zip(State, weights, strict=True)
f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights)
}
values.update(
{
f"value/{str(s)}.defect": p[1]
for (s, p) in zip(State, weights, strict=True)
}
{f"value/{str(s)}.defect": p[1] for (s, p) in zip(State, weights)}
)
return values

Expand All @@ -70,12 +66,12 @@ def policy_logger_dqn(agent) -> None:
target_steps = agent.target_step_updates
probs = {
f"policy/player_{str(pid)}/{str(s)}.cooperate": p[0]
for (s, p) in zip(State, pi, strict=True)
for (s, p) in zip(State, pi)
}
probs.update(
{
f"policy/player_{str(pid)}/{str(s)}.defect": p[1]
for (s, p) in zip(State, pi, strict=True)
for (s, p) in zip(State, pi)
}
)
probs.update({"policy/target_step_updates": target_steps})
Expand All @@ -88,12 +84,12 @@ def value_logger_dqn(agent) -> dict:
target_steps = agent.target_step_updates
values = {
f"value/player_{str(pid)}/{str(s)}.cooperate": p[0]
for (s, p) in zip(State, weights, strict=True)
for (s, p) in zip(State, weights)
}
values.update(
{
f"value/player_{str(pid)}/{str(s)}.defect": p[1]
for (s, p) in zip(State, weights, strict=True)
for (s, p) in zip(State, weights)
}
)
values.update({"value/target_step_updates": target_steps})
Expand All @@ -104,10 +100,7 @@ def policy_logger_ppo(agent: PPO) -> dict:
weights = agent._state.params["categorical_value_head/~/linear"]["w"]
pi = nn.softmax(weights)
sgd_steps = agent._total_steps / agent._num_steps
probs = {
f"policy/{str(s)}.cooperate": p[0]
for (s, p) in zip(State, pi, strict=True)
}
probs = {f"policy/{str(s)}.cooperate": p[0] for (s, p) in zip(State, pi)}
probs.update({"policy/total_steps": sgd_steps})
return probs

Expand All @@ -118,8 +111,7 @@ def value_logger_ppo(agent: PPO) -> dict:
] # 5 x 1 matrix
sgd_steps = agent._total_steps / agent._num_steps
probs = {
f"value/{str(s)}.cooperate": p[0]
for (s, p) in zip(State, weights, strict=True)
f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights)
}
probs.update({"value/total_steps": sgd_steps})
return probs
Expand Down Expand Up @@ -222,7 +214,7 @@ def policy_logger_naive(agent) -> None:
sgd_steps = agent._total_steps / agent._num_steps
probs = {
f"policy/{str(s)}/{agent.player_id}.cooperate": p[0]
for (s, p) in zip(State, pi, strict=True)
for (s, p) in zip(State, pi)
}
probs.update({"policy/total_steps": sgd_steps})
return probs
Expand Down Expand Up @@ -508,10 +500,10 @@ def generate_grouped_combs_strs(num_players):
]

visitation_dict = (
dict(zip(visitation_strs, state_freq, strict=True))
| dict(zip(prob_strs, state_probs, strict=True))
| dict(zip(grouped_visitation_strs, grouped_state_freq, strict=True))
| dict(zip(grouped_prob_strs, grouped_state_probs, strict=True))
dict(zip(visitation_strs, state_freq))
| dict(zip(prob_strs, state_probs))
| dict(zip(grouped_visitation_strs, grouped_state_freq))
| dict(zip(grouped_prob_strs, grouped_state_probs))
)
return visitation_dict

Expand Down Expand Up @@ -807,20 +799,14 @@ def third_party_punishment_visitation(
total_punishment_str = "total_punishment"

visitation_dict = (
dict(zip(all_game_visitation_strs, action_freq, strict=True))
| dict(zip(all_game_prob_strs, action_probs, strict=True))
| dict(
zip(pl1_v_pl2_visitation_strs, pl1_v_pl2_action_freq, strict=True)
)
| dict(zip(pl1_v_pl2_prob_strs, pl1_v_pl2_action_probs, strict=True))
| dict(
zip(pl1_v_pl3_visitation_strs, pl1_v_pl3_action_freq, strict=True)
)
| dict(zip(pl1_v_pl3_prob_strs, pl1_v_pl3_action_probs, strict=True))
| dict(
zip(pl2_v_pl3_visitation_strs, pl2_v_pl3_action_freq, strict=True)
)
| dict(zip(pl2_v_pl3_prob_strs, pl2_v_pl3_action_probs, strict=True))
dict(zip(all_game_visitation_strs, action_freq))
| dict(zip(all_game_prob_strs, action_probs))
| dict(zip(pl1_v_pl2_visitation_strs, pl1_v_pl2_action_freq))
| dict(zip(pl1_v_pl2_prob_strs, pl1_v_pl2_action_probs))
| dict(zip(pl1_v_pl3_visitation_strs, pl1_v_pl3_action_freq))
| dict(zip(pl1_v_pl3_prob_strs, pl1_v_pl3_action_probs))
| dict(zip(pl2_v_pl3_visitation_strs, pl2_v_pl3_action_freq))
| dict(zip(pl2_v_pl3_prob_strs, pl2_v_pl3_action_probs))
| {pl1_total_defects_prob_str: pl1_total_defects_prob}
| {pl2_total_defects_prob_str: pl2_total_defects_prob}
| {pl3_total_defects_prob_str: pl3_total_defects_prob}
Expand Down Expand Up @@ -1006,10 +992,10 @@ def third_party_random_visitation(
)

visitation_dict = (
dict(zip(game_prob_strs, action_probs, strict=True))
| dict(zip(game1_prob_strs, pl1_v_pl2_action_probs, strict=True))
| dict(zip(game2_prob_strs, pl2_v_pl3_action_probs, strict=True))
| dict(zip(game3_prob_strs, pl3_v_pl1_action_probs, strict=True))
dict(zip(game_prob_strs, action_probs))
| dict(zip(game1_prob_strs, pl1_v_pl2_action_probs))
| dict(zip(game2_prob_strs, pl2_v_pl3_action_probs))
| dict(zip(game3_prob_strs, pl3_v_pl1_action_probs))
| {game_selected_punish_str: game_selected_punish}
| {pl1_defects_prob_str: pl1_defect_prob}
| {pl2_defects_prob_str: pl2_defect_prob}
Expand Down
2 changes: 1 addition & 1 deletion pax/watchers/fishery.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def fishery_eval_stats(traj1: NamedTuple, traj2: NamedTuple) -> dict:

stock_obs = traj1.observations[..., 0].squeeze().tolist()
stock_table = wandb.Table(
data=[[x, y] for (x, y) in zip(ep_length, stock_obs, strict=True)],
data=[[x, y] for (x, y) in zip(ep_length, stock_obs)],
columns=["step", "stock"],
)
# Plot the stock in a separate graph
Expand Down

0 comments on commit df8fc26

Please sign in to comment.