diff --git a/pax/experiment.py b/pax/experiment.py index 4dd9553d..c8942880 100644 --- a/pax/experiment.py +++ b/pax/experiment.py @@ -629,10 +629,11 @@ def ppo_log(agent): "coin_game", "Cournot", "Fishery", + "Rice-N", + "C-Rice-N", "InTheMatrix", "iterated_matrix_game", "iterated_nplayer_tensor_game", - "CournotGame", "Fishery" ]: policy = policy_logger_ppo(agent) value = value_logger_ppo(agent) diff --git a/pax/runners/runner_evo.py b/pax/runners/runner_evo.py index af85b7b5..252a6676 100644 --- a/pax/runners/runner_evo.py +++ b/pax/runners/runner_evo.py @@ -452,7 +452,7 @@ def _rollout( 2, ), ) - elif args.env_id == "Rice-N": + elif args.env_id in ["Rice-N", "C-Rice-N"]: env_stats = rice_stats([traj_1] + traj_2, args.num_players, args.has_mediator) else: env_stats = {} diff --git a/pax/runners/runner_weight_sharing.py b/pax/runners/runner_weight_sharing.py index b9798906..93cfbc07 100644 --- a/pax/runners/runner_weight_sharing.py +++ b/pax/runners/runner_weight_sharing.py @@ -177,7 +177,7 @@ def _rollout( for traj in trajectories: rewards.append(jnp.where(num_episodes != 0, jnp.sum(traj.rewards) / num_episodes, 0)) env_stats = {} - if args.env_id == "Rice-N": + if args.env_id in ["Rice-N", "C-Rice-N"]: env_stats = rice_stats(trajectories, args.num_players, args.has_mediator) return (