-
Notifications
You must be signed in to change notification settings - Fork 14
/
data_loader.py
executable file
·106 lines (88 loc) · 4.29 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Borrowed from https://github.com/EvelynFan/FaceFormer/blob/main/data_loader.py
import os
import torch
from collections import defaultdict
from torch.utils import data
import numpy as np
import pickle
from tqdm import tqdm
from transformers import Wav2Vec2Processor
import librosa
class Dataset(data.Dataset):
"""Custom data.Dataset compatible with data.DataLoader."""
def __init__(self, data, subjects_dict, data_type="train"):
self.data = data
self.len = len(self.data)
self.subjects_dict = subjects_dict
self.data_type = data_type
def __getitem__(self, index):
"""Returns one data pair (source and target)."""
# seq_len, fea_dim
file_name = self.data[index]["name"]
audio = self.data[index]["audio"]
vertice = self.data[index]["vertice"]
template = self.data[index]["template"]
return torch.FloatTensor(audio), torch.FloatTensor(vertice), torch.FloatTensor(template), file_name
def __len__(self):
return self.len
def read_data(args):
print("Loading data...")
data = defaultdict(dict)
train_data = []
valid_data = []
test_data = []
audio_path = os.path.join(args.dataset, args.wav_path)
vertices_path = os.path.join(args.dataset, args.vertices_path)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
template_file = os.path.join(args.dataset, args.template_file)
with open(template_file, 'rb') as fin:
templates = pickle.load(fin, encoding='latin1')
for r, ds, fs in os.walk(audio_path):
for f in tqdm(fs):
if f.endswith("wav"):
wav_path = os.path.join(r, f)
speech_array, sampling_rate = librosa.load(wav_path, sr=16000)
input_values = np.squeeze(processor(speech_array, sampling_rate=16000).input_values)
key = f.replace("wav", "npy")
data[key]["audio"] = input_values
subject_id = "_".join(key.split("_")[:-1])
temp = templates[subject_id]
data[key]["name"] = f
data[key]["template"] = temp.reshape((-1))
vertice_path = os.path.join(vertices_path, f.replace("wav", "npy"))
if not os.path.exists(vertice_path):
del data[key]
else:
if args.dataset == "vocaset":
data[key]["vertice"] = np.load(vertice_path, allow_pickle=True)[::2, :]
elif args.dataset == "BIWI":
data[key]["vertice"] = np.load(vertice_path, allow_pickle=True)
subjects_dict = {}
subjects_dict["train"] = [i for i in args.train_subjects.split(" ")]
subjects_dict["val"] = [i for i in args.val_subjects.split(" ")]
subjects_dict["test"] = [i for i in args.test_subjects.split(" ")]
splits = {'vocaset': {'train': range(1, 41), 'val': range(21, 41), 'test': range(21, 41)},
'BIWI': {'train': range(1, 33), 'val': range(33, 37), 'test': range(37, 41)}}
for k, v in data.items():
subject_id = "_".join(k.split("_")[:-1])
sentence_id = int(k.split(".")[0][-2:])
if subject_id in subjects_dict["train"] and sentence_id in splits[args.dataset]['train']:
train_data.append(v)
if subject_id in subjects_dict["val"] and sentence_id in splits[args.dataset]['val']:
valid_data.append(v)
if subject_id in subjects_dict["test"] and sentence_id in splits[args.dataset]['test']:
test_data.append(v)
print('Loaded data: Train-{}, Val-{}, Test-{}'.format(len(train_data), len(valid_data), len(test_data)))
return train_data, valid_data, test_data, subjects_dict
def get_dataloaders(args):
dataset = {}
train_data, valid_data, test_data, subjects_dict = read_data(args)
train_data = Dataset(train_data, subjects_dict, "train")
dataset["train"] = data.DataLoader(dataset=train_data, batch_size=1, shuffle=True)
valid_data = Dataset(valid_data, subjects_dict, "val")
dataset["valid"] = data.DataLoader(dataset=valid_data, batch_size=1, shuffle=False)
test_data = Dataset(test_data, subjects_dict, "test")
dataset["test"] = data.DataLoader(dataset=test_data, batch_size=1, shuffle=False)
return dataset
if __name__ == "__main__":
get_dataloaders()