-
Notifications
You must be signed in to change notification settings - Fork 0
/
deploy.py
79 lines (51 loc) · 1.44 KB
/
deploy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
from dotenv import load_dotenv
import yaml
import gzip
import numpy as np
from sagemaker.pytorch import PyTorchModel
from sagemaker import Session
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
import matplotlib.pyplot as plt
load_dotenv()
with open('config/config.yaml', 'r') as file:
config = yaml.safe_load(file)
with open('config/save_last_model.yaml', 'r') as file:
saveconfig = yaml.safe_load(file)
# Deploying
sess = Session()
role = os.getenv("role")
pt_mnist_model_data = saveconfig["last-trained-model"]
model = PyTorchModel(
entry_point="inference.py",
source_dir="code",
role=role,
model_data=pt_mnist_model_data,
framework_version="1.5.0",
py_version="py3",
)
predictor = model.deploy(
initial_instance_count=1,
instance_type=config["deploy-instance"],
serializer=JSONSerializer(),
deserializer=JSONDeserializer(),
)
# Calling
res = []
data = {"x": 56, "y": 56}
for i in range(10):
res.append(predictor.predict(data))
# Shutting down
predictor.delete_endpoint()
# Showing results
if not os.path.exists("outputs"):
os.makedirs("outputs")
save_path = f"outputs/{pt_mnist_model_data.split('/')[5]}"
if not os.path.exists(save_path):
os.makedirs(save_path)
for i, image in enumerate(res):
fig, axes = plt.subplots(1, 1)
axes.imshow(image[0][0])
fig.savefig(f"{save_path}/{i}.png")
fig.show()