-
Notifications
You must be signed in to change notification settings - Fork 0
/
exporter_main_v2.py
159 lines (138 loc) · 7.05 KB
/
exporter_main_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Lint as: python2, python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Tool to export an object detection model for inference.
Prepares an object detection tensorflow graph for inference using model
configuration and a trained checkpoint. Outputs associated checkpoint files,
a SavedModel, and a copy of the model config.
The inference graph contains one of three input nodes depending on the user
specified option.
* `image_tensor`: Accepts a uint8 4-D tensor of shape [1, None, None, 3]
* `float_image_tensor`: Accepts a float32 4-D tensor of shape
[1, None, None, 3]
* `encoded_image_string_tensor`: Accepts a 1-D string tensor of shape [None]
containing encoded PNG or JPEG images. Image resolutions are expected to be
the same if more than 1 image is provided.
* `tf_example`: Accepts a 1-D string tensor of shape [None] containing
serialized TFExample protos. Image resolutions are expected to be the same
if more than 1 image is provided.
and the following output nodes returned by the model.postprocess(..):
* `num_detections`: Outputs float32 tensors of the form [batch]
that specifies the number of valid boxes per image in the batch.
* `detection_boxes`: Outputs float32 tensors of the form
[batch, num_boxes, 4] containing detected boxes.
* `detection_scores`: Outputs float32 tensors of the form
[batch, num_boxes] containing class scores for the detections.
* `detection_classes`: Outputs float32 tensors of the form
[batch, num_boxes] containing classes for the detections.
Example Usage:
--------------
python exporter_main_v2.py \
--input_type image_tensor \
--pipeline_config_path path/to/ssd_inception_v2.config \
--trained_checkpoint_dir path/to/checkpoint \
--output_directory path/to/exported_model_directory
--use_side_inputs True/False \
--side_input_shapes dim_0,dim_1,...dim_a/.../dim_0,dim_1,...,dim_z \
--side_input_names name_a,name_b,...,name_c \
--side_input_types type_1,type_2
The expected output would be in the directory
path/to/exported_model_directory (which is created if it does not exist)
holding two subdirectories (corresponding to checkpoint and SavedModel,
respectively) and a copy of the pipeline config.
Config overrides (see the `config_override` flag) are text protobufs
(also of type pipeline_pb2.TrainEvalPipelineConfig) which are used to override
certain fields in the provided pipeline_config_path. These are useful for
making small changes to the inference graph that differ from the training or
eval config.
Example Usage (in which we change the second stage post-processing score
threshold to be 0.5):
python exporter_main_v2.py \
--input_type image_tensor \
--pipeline_config_path path/to/ssd_inception_v2.config \
--trained_checkpoint_dir path/to/checkpoint \
--output_directory path/to/exported_model_directory \
--config_override " \
model{ \
faster_rcnn { \
second_stage_post_processing { \
batch_non_max_suppression { \
score_threshold: 0.5 \
} \
} \
} \
}"
If side inputs are desired, the following arguments could be appended
(the example below is for Context R-CNN).
--use_side_inputs True \
--side_input_shapes 1,2000,2057/1 \
--side_input_names context_features,valid_context_size \
--side_input_types tf.float32,tf.int32
"""
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
from google.protobuf import text_format
from object_detection import exporter_lib_v2
from object_detection.protos import pipeline_pb2
tf.enable_v2_behavior()
FLAGS = flags.FLAGS
flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be '
'one of [`image_tensor`, `encoded_image_string_tensor`, '
'`tf_example`, `float_image_tensor`]')
flags.DEFINE_string('pipeline_config_path', None,
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file.')
flags.DEFINE_string('trained_checkpoint_dir', None,
'Path to trained checkpoint directory')
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags.DEFINE_string('config_override', '',
'pipeline_pb2.TrainEvalPipelineConfig '
'text proto to override pipeline_config_path.')
flags.DEFINE_boolean('use_side_inputs', False,
'If True, uses side inputs as well as image inputs.')
flags.DEFINE_string('side_input_shapes', '',
'If use_side_inputs is True, this explicitly sets '
'the shape of the side input tensors to a fixed size. The '
'dimensions are to be provided as a comma-separated list '
'of integers. A value of -1 can be used for unknown '
'dimensions. A `/` denotes a break, starting the shape of '
'the next side input tensor. This flag is required if '
'using side inputs.')
flags.DEFINE_string('side_input_types', '',
'If use_side_inputs is True, this explicitly sets '
'the type of the side input tensors. The '
'dimensions are to be provided as a comma-separated list '
'of types, each of `string`, `integer`, or `float`. '
'This flag is required if using side inputs.')
flags.DEFINE_string('side_input_names', '',
'If use_side_inputs is True, this explicitly sets '
'the names of the side input tensors required by the model '
'assuming the names will be a comma-separated list of '
'strings. This flag is required if using side inputs.')
flags.mark_flag_as_required('pipeline_config_path')
flags.mark_flag_as_required('trained_checkpoint_dir')
flags.mark_flag_as_required('output_directory')
def main(_):
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.io.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
text_format.Merge(FLAGS.config_override, pipeline_config)
exporter_lib_v2.export_inference_graph(
FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_dir,
FLAGS.output_directory, FLAGS.use_side_inputs, FLAGS.side_input_shapes,
FLAGS.side_input_types, FLAGS.side_input_names)
if __name__ == '__main__':
app.run(main)