-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_inference_model
executable file
·34 lines (23 loc) · 1.1 KB
/
generate_inference_model
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#!/usr/bin/python
import argparse
import os
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
from keras import backend as K
import model
parser = argparse.ArgumentParser()
parser.add_argument('--input_width', type=int, default=640, help='the input image width')
parser.add_argument('--input_height', type=int, default=360, help='the input image height')
parser.add_argument('--weights', type=str, default='weights_1.h5', help='output path for the weights file')
opt = parser.parse_args()
print(opt)
K.set_learning_phase(0)
tiramisu = model.build(opt.input_width, opt.input_height, 66, weights_path=opt.weights)
sess = K.get_session()
for n in tf.get_default_graph().as_graph_def().node:
print(n.name)
inference_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), ['output'])
inference_graph = graph_util.remove_training_nodes(inference_graph) # Just in case
graph_io.write_graph(constant_graph, './', 'tiramisu.pb', as_text=False)
print('Saved the constant graph (ready for inference) at: ./tiramisu.pb')