Skip to content

Commit

Permalink
enable load local model in subfolder
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 9, 2024
1 parent 82f2699 commit 9b6c221
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
6 changes: 2 additions & 4 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,13 +510,12 @@ def _from_pretrained(

if file_name is None:
if model_path.is_dir():
onnx_files = list(model_path.glob("*.onnx"))
onnx_files = list((model_path / subfolder).glob("*.onnx"))
else:
repo_files, _ = TasksManager.get_model_files(
model_id, revision=revision, cache_dir=cache_dir, token=token
)
repo_files = map(Path, repo_files)

pattern = "*.onnx" if subfolder == "" else f"{subfolder}/*.onnx"
onnx_files = [p for p in repo_files if p.match(pattern)]

Expand Down Expand Up @@ -983,10 +982,9 @@ def _cached_file(
token = use_auth_token

model_path = Path(model_path)

# locates a file in a local folder and repo, downloads and cache it if necessary.
if model_path.is_dir():
model_cache_path = model_path / file_name
model_cache_path = model_path / subfolder / file_name
preprocessors = maybe_load_preprocessors(model_path.as_posix())
else:
model_cache_path = hf_hub_download(
Expand Down
14 changes: 14 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import requests
import timm
import torch
from huggingface_hub import HfApi
from huggingface_hub.constants import default_cache_path
from parameterized import parameterized
from PIL import Image
Expand Down Expand Up @@ -1263,6 +1264,19 @@ def test_trust_remote_code(self):
torch.allclose(pt_logits, ort_logits, atol=1e-4), f" Maxdiff: {torch.abs(pt_logits - ort_logits).max()}"
)

@parameterized.expand(("", "onnx"))
def test_loading_with_config_in_root(self, subfolder):
# config.json file in the root directory and not in the subfolder
model_id = "sentence-transformers-testing/stsb-bert-tiny-onnx"
# hub model
ORTModelForFeatureExtraction.from_pretrained(model_id, subfolder=subfolder, export=subfolder == "")
# local model
api = HfApi()
with tempfile.TemporaryDirectory() as tmpdirname:
local_dir = Path(tmpdirname) / "model"
api.snapshot_download(repo_id=model_id, local_dir=local_dir)
ORTModelForFeatureExtraction.from_pretrained(local_dir, subfolder=subfolder, export=subfolder == "")


class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = [
Expand Down

0 comments on commit 9b6c221

Please sign in to comment.