Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
chrismatix committed Sep 28, 2023
1 parent 1a4d194 commit 46fb894
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion pax/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,10 +629,11 @@ def ppo_log(agent):
"coin_game",
"Cournot",
"Fishery",
"Rice-N",
"C-Rice-N",
"InTheMatrix",
"iterated_matrix_game",
"iterated_nplayer_tensor_game",
"CournotGame", "Fishery"
]:
policy = policy_logger_ppo(agent)
value = value_logger_ppo(agent)
Expand Down
2 changes: 1 addition & 1 deletion pax/runners/runner_evo.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def _rollout(
2,
),
)
elif args.env_id == "Rice-N":
elif args.env_id in ["Rice-N", "C-Rice-N"]:
env_stats = rice_stats([traj_1] + traj_2, args.num_players, args.has_mediator)
else:
env_stats = {}
Expand Down
2 changes: 1 addition & 1 deletion pax/runners/runner_weight_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _rollout(
for traj in trajectories:
rewards.append(jnp.where(num_episodes != 0, jnp.sum(traj.rewards) / num_episodes, 0))
env_stats = {}
if args.env_id == "Rice-N":
if args.env_id in ["Rice-N", "C-Rice-N"]:
env_stats = rice_stats(trajectories, args.num_players, args.has_mediator)

return (
Expand Down

0 comments on commit 46fb894

Please sign in to comment.