-
Notifications
You must be signed in to change notification settings - Fork 3
/
inference.py
40 lines (33 loc) · 1.08 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from dnn_split.model_util import get_alexnet
from dnn_split.model_canyon import ModelCanyon
from PIL import Image
from torchvision import transforms
MODEL_PATH = './models/'
IMAGE_PATH = './data/images/'
def get_input():
input_image = Image.open(IMAGE_PATH+'dog.jpg')
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
return input_batch
if __name__ == '__main__':
input = get_input()
path = MODEL_PATH+"partialmodel.pth"
alexnet = get_alexnet()
model = ModelCanyon(model=alexnet, start=0, end=2)
model = torch.load(path)
model.eval()
# print(model.partialLayers)
output = model(input)
path2 = MODEL_PATH+"partialmodel2.pth"
model2 = ModelCanyon(model=alexnet, start=3, end=20)
model2 = torch.load(path2)
model2.eval()
output2 = model2(output)
print(output2)