forked from batuhan3526/AirSim-PyTorch-Drone-DDQN-Agent
-
Notifications
You must be signed in to change notification settings - Fork 2
/
inference.py
140 lines (107 loc) · 4.55 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import time
import numpy as np
import cv2
from DQRN_net import QNetwork
from env import DroneEnv
import torch
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def transformToTensor(img):
tensor = torch.FloatTensor(img).to(device)
tensor = tensor.unsqueeze(0)
tensor = tensor.unsqueeze(0)
tensor = tensor.float()
return tensor
def test_DQN(checkpoint_dict,max_steps):
model_dict=checkpoint_dict["state_dict"]
episode=checkpoint_dict["episode"]
steps_done=checkpoint_dict["steps_done"]
test_network=QNetwork()
test_network.load_state_dict(model_dict)
start = time.time()
steps = 0
score = 0
image_array = []
env = DroneEnv()
state, next_state_image = env.reset()
image_array.append(next_state_image)
while True:
state = transformToTensor(state)
action = int(np.argmax(test_network(state).cpu().data.squeeze().numpy()))
next_state, reward, done, next_state_image = env.step(action)
image_array.append(next_state_image)
if steps == max_steps:
done = 1
state = next_state
steps += 1
score += reward
if done:
print("----------------------------------------------------------------------------------------")
print("TEST, reward: {}, score: {}, total steps: {}".format(
reward, score, steps_done))
with open('tests.txt', 'a') as file:
file.write("TEST, reward: {}, score: {}, total steps: {}\n".format(
reward, score, steps_done))
# writer.add_scalars('Test', {'score': score, 'reward': reward}, episode)
end = time.time()
stopWatch = end - start
print("Test is done, test time: ", stopWatch)
# Convert images to video
frameSize = (256, 144)
video = cv2.VideoWriter("videos\\test_video_episode_{}_score_{}.avi".format(episode, score), cv2.VideoWriter_fourcc(*'DIVX'), 7, frameSize)
for img in image_array:
video.write(img)
video.release()
break
def test_DQRN(checkpoint_dict,max_steps,num_frames=7):
model_dict=checkpoint_dict["state_dict"]
episode=checkpoint_dict["episode"]
steps_done=checkpoint_dict["steps_done"]
test_network=QNetwork()
test_network.load_state_dict(model_dict)
test_network.to(device)
env = DroneEnv()
start = time.time()
steps = 0
score = 0
image_array = []
state, next_state_image = env.reset()
image_array.append(next_state_image)
state_sequence = []
while True:
if isinstance(state, np.ndarray):
state = transformToTensor(state)
if len(state_sequence) == 0:
for _ in range(num_frames):
state_sequence.append(state)
action = int(np.argmax(test_network(torch.stack(state_sequence).permute(1,0,2,3,4)).cpu().data.squeeze().numpy()))
next_state, reward, done, next_state_image = env.step(action)
image_array.append(next_state_image)
if steps == max_steps:
done = 1
state_sequence.append(state)
if len(state_sequence) > 7: # Keep the sequence length fixed at 7 steps
state_sequence.pop(0) # Remove oldest step if sequence exceeds length
state = next_state
steps += 1
score += reward
if done:
print("----------------------------------------------------------------------------------------")
print("TEST, reward: {}, score: {}, total steps: {}".format(
reward, score, steps_done))
with open('tests.txt', 'a') as file:
file.write("TEST, reward: {}, score: {}, total steps: {}\n".format(
reward, score, steps_done))
end = time.time()
stopWatch = end - start
print("Test is done, test time: ", stopWatch)
# Convert images to video
frameSize = (256, 144)
import cv2
video = cv2.VideoWriter("videos\\test_video_episode_{}_score_{}.avi".format(episode, score), cv2.VideoWriter_fourcc(*'DIVX'), 7, frameSize)
for img in image_array:
video.write(img)
video.release()
break
if __name__ == "__main__":
checkpoint = torch.load("EPISODE76.pt")
test_DQRN(checkpoint,34)