-
Notifications
You must be signed in to change notification settings - Fork 1
/
eval_random.py
111 lines (89 loc) · 3.73 KB
/
eval_random.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from envs import EinsteinWuerfeltNichtEnv, MinimaxEnv
from classical_policies import RandomAgent
import numpy as np
from tqdm import tqdm
from constants import ClassicalPolicy
from statsmodels.stats.proportion import proportion_confint
from tqdm import trange
import argparse
import gymnasium as gym
from gymnasium.envs.registration import register
register(
id='EWN-v0',
entry_point='envs:EinsteinWuerfeltNichtEnv'
)
def evaluation(env, model, render_last, eval_num=100) -> np.ndarray:
score = np.zeros(eval_num)
# Run eval_num times rollouts
for seed in trange(eval_num):
done = False
# Set seed and reset env using Gymnasium API
obs, info = env.reset(seed=seed)
reward = 0
while not done:
# Interact with env using Gymnasium API
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, _, info = env.step(action)
# Render the last board state of each episode
# print("Last board state:")
# env.render()
# The episode number is same as the seed
episode = seed
score[episode] = reward
# Render last rollout
if render_last:
print("Rendering last rollout")
done = False
obs, info = env.reset(seed=eval_num - 1)
env.render()
while not done:
action, _state = model.predict(obs, deterministic=True)
obs, reward, done, _, info = env.step(action)
env.render()
return score
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description='Evaluate random agent', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num', type=int, default=100,
help='Number of rollouts')
parser.add_argument('--render_last', action='store_true',
help='Render last rollout', default=False)
parser.add_argument('--cube_layer', type=int, default=3,
help='Number of cube layers')
parser.add_argument('--board_size', type=int, default=5,
help='Board size')
parser.add_argument('--significance_level', type=float, default=0.05,
help='Board size')
parser.add_argument('--opponent_policy', type=ClassicalPolicy.from_string, default=ClassicalPolicy.random, choices=list(ClassicalPolicy),
help='Opponent policy')
parser.add_argument(
'--max_depth',
type=int,
default=3,
help='Max depth for minimax')
parser.add_argument('--model_folder', type=str, default='alpha_zero_models',
help='folder of model')
parser.add_argument('--model_name', type=str, default='checkpoint_100.pth.tar',
help='name of model')
parser.add_argument('--num_simulations', type=int, default=10,
help='Number of simulations per env')
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
env = gym.make(
'EWN-v0',
# render_mode='human',
**args.__dict__
)
agent = RandomAgent(
env=env,
)
eval_num = args.num
score = evaluation(env, agent, args.render_last, eval_num)
print("Avg_score: ", np.mean(score))
winrate: float = np.count_nonzero(score > 0) / eval_num
print("Avg win rate: ", winrate)
# print("Avg_highest:", np.sum(highest) / eval_num)
#calculate (1-alpha)% confidence interval with {win_count} successes in {num_simulations} trials
print(f'The {1-args.significance_level} confidence interval: {proportion_confint(count=winrate, nobs=eval_num, alpha=args.significance_level)}')
print(f"Counts: (Total of {eval_num} rollouts)")