-
Notifications
You must be signed in to change notification settings - Fork 0
/
dqn_replay_dataset.py
145 lines (118 loc) · 4.21 KB
/
dqn_replay_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import gzip
from pathlib import Path
from typing import List
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from utils.pytorch_utils import *
class DQNReplayDataset(Dataset):
"""
A dataset of observations from a one checkpoint of one game.
It saves a tensor of dimension: (dataset_size, h, w)
and given an index i returns a slice starting at i and
ending in i plus a number of frames: (slice_size, h, w).
The slice size should be equivalent to the number of frames stacked
during the RL phase.
In add adjacent mode the dataset returns three stacked observations
the observation before i the observation i and the observation after i.
(3, slice_size, h, w)
"""
def __init__(
self,
data_path: Path,
game: str,
checkpoint: int,
frames: int,
max_size: int,
transform,
actions=False,
start_index=0,
) -> None:
self.actions = None
data = torch.tensor([])
self.start_index = start_index
filename = Path(data_path / f"{game}/observation_{checkpoint}.gz")
print(f"Loading {filename}")
zipFile = gzip.GzipFile(filename=filename)
loaded_data = np.load(zipFile)
loaded_data_capped = np.copy(
loaded_data[self.start_index : self.start_index + max_size]
)
print(f"Using {loaded_data.size * loaded_data.itemsize} bytes")
print(f"Shape {loaded_data.shape}")
data = torch.from_numpy(loaded_data_capped)
setattr(self, "observation", data)
del loaded_data
del zipFile
del loaded_data_capped
if actions:
actions_filename = Path(data_path / f"{game}/action_{checkpoint}.gz")
actions_zipFile = gzip.GzipFile(filename=actions_filename)
actions_loaded_data = np.load(actions_zipFile)
actions_data_capped = np.copy(
actions_loaded_data[self.start_index : self.start_index + max_size]
)
data = torch.from_numpy(actions_data_capped)
setattr(self, "actions", data)
self.size = min(data.shape[0], max_size)
self.game = game
self.frames = frames
self.effective_size = self.size - self.frames + 1
self.transform = transform
def __len__(self):
return self.effective_size
def __getitem__(self, index: int) -> torch.Tensor:
time_ind = index % self.effective_size
res_action = self.actions[time_ind]
curr_obs = None
if self.frames <= 1:
curr_obs = self.observation[time_ind]
else:
curr_slice = slice(time_ind, time_ind + self.frames)
curr_obs = self.observation[curr_slice]
return self.transform(curr_obs), res_action
class MultiDQNReplayDataset(Dataset):
"""
This dataset corresponds to the concatenation of several DQNReplayDataset.
Meaning that it contains several checkpoints from several games.
"""
def __init__(
self,
data_path: Path,
games: Union[List[str], str],
checkpoints: List[int],
frames: int,
max_size: int,
transform,
actions=True,
start_index=0,
) -> None:
self.actions = actions
self.n_checkpoints_per_game = len(checkpoints)
if isinstance(games, str):
games = [games]
self.datasets = [
DQNReplayDataset(
data_path,
game,
ckpt,
frames,
max_size,
transform,
actions,
start_index,
)
for ckpt in checkpoints
for game in games
]
self.n_datasets = len(self.datasets)
self.single_dataset_size = len(self.datasets[0])
def __len__(self) -> int:
return self.n_datasets * self.single_dataset_size
def __getitem__(self, index: int) -> torch.Tensor:
multidataset_index = index % self.n_datasets
dataset_index = index // self.n_datasets
res_obs, res_action = self.datasets[multidataset_index][dataset_index]
return [res_obs, res_action]