Skip to content

Commit

Permalink
Add OpenVINO support
Browse files Browse the repository at this point in the history
  • Loading branch information
helena-intel committed Aug 24, 2024
1 parent b37f470 commit 4b31bfa
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 10 deletions.
7 changes: 5 additions & 2 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class SentenceTransformer(nn.Sequential, FitMixin):
model_card_data (:class:`~sentence_transformers.model_card.SentenceTransformerModelCardData`, optional): A model
card data object that contains information about the model. This is used to generate a model card when saving
the model. If not set, a default model card data object is created.
backend (str, optional): If set to "openvino", use OpenVINO backend for Hugging Face Transformers model
Example:
::
Expand Down Expand Up @@ -162,6 +163,7 @@ def __init__(
tokenizer_kwargs: dict[str, Any] | None = None,
config_kwargs: dict[str, Any] | None = None,
model_card_data: SentenceTransformerModelCardData | None = None,
backend: str = None,
) -> None:
# Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
self.prompts = prompts or {}
Expand All @@ -172,6 +174,7 @@ def __init__(
self._model_card_vars = {}
self._model_card_text = None
self._model_config = {}
self._backend = backend
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v3 of SentenceTransformers.",
Expand Down Expand Up @@ -1408,6 +1411,7 @@ def _load_auto_model(
model_args=model_kwargs,
tokenizer_args=tokenizer_kwargs,
config_args=config_kwargs,
backend=self._backend,
)
pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), "mean")
self.model_card_data.set_base_model(model_name_or_path, revision=revision)
Expand Down Expand Up @@ -1564,8 +1568,7 @@ def _load_sbert_model(
kwargs["tokenizer_args"].update(tokenizer_kwargs)
if config_kwargs:
kwargs["config_args"].update(config_kwargs)

module = Transformer(model_name_or_path, cache_dir=cache_folder, **kwargs)
module = Transformer(model_name_or_path, cache_dir=cache_folder, backend=self._backend, **kwargs)
else:
# Normalize does not require any files to be loaded
if module_class == Normalize:
Expand Down
56 changes: 48 additions & 8 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
from pathlib import Path
from typing import Any

import torch
Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(
cache_dir: str | None = None,
do_lower_case: bool = False,
tokenizer_name_or_path: str = None,
backend: str = None,
) -> None:
super().__init__()
self.config_keys = ["max_seq_length", "do_lower_case"]
Expand All @@ -53,7 +55,7 @@ def __init__(
config_args = {}

config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
self._load_model(model_name_or_path, config, cache_dir, **model_args)
self._load_model(model_name_or_path, config, cache_dir, backend, **model_args)

if max_seq_length is not None and "model_max_length" not in tokenizer_args:
tokenizer_args["model_max_length"] = max_seq_length
Expand All @@ -77,17 +79,55 @@ def __init__(
if tokenizer_name_or_path is not None:
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__

def _load_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None:
"""Loads the transformer model"""
if isinstance(config, T5Config):
self._load_t5_model(model_name_or_path, config, cache_dir, **model_args)
elif isinstance(config, MT5Config):
self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args)
if backend is None:
if isinstance(config, T5Config):
self._load_t5_model(model_name_or_path, config, cache_dir, **model_args)
elif isinstance(config, MT5Config):
self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args)
else:
self.auto_model = AutoModel.from_pretrained(
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
)
elif backend == "openvino":
if isinstance(config, T5Config) or isinstance(config, MT5Config):
raise ValueError("T5 models are not yet supported by the OpenVINO backend.")
else:
self._load_openvino_model(model_name_or_path, cache_dir, **model_args)
else:
self.auto_model = AutoModel.from_pretrained(
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
raise ValueError(f"Unsupported backend '{backend}'. `backend` should be `None` or `openvino`.")

def _load_openvino_model(self, model_name_or_path, cache_dir, **model_args) -> None:
config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
if isinstance(config, T5Config) or isinstance(config, MT5Config):
raise ValueError("T5 models are not yet supported by the OpenVINO backend.")

try:
from optimum.intel import OVModelForFeatureExtraction
except ModuleNotFoundError:
raise Exception(
"Using the OpenVINO backend requires installing optimum-intel and OpenVINO. You can install them with pip: `pip install optimum-intel openvino`."
)

export = not (Path(model_name_or_path) / "openvino_model.xml").is_file()

if "ov_config" in model_args:
ov_config = model_args["ov_config"]
# ov_config can be either a dictionary, or point to a json file with an OpenVINO config
if not isinstance(ov_config, dict):
if not Path(ov_config).exists():
raise ValueError(
"ov_config should be a dictionary or point to a .json file containing an OpenVINO config"
)
with open(ov_config) as f:
model_args["ov_config"] = json.load(f)
else:
model_args["ov_config"] = {}
self.auto_model = OVModelForFeatureExtraction.from_pretrained(
model_name_or_path, export=export, cache_dir=cache_dir, **model_args
)

def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
"""Loads the encoder model from T5"""
from transformers import T5EncoderModel
Expand Down
46 changes: 46 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import re
import tempfile
from functools import partial
from importlib.util import find_spec
from pathlib import Path
from typing import Dict, List, Literal, cast

Expand Down Expand Up @@ -653,6 +654,7 @@ def test_override_config_versions(stsb_bert_tiny_model: SentenceTransformer) ->
SentenceTransformer("sentence-transformers/average_word_embeddings_levy_dependency"),
],
)

