Skip to content

Commit

Permalink
Enable exporting for inference when loading from buffer without behav…
Browse files Browse the repository at this point in the history
…ior changes (#21601)

### Description
Added eval model buffer as optional field in Module so that you can
export for inference using the eval model stored as a buffer.

### Motivation and Context
- Resolves #21152 
- Previous solution (PR #21422) produced an eval model that was specific
to the EP's used to train because of unavoidable runtime optimizations
that changed the graph stored with the eval session.
  • Loading branch information
carzh authored Aug 9, 2024
1 parent 37be90c commit eeef0c8
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 6 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ Status Model::SaveWithExternalInitializers(Model& model, const std::filesystem::
return SaveModelWithExternalInitializers(model, file_path, external_file_name, initializer_size_threshold);
}

Status Model::LoadFromBytes(int count, void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) {
Status Model::LoadFromBytes(int count, const void* p_bytes, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) {
const bool result = model_proto.ParseFromArray(p_bytes, count);
if (!result) {
return Status(ONNXRUNTIME, INVALID_PROTOBUF, "Protobuf parsing failed.");
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class Model {
const ModelOptions& options = {});

// 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks
static common::Status LoadFromBytes(int count, void* pBytes,
static common::Status LoadFromBytes(int count, const void* pBytes,
/*out*/ ONNX_NAMESPACE::ModelProto& model_proto);

// 'int' rather than 'size_t' because of a protobuf design choice; let callers handle type checks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,41 @@ TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) {
train_model_data);
}

TEST(TrainingCApiTest, LoadONNXModelsFromBufferThenExport) {
auto model_path = MODEL_FOLDER "training_model.onnx";
size_t model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len));
std::vector<uint8_t> train_model_data(model_data_len);
std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(train_model_data.data()), model_data_len);
ASSERT_TRUE(train_model_data.size() == model_data_len);

auto eval_model_path = MODEL_FOLDER "eval_model.onnx";
size_t eval_model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, eval_model_data_len));
std::vector<uint8_t> eval_model_data(eval_model_data_len);
std::ifstream eval_bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary);
eval_bytes_stream.read(reinterpret_cast<char*>(eval_model_data.data()), eval_model_data_len);
ASSERT_TRUE(eval_model_data.size() == eval_model_data_len);

Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(env,
Ort::SessionOptions(),
checkpoint_state,
train_model_data,
eval_model_data);

// randomly selected output name
std::vector<std::string> graph_output_names({"onnx::loss::21273"});
training_session.ExportModelForInferencing(MODEL_FOLDER "inference_model.onnx", graph_output_names);

// Check that the model is a valid inference model by loading into an InferenceSession
std::unique_ptr<Environment> environment;
ASSERT_STATUS_OK(Environment::Create(nullptr, environment));
InferenceSession inference_session = InferenceSession(SessionOptions(), *environment, MODEL_FOLDER "inference_model.onnx");
}

TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) {
auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort";
auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort";
Expand Down
15 changes: 11 additions & 4 deletions orttraining/orttraining/training_api/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,12 @@ Module::Module(const ModelIdentifiers& model_identifiers,
eval_user_input_count_ = eval_user_input_names.size();
eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end());

// Keep a copy of the eval model path to be able to later export the model for inferencing.
// Keep a copy of the eval model path or buffer to be able to later export the model for inferencing.
// The inference model will be reconstructed from the eval model.
// TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer.
if (std::holds_alternative<std::optional<std::string>>(model_identifiers.eval_model)) {
eval_model_path_ = std::get<std::optional<std::string>>(model_identifiers.eval_model);
} else if (std::holds_alternative<gsl::span<const uint8_t>>(model_identifiers.eval_model)) {
eval_model_buffer_ = std::get<gsl::span<const uint8_t>>(model_identifiers.eval_model);
}
}

Expand Down Expand Up @@ -658,11 +659,17 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path
gsl::span<const std::string> graph_output_names) const {
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot export the model with a nominal state. Please load the model parameters first.");
ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(),
ORT_RETURN_IF(!eval_sess_ || (!eval_model_path_.has_value() && !eval_model_buffer_.has_value()),
"Eval model was not provided. Cannot export a model for inferencing.");

ONNX_NAMESPACE::ModelProto eval_model;
ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model));
if (eval_model_path_.has_value()) {
ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model));
} else if (eval_model_buffer_.has_value()) {
int eval_model_buffer_size = static_cast<int>(eval_model_buffer_.value().size());
const void* eval_model_buffer_ptr = static_cast<const void*>(eval_model_buffer_.value().data());
ORT_THROW_IF_ERROR(Model::LoadFromBytes(eval_model_buffer_size, eval_model_buffer_ptr, eval_model));
}

// Clone the eval mode into an inference onnxruntime::Model.
std::shared_ptr<Model> inference_model;
Expand Down
1 change: 1 addition & 0 deletions orttraining/orttraining/training_api/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ struct Module {

bool accumulate_gradient_ = false;
std::optional<std::string> eval_model_path_;
std::optional<gsl::span<const uint8_t>> eval_model_buffer_;
size_t eval_user_input_count_{0U};
};

Expand Down

0 comments on commit eeef0c8

Please sign in to comment.