-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
81 lines (67 loc) · 2.76 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright 2021 by YoungWoon Cho, Danny Hong
# The Cooper Union for the Advancement of Science and Art
# ECE471 Machine Learning Architecture
import argparse
import random
import time
import os
from PIL import Image
import torch.backends.cudnn as cudnn
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torchvision.utils import save_image
from shutil import copyfile
from models import Generator
parser = argparse.ArgumentParser(description="This is a pytorch implementation of CycleGAN. Please refer to the following arguments.")
parser.add_argument("--data_root", type=str, default="./test", help="Root directory to the test dataset. Default: ./test")
parser.add_argument("--model", type=str, default="weights/fruit2rotten/G_A2B.pth", help="Generator model to use. Default: weights/fruit2rotten/G_A2B.pth")
parser.add_argument('--cuda', action="store_true", help="Turn on the cuda option.")
parser.add_argument('--image_size', type=int, default=256, help='Size of the image. Default: 256')
args = parser.parse_args()
# Random seed to initialize the random state
seed = random.randint(1, 10000)
torch.manual_seed(seed)
print(f'Random Seed: {seed}')
print('****Preparing training with following options****')
time.sleep(0.2)
# Cuda option
if torch.cuda.is_available() and not args.cuda:
print("Cuda device found. Turning on cuda...")
args.cuda = True
time.sleep(0.2)
device = torch.device("cuda:0" if args.cuda else "cpu")
# Random seed to initialize the random state
seed = random.randint(1, 10000)
torch.manual_seed(seed)
print(f'Random Seed: {seed}')
print(f'Cuda: {args.cuda}')
print(f'Image size: {args.image_size}')
print(f'Testing dataset: {args.data_root}')
print(f'Testing model: {args.model}')
time.sleep(0.2)
# create model
model = Generator().to(device)
# Load state dicts
if args.cuda == True:
model.load_state_dict(torch.load(args.model))
else:
model.load_state_dict(torch.load(args.model, map_location=torch.device('cpu')))
# Set model mode
model.eval()
# Load image
def translate(image_dir):
image = Image.open(image_dir)
transform = transforms.Compose([transforms.Resize(args.image_size),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
image = transform(image).unsqueeze(0)
image = image.to(device)
translated_image = model(image)
return translated_image
print("Begin translation...", end='')
for image in os.listdir(args.data_root):
if not image.startswith('translated_'):
translated_filename = image[:image.find('.')]
translated_image = translate(os.path.join(args.data_root, image))
save_image(translated_image.detach(), f'test/translated_{translated_filename}.png', normalize=True)
print("Done.")