Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python3.9 does not have the zip strict keyword yet #168

Merged
merged 1 commit into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
additional_dependencies: [flake8-bugbear]
args: [
"--show-source",
"--ignore=E203,E266,E501,W503,F403,F401,B008,E712",
"--ignore=E203,E266,E501,W503,F403,F401,B008,B905,E712",
"--max-line-length=100",
"--max-complexity=18",
"--select=B,C,E,F,W,T4,B9"]
2 changes: 1 addition & 1 deletion pax/agents/ppo/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def step(self, actions: jnp.ndarray) -> TimeStep:
rewards = []
observations = []
discounts = []
for env, action in zip(self.envs, actions, strict=True):
for env, action in zip(self.envs, actions):
t = env.step(int(action))
if t.step_type == 2:
t_reset = env.reset()
Expand Down
3 changes: 1 addition & 2 deletions pax/conf/experiment/rice/shaper_v_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ popsize: 1000
num_envs: 1
num_opps: 1
num_outer_steps: 200
num_inner_steps: 20
num_inner_steps: 200
num_iters: 1500
num_devices: 1
num_steps: 200


# PPO agent parameters
Expand Down
12 changes: 4 additions & 8 deletions pax/runners/runner_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _inner_rollout(carry, unused):

a1_actions = []
new_a1_memories = []
for _obs, _mem in zip(obs1, a1_mem, strict=True):
for _obs, _mem in zip(obs1, a1_mem):
a1_action, a1_state, new_a1_memory = agent1.batch_policy(
a1_state,
_obs,
Expand All @@ -155,7 +155,7 @@ def _inner_rollout(carry, unused):

a2_actions = []
new_a2_memories = []
for _obs, _mem in zip(obs2, a2_mem, strict=True):
for _obs, _mem in zip(obs2, a2_mem):
a2_action, a2_state, new_a2_memory = agent2.batch_policy(
a2_state,
_obs,
Expand Down Expand Up @@ -192,7 +192,6 @@ def _inner_rollout(carry, unused):
rewards[: self.args.agent1_roles],
new_a1_memories,
a1_mem,
strict=True,
)
]
a2_trajectories = [
Expand All @@ -211,7 +210,6 @@ def _inner_rollout(carry, unused):
rewards[self.args.agent1_roles :],
new_a2_memories,
a2_mem,
strict=True,
)
]

Expand Down Expand Up @@ -267,9 +265,7 @@ def _outer_rollout(carry, unused):
a2_metrics = {}
else:
new_a2_memories = []
for _obs, mem, traj in zip(
obs2, a2_mem, stack[1], strict=True
):
for _obs, mem, traj in zip(obs2, a2_mem, stack[1]):
a2_state, a2_mem, a2_metrics = agent2.batch_update(
traj,
_obs,
Expand Down Expand Up @@ -512,7 +508,7 @@ def run_loop(self, env, env_params, agents, num_episodes, watchers):
agent2._logger.metrics | flattened_metrics
)

