-
Notifications
You must be signed in to change notification settings - Fork 1
/
Main_training.py
50 lines (39 loc) · 1.39 KB
/
Main_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
from Agent import Agent
from utils import plotLearning
from Environment import Env
if __name__ == '__main__':
env = Env.reset()
num_games = 250
load_checkpoint = False
agent = Agent(gamma=0.99, epsilon=1.0, lr=5e-4,
input_dims=[8], n_actions=4, mem_size=100000, eps_min=0.01,
batch_size=64, eps_dec=1e-3, replace=100)
if load_checkpoint:
agent.load_models()
filename = 'DDQN.png'
scores = []
eps_history = []
n_steps = 0
for i in range(num_games):
done = False
observation = env.reset()
score = 0
while not done:
action = agent.choose_action(observation)
observation_, reward, done, info = env.step(action)
score += reward
agent.store_transition(observation, action,
reward, observation_, int(done))
agent.learn()
observation = observation_
scores.append(score)
avg_score = np.mean(scores[max(0, i-100):(i+1)])
print('episode: ', i,'score %.1f ' % score,
' average score %.1f' % avg_score,
'epsilon %.2f' % agent.epsilon)
if i > 0 and i % 10 == 0:
agent.save_models()
eps_history.append(agent.epsilon)
x = [i+1 for i in range(num_games)]
plotLearning(x, scores, eps_history, filename)