-
Notifications
You must be signed in to change notification settings - Fork 4
/
dataloader.py
100 lines (85 loc) · 3.59 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import torch
from torch.utils.data import Dataset
from util import list2tuple, tuple2list, flatten
class TestDataset(Dataset):
def __init__(self, queries, nentity, nrelation):
self.len = len(queries)
self.queries = queries
self.nentity = nentity
self.nrelation = nrelation
def __len__(self):
return self.len
def __getitem__(self, idx):
query = self.queries[idx][0]
query_structure = self.queries[idx][1]
negative_sample = torch.LongTensor(range(self.nentity))
return negative_sample, flatten(query), query, query_structure
@staticmethod
def collate_fn(data):
negative_sample = torch.stack([_[0] for _ in data], dim=0)
query = [_[1] for _ in data]
query_unflatten = [_[2] for _ in data]
query_structure = [_[3] for _ in data]
return negative_sample, query, query_unflatten, query_structure
class TrainDataset(Dataset):
def __init__(self, queries, nentity, nrelation, negative_sample_size, answer):
self.len = len(queries)
self.queries = queries
self.nentity = nentity
self.nrelation = nrelation
self.negative_sample_size = negative_sample_size
self.count = self.count_frequency(queries, answer)
self.answer = answer
def __len__(self):
return self.len
def __getitem__(self, idx):
query = self.queries[idx][0]
query_structure = self.queries[idx][1]
tail = np.random.choice(list(self.answer[query]))
subsampling_weight = self.count[query]
subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight]))
negative_sample_list = []
negative_sample_size = 0
while negative_sample_size < self.negative_sample_size:
negative_sample = np.random.randint(self.nentity, size=self.negative_sample_size*2)
mask = np.in1d(
negative_sample,
self.answer[query],
assume_unique=True,
invert=True
)
negative_sample = negative_sample[mask]
negative_sample_list.append(negative_sample)
negative_sample_size += negative_sample.size
negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size]
negative_sample = torch.from_numpy(negative_sample)
positive_sample = torch.LongTensor([tail])
return positive_sample, negative_sample, subsampling_weight, flatten(query), query_structure
@staticmethod
def collate_fn(data):
positive_sample = torch.cat([_[0] for _ in data], dim=0)
negative_sample = torch.stack([_[1] for _ in data], dim=0)
subsample_weight = torch.cat([_[2] for _ in data], dim=0)
query = [_[3] for _ in data]
query_structure = [_[4] for _ in data]
return positive_sample, negative_sample, subsample_weight, query, query_structure
@staticmethod
def count_frequency(queries, answer, start=4):
count = {}
for query, qtype in queries:
count[query] = start + len(answer[query])
return count
class SingledirectionalOneShotIterator(object):
def __init__(self, dataloader):
self.iterator = self.one_shot_iterator(dataloader)
self.step = 0
def __next__(self):
self.step += 1
data = next(self.iterator)
return data
@staticmethod
def one_shot_iterator(dataloader):
while True:
for data in dataloader:
yield data