for watcher, agent in zip(watchers, agents, strict=True):
for watcher, agent in zip(watchers, agents):
watcher(agent)
wandb.log(
{
Expand Down
33 changes: 16 additions & 17 deletions pax/runners/runner_evo.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,20 @@ def _inner_rollout(carry, unused):
env_rng = rngs[:, :, :, 0, :]
rngs = rngs[:, :, :, 3, :]

a1, a1_state, new_a1_mem = agent1.batch_policy(
a1_state,
obs1,
a1_mem,
)
a1_actions = []
new_a1_memories = []
for _obs, _mem in zip(obs1, a1_mem):
a1_action, a1_state, new_a1_memory = agent1.batch_policy(
a1_state,
_obs,
_mem,
)
a1_actions.append(a1_action)
new_a1_memories.append(new_a1_memory)

a2_actions = []
new_a2_memories = []
for _obs, _mem in zip(obs2, a2_mem, strict=True):
for _obs, _mem in zip(obs2, a2_mem):
a2_action, a2_state, new_a2_memory = agent2.batch_policy(
a2_state,
_obs,
Expand All @@ -211,7 +217,7 @@ def _inner_rollout(carry, unused):
a2_actions.append(a2_action)
new_a2_memories.append(new_a2_memory)

actions = jnp.asarray([a1, *a2_actions])[agent_order]
actions = jnp.asarray([*a1_actions, *a2_actions])[agent_order]
obs, env_state, rewards, done, info = env.step(
env_rng,
env_state,
Expand Down Expand Up @@ -248,7 +254,6 @@ def _inner_rollout(carry, unused):
rewards[1:],
new_a2_memories,
a2_mem,
strict=True,
)
]

Expand Down Expand Up @@ -299,9 +304,7 @@ def _outer_rollout(carry, unused):

# update second agent
new_a2_memories = []
for _obs, mem, traj in zip(
obs2, a2_mem, trajectories[1], strict=True
):
for _obs, mem, traj in zip(obs2, a2_mem, trajectories[1]):
a2_state, a2_mem, a2_metrics = agent2.batch_update(
traj,
_obs,
Expand Down Expand Up @@ -370,8 +373,6 @@ def _rollout(
if args.shuffle_players:
agent_order = jax.random.permutation(_rng_run, agent_order)

inv_agent_order = jnp.argsort(agent_order)
obs = jnp.asarray(obs)[inv_agent_order]
# run trials
vals, stack = jax.lax.scan(
_outer_rollout,
Expand Down Expand Up @@ -681,9 +682,7 @@ def run_loop(
wandb_log.update(env_stats)
# loop through population
for idx, (overall_fitness, gen_fitness) in enumerate(
zip(
log["top_fitness"], log["top_gen_fitness"], strict=True
)
zip(log["top_fitness"], log["top_gen_fitness"])
):
wandb_log[
f"train/fitness/top_overall_agent_{idx + 1}"
Expand All @@ -699,7 +698,7 @@ def run_loop(
)

agent2._logger.metrics.update(flattened_metrics)
for watcher, agent in zip(watchers, agents, strict=True):
for watcher, agent in zip(watchers, agents):
watcher(agent)
wandb_log = jax.tree_util.tree_map(
lambda x: x.item() if isinstance(x, jax.Array) else x,
Expand Down
20 changes: 6 additions & 14 deletions pax/runners/runner_evo_multishaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,23 +635,21 @@ def run_loop(
# Tell
fitness_re = [
fit_shaper.apply(x, fitness)
for x, fitness in zip(xs, shapers_fitness, strict=True)
for x, fitness in zip(xs, shapers_fitness)
]

if self.args.es.mean_reduce:
fitness_re = [fit_re - fit_re.mean() for fit_re in fitness_re]
evo_states = [
strategy.tell(x, fit_re, evo_state, es_params)
for x, fit_re, evo_state in zip(
xs, fitness_re, evo_states, strict=True
)
for x, fit_re, evo_state in zip(xs, fitness_re, evo_states)
]

# Logging
logs = [
es_log.update(log, x, fitness)
for es_log, log, x, fitness in zip(
es_logging, logs, xs, shapers_fitness, strict=True
es_logging, logs, xs, shapers_fitness
)
]
# Saving
Expand Down Expand Up @@ -727,9 +725,7 @@ def run_loop(
]
rewards_strs = shaper_rewards_strs + target_rewards_strs
rewards_val = shaper_rewards_val + target_rewards_val
rewards_dict = dict(
zip(rewards_strs, rewards_val, strict=True)
)
rewards_dict = dict(zip(rewards_strs, rewards_val))

shaper_fitness_str = [
"train/fitness/shaper_" + str(i)
Expand All @@ -748,9 +744,7 @@ def run_loop(
fitness_strs = shaper_fitness_str + target_fitness_str
fitness_vals = shaper_fitness_val + target_fitness_val

fitness_dict = dict(
zip(fitness_strs, fitness_vals, strict=True)
)
fitness_dict = dict(zip(fitness_strs, fitness_vals))

shaper_welfare = float(
sum([reward.mean() for reward in shapers_rewards])
Expand Down Expand Up @@ -801,9 +795,7 @@ def run_loop(

# other player metrics
# metrics [outer_timesteps, num_opps]
for agent, metrics in zip(
agents[1:], targets_metrics, strict=True
):
for agent, metrics in zip(agents[1:], targets_metrics):
flattened_metrics = jax.tree_util.tree_map(
lambda x: jnp.sum(jnp.mean(x, 1)), metrics
)
Expand Down
2 changes: 1 addition & 1 deletion pax/runners/runner_weight_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _inner_rollout(carry, unused) -> Tuple[Tuple, List[Sample]]:
memory.hidden,
)
for observation, action, reward, memory, new_memory in zip(
obs, actions, rewards, memories, new_memories, strict=True
obs, actions, rewards, memories, new_memories
)
]

Expand Down
66 changes: 26 additions & 40 deletions pax/watchers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,18 @@ def policy_logger(agent) -> dict:
] # [layer_name]['w']
log_pi = nn.softmax(weights)
probs = {
"policy/" + str(s): p[0] for (s, p) in zip(State, log_pi, strict=True)
"policy/" + str(s): p[0] for (s, p) in zip(State, log_pi)
} # probability of cooperating is p[0]
return probs


def value_logger(agent) -> dict:
weights = agent.critic_optimizer.target["Dense_0"]["kernel"]
values = {
f"value/{str(s)}.cooperate": p[0]
for (s, p) in zip(State, weights, strict=True)
f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights)
}
values.update(
{
f"value/{str(s)}.defect": p[1]
for (s, p) in zip(State, weights, strict=True)
}
{f"value/{str(s)}.defect": p[1] for (s, p) in zip(State, weights)}
)
return values

Expand All @@ -70,12 +66,12 @@ def policy_logger_dqn(agent) -> None:
target_steps = agent.target_step_updates
probs = {
f"policy/player_{str(pid)}/{str(s)}.cooperate": p[0]
for (s, p) in zip(State, pi, strict=True)
for (s, p) in zip(State, pi)
}
probs.update(
{
f"policy/player_{str(pid)}/{str(s)}.defect": p[1]
for (s, p) in zip(State, pi, strict=True)
for (s, p) in zip(State, pi)
}
)
probs.update({"policy/target_step_updates": target_steps})
Expand All @@ -88,12 +84,12 @@ def value_logger_dqn(agent) -> dict:
target_steps = agent.target_step_updates
values = {
f"value/player_{str(pid)}/{str(s)}.cooperate": p[0]
for (s, p) in zip(State, weights, strict=True)
for (s, p) in zip(State, weights)
}
values.update(
{
f"value/player_{str(pid)}/{str(s)}.defect": p[1]
for (s, p) in zip(State, weights, strict=True)
for (s, p) in zip(State, weights)
}
)
values.update({"value/target_step_updates": target_steps})
Expand All @@ -104,10 +100,7 @@ def policy_logger_ppo(agent: PPO) -> dict:
weights = agent._state.params["categorical_value_head/~/linear"]["w"]
pi = nn.softmax(weights)
sgd_steps = agent._total_steps / agent._num_steps
probs = {
f"policy/{str(s)}.cooperate": p[0]
for (s, p) in zip(State, pi, strict=True)
}
probs = {f"policy/{str(s)}.cooperate": p[0] for (s, p) in zip(State, pi)}
probs.update({"policy/total_steps": sgd_steps})
return probs

Expand All @@ -118,8 +111,7 @@ def value_logger_ppo(agent: PPO) -> dict:
] # 5 x 1 matrix
sgd_steps = agent._total_steps / agent._num_steps
probs = {
f"value/{str(s)}.cooperate": p[0]
for (s, p) in zip(State, weights, strict=True)
f"value/{str(s)}.cooperate": p[0] for (s, p) in zip(State, weights)
}
probs.update({"value/total_steps": sgd_steps})
return probs
Expand Down Expand Up @@ -222,7 +214,7 @@ def policy_logger_naive(agent) -> None:
sgd_steps = agent._total_steps / agent._num_steps
probs = {
f"policy/{str(s)}/{agent.player_id}.cooperate": p[0]
for (s, p) in zip(State, pi, strict=True)
for (s, p) in zip(State, pi)
}
probs.update({"policy/total_steps": sgd_steps})
return probs
Expand Down Expand Up @@ -508,10 +500,10 @@ def generate_grouped_combs_strs(num_players):
]

visitation_dict = (
dict(zip(visitation_strs, state_freq, strict=True))
| dict(zip(prob_strs, state_probs, strict=True))
| dict(zip(grouped_visitation_strs, grouped_state_freq, strict=True))
| dict(zip(grouped_prob_strs, grouped_state_probs, strict=True))
dict(zip(visitation_strs, state_freq))
| dict(zip(prob_strs, state_probs))
| dict(zip(grouped_visitation_strs, grouped_state_freq))
| dict(zip(grouped_prob_strs, grouped_state_probs))
)
return visitation_dict

Expand Down Expand Up @@ -807,20 +799,14 @@ def third_party_punishment_visitation(
total_punishment_str = "total_punishment"

visitation_dict = (
dict(zip(all_game_visitation_strs, action_freq, strict=True))
| dict(zip(all_game_prob_strs, action_probs, strict=True))
| dict(
zip(pl1_v_pl2_visitation_strs, pl1_v_pl2_action_freq, strict=True)
)
| dict(zip(pl1_v_pl2_prob_strs, pl1_v_pl2_action_probs, strict=True))
| dict(
zip(pl1_v_pl3_visitation_strs, pl1_v_pl3_action_freq, strict=True)
)
| dict(zip(pl1_v_pl3_prob_strs, pl1_v_pl3_action_probs, strict=True))
| dict(
zip(pl2_v_pl3_visitation_strs, pl2_v_pl3_action_freq, strict=True)
)
| dict(zip(pl2_v_pl3_prob_strs, pl2_v_pl3_action_probs, strict=True))
dict(zip(all_game_visitation_strs, action_freq))
| dict(zip(all_game_prob_strs, action_probs))
| dict(zip(pl1_v_pl2_visitation_strs, pl1_v_pl2_action_freq))
| dict(zip(pl1_v_pl2_prob_strs, pl1_v_pl2_action_probs))
| dict(zip(pl1_v_pl3_visitation_strs, pl1_v_pl3_action_freq))
| dict(zip(pl1_v_pl3_prob_strs, pl1_v_pl3_action_probs))
| dict(zip(pl2_v_pl3_visitation_strs, pl2_v_pl3_action_freq))
| dict(zip(pl2_v_pl3_prob_strs, pl2_v_pl3_action_probs))
| {pl1_total_defects_prob_str: pl1_total_defects_prob}
| {pl2_total_defects_prob_str: pl2_total_defects_prob}
| {pl3_total_defects_prob_str: pl3_total_defects_prob}
Expand Down Expand Up @@ -1006,10 +992,10 @@ def third_party_random_visitation(
)

visitation_dict = (
dict(zip(game_prob_strs, action_probs, strict=True))
| dict(zip(game1_prob_strs, pl1_v_pl2_action_probs, strict=True))
| dict(zip(game2_prob_strs, pl2_v_pl3_action_probs, strict=True))
| dict(zip(game3_prob_strs, pl3_v_pl1_action_probs, strict=True))
dict(zip(game_prob_strs, action_probs))
| dict(zip(game1_prob_strs, pl1_v_pl2_action_probs))
| dict(zip(game2_prob_strs, pl2_v_pl3_action_probs))
| dict(zip(game3_prob_strs, pl3_v_pl1_action_probs))
| {game_selected_punish_str: game_selected_punish}
| {pl1_defects_prob_str: pl1_defect_prob}
| {pl2_defects_prob_str: pl2_defect_prob}
Expand Down
2 changes: 1 addition & 1 deletion pax/watchers/fishery.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def fishery_eval_stats(traj1: NamedTuple, traj2: NamedTuple) -> dict:

stock_obs = traj1.observations[..., 0].squeeze().tolist()
stock_table = wandb.Table(
data=[[x, y] for (x, y) in zip(ep_length, stock_obs, strict=True)],
data=[[x, y] for (x, y) in zip(ep_length, stock_obs)],
columns=["step", "stock"],
)
# Plot the stock in a separate graph
Expand Down
Loading