-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
161 lines (146 loc) · 7.87 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import multiprocessing
import torch
from torch.utils import data
from functools import partial
import torchvision.transforms as transforms
import torchvision.datasets as datasets
cifar10_mean = [0.4914, 0.4822, 0.4465]
cifar10_std = [0.2023, 0.1994, 0.2010]
"""
MNIST and CIFAR10 datasets with `index` also returned in `__getitem__`
"""
class MNIST(datasets.MNIST):
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False, use_index=False):
super().__init__(root, train, transform, target_transform, download)
self.use_index = use_index
def __getitem__(self, index):
img, target = super().__getitem__(index)
if self.use_index:
return img, target, index
else:
return img, target
class CIFAR10(datasets.CIFAR10):
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False, use_index=False):
super().__init__(root, train, transform, target_transform, download)
self.use_index = use_index
def __getitem__(self, index):
img, target = super().__getitem__(index)
if self.use_index:
return img, target, index
else:
return img, target
def load_data(args, data, batch_size, test_batch_size, use_index=False, aug=True):
if data == 'MNIST':
"""Fix 403 Forbidden error in downloading MNIST
See https://github.com/pytorch/vision/issues/1938."""
from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)
dummy_input = torch.randn(2, 1, 28, 28)
mean, std = torch.tensor([0.0]), torch.tensor([1.0])
train_data = MNIST('./data', train=True, download=True, transform=transforms.ToTensor(), use_index=use_index)
if args.valid_share is not None:
test_data = MNIST(
'./data', train=True, download=True, transform=transforms.ToTensor(), use_index=use_index)
else:
test_data = MNIST(
'./data', train=False, download=True, transform=transforms.ToTensor(), use_index=use_index)
elif data == 'CIFAR':
mean = torch.tensor(cifar10_mean)
std = torch.tensor([0.2, 0.2, 0.2] if args.lip or args.global_lip or 'lip' in args.model else cifar10_std)
dummy_input = torch.randn(2, 3, 32, 32)
normalize = transforms.Normalize(mean = mean, std = std)
if aug:
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 2, padding_mode='edge'),
transforms.ToTensor(),
normalize])
else:
# No random cropping
transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize])
transform_test = transforms.Compose([transforms.ToTensor(), normalize])
train_data = CIFAR10('./data', train=True, download=True,
transform=transform, use_index=use_index)
if args.valid_share is not None:
test_data = CIFAR10('./data', train=True, download=True,
transform=transform_test, use_index=use_index)
else:
test_data = CIFAR10('./data', train=False, download=True,
transform=transform_test, use_index=use_index)
elif data == "tinyimagenet":
mean = torch.tensor([0.4802, 0.4481, 0.3975])
std = torch.tensor([0.22, 0.22, 0.22] if args.lip else [0.2302, 0.2265, 0.2262])
dummy_input = torch.randn(2, 3, 64, 64)
normalize = transforms.Normalize(mean=mean, std=std)
data_dir = 'data/tinyImageNet/tiny-imagenet-200'
train_data = datasets.ImageFolder(data_dir + '/train',
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(64, 4, padding_mode='edge'),
transforms.ToTensor(),
normalize,
]))
if args.valid_share is not None:
test_data = datasets.ImageFolder(data_dir + '/train',
transform=transforms.Compose([
transforms.ToTensor(),
normalize]))
else:
test_data = datasets.ImageFolder(data_dir + '/val',
transform=transforms.Compose([
transforms.ToTensor(),
normalize]))
elif data == "imagenet64":
# Code adapted from auto_lirpa's repository (removing the 56 cropping)
mean = torch.tensor([0.4815, 0.4578, 0.4082])
std = torch.tensor([0.2153, 0.2111, 0.2121])
dummy_input = torch.randn(2, 3, 64, 64)
normalize = transforms.Normalize(mean=mean, std=std)
data_dir = 'data/ImageNet64/'
train_data = datasets.ImageFolder(data_dir + '/train',
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(64, 4, padding_mode='edge'),
transforms.ToTensor(),
normalize,
]))
if args.valid_share is not None:
test_data = datasets.ImageFolder(data_dir + '/train',
transform=transforms.Compose([
transforms.ToTensor(),
normalize]))
else:
test_data = datasets.ImageFolder(data_dir + '/test',
transform=transforms.Compose([
transforms.ToTensor(),
normalize]))
else:
raise ValueError(f"Unsupported value for dataset name: {data}")
if args.valid_share is not None:
train_size = int(args.valid_share * len(train_data))
test_size = len(train_data) - train_size
if not args.valid_shuffle:
# the test data already points to the training set, but without the train-time transforms
test_data = torch.utils.data.Subset(test_data, range(train_size, train_size + test_size))
train_data = torch.utils.data.Subset(train_data, range(train_size))
else:
# in the paper's validation experiments, not used on CIFAR-10
random_indices = torch.randperm(len(train_data))
test_data = torch.utils.data.Subset(test_data, random_indices[train_size:])
train_data = torch.utils.data.Subset(train_data, random_indices[:train_size])
train_data = torch.utils.data.DataLoader(
train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=args.data_loader_workers)
test_data = torch.utils.data.DataLoader(
test_data, batch_size=test_batch_size, pin_memory=True, num_workers=args.data_loader_workers)
train_data.mean = test_data.mean = mean
train_data.std = test_data.std = std
for loader in [train_data, test_data]:
loader.mean, loader.std = mean, std
loader.data_max = data_max = torch.reshape((1. - mean) / std, (1, -1, 1, 1))
loader.data_min = torch.reshape((0. - mean) / std, (1, -1, 1, 1))
dummy_input = dummy_input.to(args.device)
return dummy_input, train_data, test_data