Skip to content

Commit

Permalink
Adding ORTPipelineForxxx entrypoints (#1960)
Browse files Browse the repository at this point in the history
* created auto task mappings

* added correct auto classes

* created auto task mappings

* added correct auto classes

* added ort/auto diffusion classes

* fix ORTPipeline detection

* start test refactoring

* dynamic dtype

* support torch random numbers generator

* compact diffusion testing suite

* fix

* test

* test

* test

* use latent-consistency architecture name instead of lcm

* fix

* add ort diffusion pipeline tests

* added dummy objects

* remove duplicate code

* support testing without diffusers

* remove unnecessary

* revert

* style

* remove model parts from optimum.onnxruntime
  • Loading branch information
IlyasMoutawwakil authored Sep 16, 2024
1 parent f1b708c commit ca36fc4
Show file tree
Hide file tree
Showing 18 changed files with 1,287 additions and 827 deletions.
2 changes: 1 addition & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ class TasksManager:
"image-feature-extraction": "feature-extraction",
# for backward compatibility and testing (where
# model task and model type are still the same)
"lcm": "text-to-image",
"stable-diffusion": "text-to-image",
"stable-diffusion-xl": "text-to-image",
"latent-consistency": "text-to-image",
}

_CUSTOM_CLASSES = {
Expand Down
9 changes: 6 additions & 3 deletions optimum/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class PreTrainedModel(ABC): # noqa: F811

class OptimizedModel(PreTrainedModel):
config_class = AutoConfig
load_tf_weights = None
base_model_prefix = "optimized_model"
config_name = CONFIG_NAME

Expand Down Expand Up @@ -378,10 +377,14 @@ def from_pretrained(
)
model_id, revision = model_id.split("@")

library_name = TasksManager.infer_library_from_model(model_id, subfolder, revision, cache_dir, token=token)
library_name = TasksManager.infer_library_from_model(
model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
)

if library_name == "timm":
config = PretrainedConfig.from_pretrained(model_id, subfolder, revision)
config = PretrainedConfig.from_pretrained(
model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
)

if config is None:
if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME:
Expand Down
16 changes: 16 additions & 0 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@
"ORTStableDiffusionXLPipeline",
"ORTStableDiffusionXLImg2ImgPipeline",
"ORTLatentConsistencyModelPipeline",
"ORTPipelineForImage2Image",
"ORTPipelineForInpainting",
"ORTPipelineForText2Image",
"ORTDiffusionPipeline",
]
else:
_import_structure["modeling_diffusion"] = [
Expand All @@ -88,6 +92,10 @@
"ORTStableDiffusionXLPipeline",
"ORTStableDiffusionXLImg2ImgPipeline",
"ORTLatentConsistencyModelPipeline",
"ORTPipelineForImage2Image",
"ORTPipelineForInpainting",
"ORTPipelineForText2Image",
"ORTDiffusionPipeline",
]


Expand Down Expand Up @@ -137,7 +145,11 @@
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_diffusers_objects import (
ORTDiffusionPipeline,
ORTLatentConsistencyModelPipeline,
ORTPipelineForImage2Image,
ORTPipelineForInpainting,
ORTPipelineForText2Image,
ORTStableDiffusionImg2ImgPipeline,
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionPipeline,
Expand All @@ -146,7 +158,11 @@
)
else:
from .modeling_diffusion import (
ORTDiffusionPipeline,
ORTLatentConsistencyModelPipeline,
ORTPipelineForImage2Image,
ORTPipelineForInpainting,
ORTPipelineForText2Image,
ORTStableDiffusionImg2ImgPipeline,
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionPipeline,
Expand Down
50 changes: 23 additions & 27 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,11 @@ class ORTModelPart:
_prepare_onnx_inputs = ORTModel._prepare_onnx_inputs
_prepare_onnx_outputs = ORTModel._prepare_onnx_outputs

def __init__(
self,
session: InferenceSession,
parent_model: "ORTModel",
):
def __init__(self, session: InferenceSession, parent_model: "ORTModel"):
self.session = session
self.parent_model = parent_model
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(
self.parent_model.config.model_type
)(self.parent_model.config)
self.main_input_name = self.parent_model.main_input_name

self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()}
Expand Down Expand Up @@ -90,12 +84,18 @@ class ORTEncoder(ORTModelPart):
Encoder part of the encoder-decoder model for ONNX Runtime inference.
"""

def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
**kwargs,
) -> BaseModelOutput:
def __init__(self, session: InferenceSession, parent_model: "ORTModel"):
super().__init__(session, parent_model)

config = (
self.parent_model.config.encoder
if hasattr(self.parent_model.config, "encoder")
else self.parent_model.config
)

self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)

def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, **kwargs) -> BaseModelOutput:
use_torch = isinstance(input_ids, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)

Expand Down Expand Up @@ -138,6 +138,14 @@ def __init__(
):
super().__init__(session, parent_model)

config = (
self.parent_model.config.decoder
if hasattr(self.parent_model.config, "decoder")
else self.parent_model.config
)

self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)

# TODO: make this less hacky.
self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)]
self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)]
Expand All @@ -153,11 +161,7 @@ def __init__(

self.use_past_in_outputs = len(self.key_value_output_names) > 0
self.use_past_in_inputs = len(self.key_value_input_names) > 0
self.use_fp16 = False
for inp in session.get_inputs():
if "past_key_values" in inp.name and inp.type == "tensor(float16)":
self.use_fp16 = True
break
self.use_fp16 = self.dtype == torch.float16

# We may use ORTDecoderForSeq2Seq for vision-encoder-decoder models, where models as gpt2
# can be used but do not support KV caching for the cross-attention key/values, see:
Expand Down Expand Up @@ -461,11 +465,3 @@ def prepare_inputs_for_merged(
cache_position = cache_position.to(self.device)

return use_cache_branch_tensor, past_key_values, cache_position


class ORTDecoder(ORTDecoderForSeq2Seq):
def __init__(self, *args, **kwargs):
logger.warning(
"The class `ORTDecoder` is deprecated and will be removed in optimum v1.15.0, please use `ORTDecoderForSeq2Seq` instead."
)
super().__init__(*args, **kwargs)
Loading

0 comments on commit ca36fc4

Please sign in to comment.