-
Notifications
You must be signed in to change notification settings - Fork 8
/
data.py
90 lines (62 loc) · 2.56 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#! /usr/bin/env python
# coding=utf-8
# /************************************************************************************
# ***
# *** File Author: Dell, 2018年 09月 21日 星期五 10:25:44 CST
# ***
# ************************************************************************************/
import pickle
from torchtext import data
def english_token(x):
return [w for w in x.split(" ") if len(w) > 0]
EnglishText = data.Field(sequential=True, tokenize=english_token, lower=True, include_lengths=True)
def chinese_token(x):
return [w for w in x.split(" ") if len(w) > 0]
ChineseText = data.Field(sequential=True, tokenize=chinese_token, include_lengths=True)
class TranslateDataset(data.Dataset):
@staticmethod
def sort_key(ex):
return len(ex.text)
def __init__(self, path, src_field, trg_field, sep='\t', **kwargs):
"""Create an dataset instance given a path and fields.
Arguments:
path: Path to the data file.
src_field: The field that will be used for source data.
trg_field: The field that will be used for destion data.
kwargs: Passed to the constructor of data.Dataset.
"""
fields = [('src', src_field), ('trg', trg_field)]
examples = []
with open(path, errors='ignore') as f:
for line in f:
s = line.strip().split(sep)
if len(s) != 2:
continue
src, trg = s[0], s[1]
e = data.Example()
setattr(e, "src", src_field.preprocess(src))
setattr(e, "trg", trg_field.preprocess(trg))
examples.append(e)
super(TranslateDataset, self).__init__(examples, fields, **kwargs)
def translate_dataloader(datafile, batchsize, shuffle=False):
src_field = EnglishText
trg_field = ChineseText
dataset = TranslateDataset(datafile, src_field, trg_field)
src_field.build_vocab(dataset)
trg_field.build_vocab(dataset)
dataiter = data.Iterator(dataset, batchsize, shuffle, repeat=False)
# dataiter.init_epoch()
return dataiter, src_field, trg_field
def save_vocab(vocab, filename):
with open(filename, 'wb') as f:
pickle.dump(vocab, f)
def load_vocab(filename):
with open(filename, 'rb') as f:
vocab = pickle.load(f)
return vocab
def test():
x, _, _ = translate_dataloader("data/en-zh.txt", 32, shuffle=False)
batch = next(iter(x))
src, trg = batch.src, batch.trg
print("src: ", type(src), src.size(), src)
print("trg: ", type(trg), trg.size(), trg)