def test_safetensors(modules: list[nn.Module] | SentenceTransformer) -> None:
if isinstance(modules, SentenceTransformer):
model = modules
Expand All @@ -679,3 +681,47 @@ def test_empty_encode(stsb_bert_tiny_model: SentenceTransformer) -> None:
model = stsb_bert_tiny_model
embeddings = model.encode([])
assert embeddings.shape == (0,)


@pytest.mark.skipif(
find_spec("openvino") is None or find_spec("optimum.intel") is None,
reason="optimum-intel and openvino must be installed for OpenVINO test",
)
def test_openvino_backend() -> None:
model_id = "sentence-transformers-testing/stsb-bert-tiny-safetensors"
# Test that OpenVINO output is close to PyTorch output
pytorch_model = SentenceTransformer(model_id)
openvino_model = SentenceTransformer(
model_id,
backend="openvino",
model_kwargs={"ov_config": {"INFERENCE_PRECISION_HINT": "f32"}},
)
pytorch_result = pytorch_model.encode(["Hello there!"])
openvino_result = openvino_model.encode(["Hello there!"])
assert np.allclose(openvino_result, pytorch_result, atol=0.000001), "OpenVINO and Pytorch outputs are not close"

with tempfile.TemporaryDirectory() as tmpdirname:
# Test that loading with ov_config file works as expected
config_file = str(Path(tmpdirname) / "ov_config.json")
with open(Path(config_file), "w") as f:
f.write('{"NUM_STREAMS" : "2"}')
openvino_model_with_config = SentenceTransformer(
model_id,
backend="openvino",
model_kwargs={"ov_config": config_file},
)
# The transformers model is an Optimum model with an OpenVINO inference request property
transformers_model = next(
module for module in openvino_model_with_config.modules() if isinstance(module, Transformer)
)
assert transformers_model.auto_model.request.get_property("NUM_STREAMS") == 2

# Test that saving and loading local OpenVINO models works as expected
openvino_model_with_config.save(tmpdirname)
local_openvino_model = SentenceTransformer(
tmpdirname, backend="openvino", model_kwargs={"ov_config": {"INFERENCE_PRECISION_HINT": "f32"}}
)
local_openvino_result = local_openvino_model.encode(["Hello there!"])
assert np.allclose(
local_openvino_result, openvino_result
), "OpenVINO saved model output differs from in-memory converted model"

0 comments on commit 4b31bfa

Please sign in to comment.