Skip to content

Commit

Permalink
Extend saving models optimized by inference session (#16912)
Browse files Browse the repository at this point in the history
### Description
This PR adds support for saving model optimizations after loading a
model that contains external data into an `InferenceSession`.



### Motivation and Context
This PR is a follow-up to a [previous
PR](#16716) for saving a
model optimized by an `InferenceSession`.
  • Loading branch information
kunal-vaishnavi authored Jul 31, 2023
1 parent 73ddba9 commit 3c72f43
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 23 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3399,8 +3399,8 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std
int64_t external_offset = 0;

// Add the initializers to the result graph.
#if !defined(DISABLE_SPARSE_TENSORS)
const auto& model_path = ModelPath();
#if !defined(DISABLE_SPARSE_TENSORS)
const auto sparse_end = sparse_tensor_names_.end();
#endif

Expand All @@ -3417,7 +3417,7 @@ ONNX_NAMESPACE::GraphProto Graph::ToGraphProtoWithExternalInitializers(const std
TensorProto* output_proto = result.add_initializer();

std::vector<uint8_t> raw_data;
ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, Path(), raw_data));
ORT_THROW_IF_ERROR(utils::UnpackInitializerData(initializer, model_path, raw_data));
size_t tensor_bytes_size = raw_data.size();
if (tensor_bytes_size < initializer_size_threshold) {
*output_proto = initializer;
Expand Down
19 changes: 12 additions & 7 deletions onnxruntime/python/tools/transformers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import coloredlogs
from fusion_options import FusionOptions
from onnx import ModelProto, load_model
from onnx import ModelProto, TensorProto, load_model
from onnx_model import OnnxModel
from onnx_model_bart import BartOnnxModel
from onnx_model_bert import BertOnnxModel
Expand Down Expand Up @@ -220,7 +220,6 @@ def optimize_model(
use_gpu: bool = False,
only_onnxruntime: bool = False,
verbose: bool = False,
use_external_data_format: bool = False,
):
"""Optimize Model by OnnxRuntime and/or python fusion logic.
Expand Down Expand Up @@ -258,8 +257,6 @@ def optimize_model(
use_gpu (bool, optional): use gpu or not for onnxruntime. Defaults to False.
only_onnxruntime (bool, optional): only use onnxruntime to optimize model, and no python fusion.
Defaults to False.
use_external_data_format (bool, optional): use external data format when saving optimized model.
Defaults to False.
Returns:
object of an optimizer class.
Expand All @@ -280,6 +277,15 @@ def optimize_model(
optimized_model_name = "model_o{}_{}.onnx".format(opt_level, "gpu" if use_gpu else "cpu")
optimized_model_path = os.path.join(temp_dir.name, optimized_model_name)

# Auto detect if input model has external data
has_external_data_file = False
original_model = load_model(input, load_external_data=False)
for initializer in original_model.graph.initializer:
if initializer.HasField("data_location") and initializer.data_location == TensorProto.EXTERNAL:
has_external_data_file = True
break
del original_model

if opt_level > 1:
# Disable some optimizers that might cause failure in symbolic shape inference or attention fusion.
disabled_optimizers += (
Expand All @@ -300,7 +306,7 @@ def optimize_model(
opt_level=opt_level,
disabled_optimizers=disabled_optimizers,
verbose=verbose,
save_as_external_data=use_external_data_format,
save_as_external_data=has_external_data_file,
)
elif opt_level == 1:
# basic optimizations (like constant folding and cast elimination) are not specified to execution provider.
Expand All @@ -314,7 +320,7 @@ def optimize_model(
opt_level=1,
disabled_optimizers=disabled_optimizers,
verbose=verbose,
save_as_external_data=use_external_data_format,
save_as_external_data=has_external_data_file,
)

if only_onnxruntime and not temp_model_path:
Expand Down Expand Up @@ -496,7 +502,6 @@ def main():
optimization_options=optimization_options,
use_gpu=args.use_gpu,
only_onnxruntime=args.only_onnxruntime,
use_external_data_format=args.use_external_data_format,
)

if args.float16:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ namespace onnxruntime {
namespace test {

void LoadSaveAndCompareModel(const std::string& input_onnx,
const std::string& input_external_init_file,
const std::string& output_onnx,
const std::string& external_init_file,
const std::string& output_external_init_file,
size_t initializer_size_threshold) {
std::shared_ptr<Model> model;
ASSERT_STATUS_OK(Model::Load(ToPathString(input_onnx), model, nullptr, DefaultLoggingManager().DefaultLogger()));
std::remove(output_onnx.c_str());
std::remove(external_init_file.c_str());
ASSERT_STATUS_OK(Model::SaveWithExternalInitializers(*model, ToPathString(output_onnx), external_init_file, initializer_size_threshold));
std::remove(output_external_init_file.c_str());
ASSERT_STATUS_OK(Model::SaveWithExternalInitializers(*model, ToPathString(output_onnx), output_external_init_file, initializer_size_threshold));

std::shared_ptr<Model> model_from_external;
ASSERT_STATUS_OK(Model::Load(ToPathString(output_onnx), model_from_external, nullptr, DefaultLoggingManager().DefaultLogger()));
Expand All @@ -42,19 +43,22 @@ void LoadSaveAndCompareModel(const std::string& input_onnx,
ASSERT_EQ(initializers.size(), initializers_from_external.size());

// Compare the initializers of the two versions.
Path model_path{};
Path external_data_path{};
for (auto i : initializers) {
const std::string kInitName = i.first;
const ONNX_NAMESPACE::TensorProto* tensor_proto = i.second;
const ONNX_NAMESPACE::TensorProto* from_external_tensor_proto = initializers_from_external[kInitName];

std::vector<uint8_t> tensor_proto_data;
ORT_THROW_IF_ERROR(utils::UnpackInitializerData(*tensor_proto, Path(), tensor_proto_data));
model_path = Path::Parse(ToPathString(input_onnx));
external_data_path = (input_external_init_file.size()) ? model_path.ParentPath().Append(Path::Parse(ToPathString(input_external_init_file))) : Path();
ORT_THROW_IF_ERROR(utils::UnpackInitializerData(*tensor_proto, external_data_path, tensor_proto_data));
size_t tensor_proto_size = tensor_proto_data.size();

std::vector<uint8_t> from_external_tensor_proto_data;
Path model_path = Path::Parse(ToPathString(output_onnx));
external_data_path = model_path.ParentPath().Append(Path::Parse(ToPathString(external_init_file)));
model_path = Path::Parse(ToPathString(output_onnx));
external_data_path = model_path.ParentPath().Append(Path::Parse(ToPathString(output_external_init_file)));
ORT_THROW_IF_ERROR(utils::UnpackInitializerData(*from_external_tensor_proto, model_path, from_external_tensor_proto_data));
size_t from_external_tensor_proto_size = from_external_tensor_proto_data.size();

Expand All @@ -74,8 +78,14 @@ void LoadSaveAndCompareModel(const std::string& input_onnx,
ASSERT_EQ(std::remove(PathToUTF8String(external_data_path.ToPathString()).c_str()), 0);
}

// Original model does not have external initializers
TEST(SaveWithExternalInitializers, Mnist) {
LoadSaveAndCompareModel("testdata/mnist.onnx", "testdata/mnist_with_external_initializers.onnx", "mnist_external_initializers.bin", 100);
LoadSaveAndCompareModel("testdata/mnist.onnx", "", "testdata/mnist_with_external_initializers.onnx", "mnist_external_initializers.bin", 100);
}

// Original model has external initializers
TEST(SaveWithExternalInitializers, ModelWithOriginalExternalData) {
LoadSaveAndCompareModel("testdata/model_with_orig_ext_data.onnx", "model_with_orig_ext_data.onnx.data", "testdata/model_with_new_external_initializers.onnx", "model_with_new_external_initializers.bin", 0);
}

} // namespace test
Expand Down
41 changes: 40 additions & 1 deletion onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_model_serialization(self):
providers=["CPUExecutionProvider"],
)
self.assertTrue(os.path.isfile(so.optimized_model_filepath))
os.remove(so.optimized_model_filepath)
except Fail as onnxruntime_error:
if (
str(onnxruntime_error) == "[ONNXRuntimeError] : 1 : FAIL : Unable to serialize model as it contains"
Expand Down Expand Up @@ -113,6 +114,8 @@ def test_model_serialization_with_external_initializers(self):
)
self.assertTrue(os.path.isfile(so.optimized_model_filepath))
self.assertTrue(os.path.isfile(external_initializers_file))
os.remove(so.optimized_model_filepath)
os.remove(external_initializers_file)
except Fail as onnxruntime_error:
if (
str(onnxruntime_error) == "[ONNXRuntimeError] : 1 : FAIL : Unable to serialize model as it contains"
Expand All @@ -137,6 +140,36 @@ def test_model_serialization_with_external_initializers_to_directory(self):
onnxrt.InferenceSession(get_name("mnist.onnx"), sess_options=so, providers=["CPUExecutionProvider"])
self.assertTrue(os.path.isfile(so.optimized_model_filepath))
self.assertTrue(os.path.isfile(os.path.join(directory, external_initializers_file)))
os.remove(so.optimized_model_filepath)
os.remove(os.path.join(directory, external_initializers_file))
except Fail as onnxruntime_error:
if (
str(onnxruntime_error) == "[ONNXRuntimeError] : 1 : FAIL : Unable to serialize model as it contains"
" compiled nodes. Please disable any execution providers which generate compiled nodes."
):
pass
else:
raise onnxruntime_error

def test_model_serialization_with_original_external_initializers_to_directory(self):
try:
so = onnxrt.SessionOptions()
so.log_severity_level = 1
so.logid = "TestModelSerializationWithOriginalExternalInitializersToDirectory"
directory = "./testdata/"
so.optimized_model_filepath = os.path.join(directory, "model_opt_with_ext_data.onnx")
external_initializers_file = "model_opt_with_ext_data.bin"
so.add_session_config_entry(
"session.optimized_model_external_initializers_file_name", external_initializers_file
)
so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100")
onnxrt.InferenceSession(
get_name("model_with_orig_ext_data.onnx"), sess_options=so, providers=["CPUExecutionProvider"]
)
self.assertTrue(os.path.isfile(so.optimized_model_filepath))
self.assertTrue(os.path.isfile(os.path.join(directory, external_initializers_file)))
os.remove(so.optimized_model_filepath)
os.remove(os.path.join(directory, external_initializers_file))
except Fail as onnxruntime_error:
if (
str(onnxruntime_error) == "[ONNXRuntimeError] : 1 : FAIL : Unable to serialize model as it contains"
Expand Down Expand Up @@ -879,6 +912,8 @@ def test_profiler_with_session_options(self):
self.assertTrue(tag in lines[i])
self.assertTrue("]" in lines[-1])

os.remove(profile_file)

def test_profiler_get_start_time_ns(self):
def get_single_session_profiling_start_time():
so = onnxrt.SessionOptions()
Expand All @@ -888,7 +923,9 @@ def get_single_session_profiling_start_time():
sess_options=so,
providers=onnxrt.get_available_providers(),
)
return sess.get_profiling_start_time_ns()
start_time = sess.get_profiling_start_time_ns()
os.remove(sess.end_profiling())
return start_time

# Get 1st profiling's start time
start_time_1 = get_single_session_profiling_start_time()
Expand Down Expand Up @@ -1028,6 +1065,8 @@ def test_loading_session_options_from_model(self):

self.assertEqual(session_options.enable_profiling, True) # from the ORT config

os.remove(sess.end_profiling())

except Exception:
raise

Expand Down
15 changes: 9 additions & 6 deletions onnxruntime/test/testdata/model_with_external_initializers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os

import numpy as np
import onnx
from onnx import TensorProto, helper
Expand All @@ -14,27 +16,27 @@ def create_external_data_tensor(value, tensor_name): # type: (List[Any], Text)
tensor_filename = f"{tensor_name}.bin"
set_external_data(tensor, location=tensor_filename)

with open(os.path.join(tensor_filename), "wb") as data_file: # noqa: F821
with open(os.path.join(tensor_filename), "wb") as data_file:
data_file.write(tensor.raw_data)
tensor.ClearField("raw_data")
tensor.data_location = onnx.TensorProto.EXTERNAL
return tensor


def GenerateModel(model_name): # noqa: N802
def GenerateModel(model_name, external_data_name): # noqa: N802
# Create one input (ValueInfoProto)
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2]) # noqa: N806

# Create second input (ValueInfoProto)
Pads = helper.make_tensor_value_info("Pads", TensorProto.INT64, [4]) # noqa: N806
Pads = helper.make_tensor_value_info(external_data_name, TensorProto.INT64, [4]) # noqa: N806

# Create one output (ValueInfoProto)
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4]) # noqa: N806

# Create a node (NodeProto)
node_def = helper.make_node(
"Pad", # node name
["X", "Pads"], # inputs
["X", external_data_name], # inputs
["Y"], # outputs
mode="constant", # Attributes
)
Expand All @@ -53,7 +55,7 @@ def GenerateModel(model_name): # noqa: N802
1,
1,
],
"Pads",
external_data_name,
)
],
)
Expand All @@ -71,4 +73,5 @@ def GenerateModel(model_name): # noqa: N802


if __name__ == "__main__":
GenerateModel("model_with_external_initializers.onnx")
GenerateModel("model_with_external_initializers.onnx", "Pads")
GenerateModel("model_with_orig_ext_data.onnx", "model_with_orig_ext_data")
Binary file not shown.
19 changes: 19 additions & 0 deletions onnxruntime/test/testdata/model_with_orig_ext_data.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
  onnx-example:�
:
X
model_with_orig_ext_dataY"Pad*
mode"constant�
test-model*JBmodel_with_orig_ext_dataj(
locationmodel_with_orig_ext_data.binpZ
X


Z&
model_with_orig_ext_data


b
Y


B

0 comments on commit 3c72f43

Please sign in to comment.