-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
139 lines (106 loc) · 3.93 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import numpy as np
import torch
import torchvision.transforms as transforms
from flask import Flask, request, send_file, render_template
from PIL import Image
import io
import torch.nn as nn
import pickle
app = Flask(__name__)
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
self.block = nn.Sequential(
nn.ReflectionPad2d(1), # Pads the input tensor using the reflection of the input boundary
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features)
)
def forward(self, x):
return x + self.block(x)
class GeneratorResNet(nn.Module):
def __init__(self, input_shape, num_residual_block):
super(GeneratorResNet, self).__init__()
channels = input_shape[0]
# Initial Convolution Block
out_features = 64
model = [
nn.ReflectionPad2d(channels),
nn.Conv2d(channels, out_features, 7),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
# Downsampling
for _ in range(2):
out_features *= 2
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True)
]
in_features = out_features
# Residual blocks
for _ in range(num_residual_block):
model += [ResidualBlock(out_features)]
# Upsampling
for _ in range(2):
out_features //= 2
model += [
nn.Upsample(scale_factor=2), # --> width*2, heigh*2
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
nn.ReLU(inplace=True)
]
in_features = out_features
# Output Layer
model += [nn.ReflectionPad2d(channels),
nn.Conv2d(out_features, channels, 7),
nn.Tanh()
]
# Unpacking
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
# Initialize the model
input_shape = (3, 256, 256) # Adjust based on your model's expected input
num_residual_blocks = 9 # Adjust based on your model
# model = GeneratorResNet(input_shape, num_residual_blocks)
# model_path = 'generator.sav' # Path to your .sav file
# Load the model from the .sav file
model=torch.load('backup.pth',map_location=torch.device('cpu'))
# with open(model_path, 'rb') as f:
# model = pickle.load(f)
# model.eval()
# Define image transformations
transform = transforms.Compose([
transforms.Resize(
(256, 256)), # Adjust size based on your model's input size
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
@app.route('/', methods=["GET"])
def home():
return render_template('index.html')
@app.route('/', methods=['POST'])
def predict():
if 'file' not in request.files:
return {"error": "No file provided"}, 400
file = request.files['file'].read()
img = Image.open(io.BytesIO(file)).convert(
'RGB') # Ensure image is in RGB mode
img = transform(img).unsqueeze(0) # Transform and add batch dimension
with torch.no_grad():
output = model(img)
output_img = output.squeeze().permute(1, 2, 0).cpu().numpy()
output_img = (output_img - output_img.min()) / (
output_img.max() - output_img.min()) # Normalize to [0, 1]
response_image = Image.fromarray((output_img * 255).astype(np.uint8))
buffered = io.BytesIO()
response_image.save(buffered, format="PNG")
buffered.seek(0)
return send_file(buffered, mimetype='image/png')
if __name__ == '__main__':
app.run(debug=True)