Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable export for inference when eval model is loaded from buffer #21422

Closed
wants to merge 12 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,10 @@ public void OptimizerStep(RunOptions options)
/// an inference model if it knows the inference graph outputs. The input inference graph outputs
/// are used to prune the eval model so that the inference model's outputs align with the provided outputs.
/// The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
/// Note that the function re-loads the eval model from the path provided to TrainingSession
/// and expects that this path still be valid.
///
/// This function modifies the eval graph in-place, so after this method is called, the TrainingSession can
/// no longer be used for training. In order to continue training from this point, save the checkpoint state
/// and create a new TrainingSession with the saved checkpoint state.
carzh marked this conversation as resolved.
Show resolved Hide resolved
/// </summary>
/// <param name="inferenceModelPath">Path where the inference model should be serialized to.</param>
/// <param name="graphOutputNames">Names of the outputs that are needed in the inference model.</param>
Expand Down
5 changes: 3 additions & 2 deletions java/src/main/java/ai/onnxruntime/OrtTrainingSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -998,8 +998,9 @@ private native void schedulerStep(long apiHandle, long trainingApiHandle, long n
* Exports the evaluation model as a model suitable for inference, setting the desired nodes as
* output nodes.
*
* <p>Note that this method reloads the evaluation model from the path provided to the training
* session, and this path must still be valid.
* <p>Note that this method modifies the eval session in-place; thus, after this method is called,
* the OrtTrainingSession can no longer be trained with. To continue training from this point,
* save the checkpoint and then load it into a new OrtTrainingSession.
*
* @param outputPath The path to write out the inference model.
* @param outputNames The names of the output nodes.
Expand Down
7 changes: 4 additions & 3 deletions objectivec/include/ort_training_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,11 @@ NS_ASSUME_NONNULL_BEGIN
*
* If the training session was provided with an eval model, the training session can generate an inference model if it
* knows the inference graph outputs. The input inference graph outputs are used to prune the eval model so that the
* inference model's outputs align with the provided outputs. The exported model is saved at the path provided and
* can be used for inferencing with `ORTSession`.
* inference model's outputs align with the provided outputs.
*
* @note The method reloads the eval model from the path provided to the initializer and expects this path to be valid.
* @note This method modifies the eval model graph in-place, so after this method is called, the ORTTrainingSession
* can no longer be used for training. To resume training from this point, save the checkpoint state and create a new
* ORTTrainingSession with the saved checkpoint state.
*
* @param inferenceModelPath The path to the serialized the inference model.
* @param graphOutputNames The names of the outputs that are needed in the inference model.
Expand Down
5 changes: 4 additions & 1 deletion orttraining/orttraining/python/training/api/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,12 @@ def export_model_for_inferencing(
Once training is complete, this function can be used to drop the training specific nodes in the onnx model.
In particular, this function does the following:

- Parse over the training graph and identify nodes that generate the given output names.
- Parse over the eval graph and identify nodes that generate the given output names.
- Drop all subsequent nodes in the graph since they are not relevant to the inference graph.

Once this method is called, training is considered complete and the module can no longer be used for training.
To resume training from this point, save the checkpoint and create a new module from the checkpoint.

Args:
inference_model_uri: The path to the inference model.
graph_output_names: The list of output names that are required for inferencing.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from orttraining_test_ort_apis_onnxblock import _get_models

import onnxruntime.training.onnxblock as onnxblock
from onnxruntime import OrtValue, SessionOptions
from onnxruntime import InferenceSession, OrtValue, SessionOptions
from onnxruntime.training import artifacts
from onnxruntime.training.api import CheckpointState, LinearLRScheduler, Module, Optimizer

Expand Down Expand Up @@ -283,6 +283,7 @@ def test_export_model_for_inferencing():
inference_model_file_path = os.path.join(temp_dir, "inference_model.onnx")
model.export_model_for_inferencing(inference_model_file_path, ["output-0"])
assert os.path.exists(inference_model_file_path)
InferenceSession(inference_model_file_path)


def test_cuda_execution_provider():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,75 @@ TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) {
train_model_data);
}

