-
Notifications
You must be signed in to change notification settings - Fork 0
/
torchscript_model.py
72 lines (56 loc) · 1.89 KB
/
torchscript_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright (c) 2020 Maka Autonomous Robotic Systems, Inc.
#
# This file is part of Makannotations.
#
# Makannotations is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Makannotations is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
try:
has_torch = True
import torch
import torch.nn.functional as F
except Exception:
has_torch = False
MODEL = None
def load(path):
if not has_torch:
return
global MODEL
MODEL = torch.jit.load(path, map_location="cpu")
MODEL.eval()
def is_loaded():
if not has_torch:
return False
global MODEL
return MODEL is not None
def image_auto_mask(rgb_img, channel, threshold=0.5):
if not has_torch:
return None
global MODEL
# Convert to PyTorch tensor
t = torch.tensor([rgb_img]).permute(0, 3, 1, 2).float()
t = F.interpolate(t, (720, 1280), mode="bilinear", align_corners=False)
# Standard RGB adjustments
t /= 255.0
t[:, 0, :, :] -= 0.485
t[:, 1, :, :] -= 0.456
t[:, 2, :, :] -= 0.406
t[:, 0, :, :] /= 0.229
t[:, 1, :, :] /= 0.224
t[:, 2, :, :] /= 0.225
# Run inference
with torch.no_grad():
mask = MODEL(t)
# Resize mask back to original image size
mask = F.interpolate(mask, (rgb_img.shape[0], rgb_img.shape[1]), mode="bilinear", align_corners=False)
# Mask is assumed to be on channel=1
return (mask[0][channel] > threshold).numpy()