TEST(TrainingCApiTest, LoadONNXModelsFromBufferThenExport) {
auto model_path = MODEL_FOLDER "training_model.onnx";
size_t model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len));
std::vector<uint8_t> train_model_data(model_data_len);
std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary);
bytes_stream.read(reinterpret_cast<char*>(train_model_data.data()), model_data_len);
ASSERT_TRUE(train_model_data.size() == model_data_len);

auto eval_model_path = MODEL_FOLDER "eval_model.onnx";
size_t eval_model_data_len = 0;
ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, eval_model_data_len));
std::vector<uint8_t> eval_model_data(eval_model_data_len);
std::ifstream eval_bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary);
eval_bytes_stream.read(reinterpret_cast<char*>(eval_model_data.data()), eval_model_data_len);
ASSERT_TRUE(eval_model_data.size() == eval_model_data_len);

Ort::Env env;
Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt");
Ort::TrainingSession training_session = Ort::TrainingSession(env,
Ort::SessionOptions(),
checkpoint_state,
train_model_data,
eval_model_data);

// randomly selected output name
std::vector<std::string> graph_output_names({"onnx::loss::21273"});
training_session.ExportModelForInferencing(MODEL_FOLDER "inference_model.onnx", graph_output_names);

// Check that the model is a valid inference model by loading into an InferenceSession
std::unique_ptr<Environment> environment;
ASSERT_STATUS_OK(Environment::Create(nullptr, environment));
InferenceSession inference_session = InferenceSession(SessionOptions(), *environment, MODEL_FOLDER "inference_model.onnx");

// Check that you can no longer train or evaluate after exporting. Since passing incorrect inputs will also cause
// TrainStep and EvalStep to throw errors, we check for the error message.
ORT_TRY {
training_session.TrainStep({});
FAIL() << "TrainStep after exporting for inference should have thrown an error.";
}
ORT_CATCH(const Ort::Exception& e) {
ORT_HANDLE_EXCEPTION([&e]() {
ASSERT_THAT(e.what(),
testing::HasSubstr("Cannot train after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."));
});
}
ORT_CATCH(...) {
FAIL() << "TrainStep after exporting for inference should have thrown an Ort::Exception.";
}

ORT_TRY {
training_session.EvalStep({});
FAIL() << "EvalStep after exporting for inference should have thrown an Ort::Exception.";
}
ORT_CATCH(const Ort::Exception& e) {
ORT_HANDLE_EXCEPTION([&e]() {
ASSERT_THAT(e.what(),
testing::HasSubstr("Cannot evaluate after exporting for inferencing. To continue training from this point, please save the checkpoint and create a new TrainingSession."));
});
}
ORT_CATCH(...) {
FAIL() << "EvalStep after exporting for inference should have thrown an Ort::Exception.";
}

// attempt to retrieve the input & output names of the eval model
ASSERT_THROW(training_session.InputNames(false), Ort::Exception);
ASSERT_THROW(training_session.OutputNames(false), Ort::Exception);
}

TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) {
auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort";
auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,9 @@ struct OrtTrainingApi {
* an inference model if it knows the inference graph outputs. The input inference graph outputs
* are used to prune the eval model so that the inference model's outputs align with the provided outputs.
* The exported model is saved at the path provided and can be used for inferencing with InferenceSession.
* \note Note that the function re-loads the eval model from the path provided to OrtTrainingApi::CreateTrainingSession
* and expects that this path still be valid.
* \note Note that the function modifies the eval model graph in-place, so after this method is called, the
* OrtTrainingSession can no longer be used for training. To resume training from this point, save the checkpoint
* state and create a new OrtTrainingSession with the updated eval model.
*
* \param[in] sess The `this` pointer to the training session.
* \param[in] inference_model_path Path where the inference model should be serialized to.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,9 @@ class TrainingSession : public detail::Base<OrtTrainingSession> {
* an inference model if it knows the inference graph outputs. The input inference graph outputs
* are used to prune the eval model so that the inference model's outputs align with the provided outputs.
* The exported model is saved at the path provided and can be used for inferencing with Ort::Session.
* \note Note that the function re-loads the eval model from the path provided to Ort::TrainingSession
* and expects that this path still be valid.
* \note Note that the function modifies the eval model graph in-place, so after this method is called, the
* OrtTrainingSession can no longer be used for training. To resume training from this point, save the checkpoint
* state and create a new OrtTrainingSession with the updated eval model.
*
* \param[in] inference_model_path Path where the inference model should be serialized to.
* \param[in] graph_output_names Names of the outputs that are needed in the inference model.
Expand Down
64 changes: 42 additions & 22 deletions orttraining/orttraining/training_api/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#include "orttraining/training_api/module.h"

#include <memory>

#include "core/common/safeint.h"
#include "core/common/string_utils.h"
#include "core/framework/execution_provider.h"
Expand Down Expand Up @@ -54,9 +56,9 @@ Status RemoveUnusedNodes(Graph& inference_graph, InlinedVector<const NodeArg*>&
GraphViewer graph_viewer(inference_graph);
const auto node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (size_t idx = node_indices.size(); idx > 0; --idx) {
const NodeIndex node_index = idx - 1;
const NodeIndex node_index = node_indices[idx - 1];
auto* node = inference_graph.GetNode(node_index);
if (!reachable_nodes.count(node)) {
if (node && !reachable_nodes.count(node)) {
graph_utils::RemoveNodeOutputEdges(inference_graph, *node);
inference_graph.RemoveNode(node_index);
}
Expand Down Expand Up @@ -414,10 +416,6 @@ Module::Module(const ModelIdentifiers& model_identifiers,

// Keep a copy of the eval model path to be able to later export the model for inferencing.
// The inference model will be reconstructed from the eval model.
// TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer.
if (std::holds_alternative<std::optional<std::string>>(model_identifiers.eval_model)) {
eval_model_path_ = std::get<std::optional<std::string>>(model_identifiers.eval_model);
}
}

Module::~Module() {
Expand All @@ -428,7 +426,9 @@ size_t Module::GetTrainingModelOutputCount() const noexcept {
return train_output_names_.size();
}

size_t Module::GetEvalModelOutputCount() const noexcept {
size_t Module::GetEvalModelOutputCount() const {
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel output count. ");
return eval_output_names_.size();
}

Expand All @@ -438,6 +438,8 @@ std::string Module::GetTrainingModelOutputName(size_t index) const {
}

std::string Module::GetEvalModelOutputName(size_t index) const {
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel output name. ");
ORT_ENFORCE(index < eval_output_names_.size(), "Eval output name index out of range. Expected in range [0-",
eval_output_names_.size(), "). Actual: ", index);
return eval_output_names_.at(index);
Expand Down Expand Up @@ -613,13 +615,19 @@ Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool tr
}

Status Module::LazyResetGrad() {
ORT_RETURN_IF(finished_training_,
"Cannot train after exporting for inferencing. ",
"To continue training from this point, please save the checkpoint and create a new TrainingSession.");
accumulate_gradient_ = false;
return Status::OK();
}

Status Module::TrainStep(const std::vector<OrtValue>& inputs, std::vector<OrtValue>& outputs) {
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot perform TrainStep with a nominal state. Please load the model parameters first.");
ORT_RETURN_IF(finished_training_,
"Cannot train after exporting for inferencing. ",
"To continue training from this point, please save the checkpoint and create a new TrainingSession.");
std::vector<std::shared_ptr<Parameter>> params;
std::vector<OrtValue> feeds{inputs};
feeds.insert(feeds.end(), weights_.begin(), weights_.end());
Expand All @@ -642,6 +650,9 @@ Status Module::TrainStep(const std::vector<OrtValue>& inputs, std::vector<OrtVal
Status Module::EvalStep(const std::vector<OrtValue>& inputs, std::vector<OrtValue>& outputs) {
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot perform EvalStep with a nominal state. Please load the model parameters first.");
ORT_RETURN_IF(finished_training_,
"Cannot evaluate after exporting for inferencing. ",
"To continue training from this point, please save the checkpoint and create a new TrainingSession.");
ORT_ENFORCE(nullptr != eval_sess_, "Evaluation session not initialized.");
std::vector<OrtValue> feeds{inputs};
feeds.insert(feeds.end(), weights_.begin(), weights_.end());
Expand All @@ -655,26 +666,27 @@ Status Module::EvalStep(const std::vector<OrtValue>& inputs, std::vector<OrtValu
// the build is minimal or not. This will require to read the ort_format eval model,
// transform it to an inference model and save it in ort_format.
Status Module::ExportModelForInferencing(const std::string& inference_model_path,
gsl::span<const std::string> graph_output_names) const {
gsl::span<const std::string> graph_output_names) {
ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state,
"Cannot export the model with a nominal state. Please load the model parameters first.");
ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(),
"Eval model was not provided. Cannot export a model for inferencing.");
ORT_RETURN_IF(!eval_sess_, "Eval model was not provided. Cannot export a model for inferencing.");

// Once finished_training is set to true, will no longer be able to train or evaluate with this module
// since the eval session graph will have been modified.
finished_training_ = true;

ONNX_NAMESPACE::ModelProto eval_model;
ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model));
EvalSessionWrapper& eval_sess_wrapper = static_cast<EvalSessionWrapper&>(*eval_sess_);

// Clone the eval mode into an inference onnxruntime::Model.
std::shared_ptr<Model> inference_model;
ORT_RETURN_IF_ERROR(Model::Load(eval_model, inference_model, nullptr, logging::LoggingManager::DefaultLogger()));
Model& inference_model = eval_sess_wrapper.GetMutableModel();
Graph& inference_graph = eval_sess_wrapper.GetMutableGraph();

// The cloned model's outputs are transformed such that the model has outputs as defined by graph_output_names
// Any nodes not contributing to the inference outputs will be pruned.
ORT_THROW_IF_ERROR(TransformModelOutputsForInference(inference_model->MainGraph(), graph_output_names));
ORT_THROW_IF_ERROR(TransformModelOutputsForInference(inference_graph, graph_output_names));

// The cloned model's inputs are transformed such that the model has only user defined inputs. All parameters
// are moved to be constant initializers for the model.
ORT_RETURN_IF_ERROR(TransformModelInputsForInference(inference_model->MainGraph(),
ORT_RETURN_IF_ERROR(TransformModelInputsForInference(inference_graph,
state_->module_checkpoint_state.named_parameters,
eval_sess_->GetDataTransferManager()));

Expand All @@ -683,9 +695,9 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path
ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(ExternalCheckpointDataPath(ToPathString(inference_model_path)));
PathString inference_model_pathstring = ToPathString(inference_model_path);
ORT_THROW_IF_ERROR(
Model::SaveWithExternalInitializers(*inference_model, inference_model_pathstring, external_data_name, 64));
Model::SaveWithExternalInitializers(inference_model, inference_model_pathstring, external_data_name, 64));
} else {
ORT_THROW_IF_ERROR(Model::Save(*inference_model, ToPathString(inference_model_path)));
ORT_THROW_IF_ERROR(Model::Save(inference_model, ToPathString(inference_model_path)));
}
// Save the model at the desired location.
return Status::OK();
Expand All @@ -696,18 +708,24 @@ size_t Module::GetTrainingModelInputCount() const noexcept {
return train_input_names_.UserInputNames().size();
}

size_t Module::GetEvalModelInputCount() const noexcept {
size_t Module::GetEvalModelInputCount() const {
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel input count. ");
return eval_user_input_count_;
}

std::string Module::GetTrainingModelInputName(size_t index) const {
ORT_ENFORCE(index < train_input_names_.UserInputNames().size(),
"Train input name index out of range. Expected in range [0-", train_input_names_.UserInputNames().size(), "). Actual: ",
"Train input name index out of range. Expected in range [0-",
train_input_names_.UserInputNames().size(),
"). Actual: ",
index);
return train_input_names_.UserInputNames()[index];
}

std::string Module::GetEvalModelInputName(size_t index) const {
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel input name. ");
ORT_ENFORCE(index < eval_user_input_count_,
"Eval input name index out of range. Expected in range [0-", eval_user_input_count_, "). Actual: ",
index);
Expand All @@ -718,7 +736,9 @@ std::pair<common::Status, const InputDefList*> Module::GetTrainingModelInputs()
return train_sess_->GetModelInputs();
}

std::pair<common::Status, const InputDefList*> Module::GetEvalModelInputs() const noexcept {
std::pair<common::Status, const InputDefList*> Module::GetEvalModelInputs() const {
ORT_ENFORCE(!finished_training_,
"Exporting for inference has modified the eval model. Cannot retrieve EvalModel inputs. ");
return eval_sess_->GetModelInputs();
}

Expand Down
Loading
Loading