Skip to content

Commit

Permalink
Remove the need for the config to be in the subfolder (#2044)
Browse files Browse the repository at this point in the history
* remove the need for the config to be in the subfolder

* fix

* fix for offline mode

* add log

* fix

* enable load local model in subfolder

* fix windows
  • Loading branch information
echarlaix authored Oct 10, 2024
1 parent 2c0476e commit 1b5a63d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 18 deletions.
36 changes: 22 additions & 14 deletions optimum/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,27 +380,35 @@ def from_pretrained(
)
model_id, revision = model_id.split("@")

all_files, _ = TasksManager.get_model_files(
model_id,
subfolder=subfolder,
cache_dir=cache_dir,
revision=revision,
token=token,
)

config_folder = subfolder
if cls.config_name not in all_files:
logger.info(
f"{cls.config_name} not found in the specified subfolder {subfolder}. Using the top level {cls.config_name}."
)
config_folder = ""

library_name = TasksManager.infer_library_from_model(
model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
model_id, subfolder=config_folder, revision=revision, cache_dir=cache_dir, token=token
)

if library_name == "timm":
config = PretrainedConfig.from_pretrained(
model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
model_id, subfolder=config_folder, revision=revision, cache_dir=cache_dir, token=token
)

if config is None:
if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME:
if CONFIG_NAME in os.listdir(os.path.join(model_id, subfolder)):
config = AutoConfig.from_pretrained(
os.path.join(model_id, subfolder), trust_remote_code=trust_remote_code
)
elif CONFIG_NAME in os.listdir(model_id):
if os.path.isdir(os.path.join(model_id, config_folder)) and cls.config_name == CONFIG_NAME:
if CONFIG_NAME in os.listdir(os.path.join(model_id, config_folder)):
config = AutoConfig.from_pretrained(
os.path.join(model_id, CONFIG_NAME), trust_remote_code=trust_remote_code
)
logger.info(
f"config.json not found in the specified subfolder {subfolder}. Using the top level config.json."
os.path.join(model_id, config_folder), trust_remote_code=trust_remote_code
)
else:
raise OSError(f"config.json not found in {model_id} local folder")
Expand All @@ -411,7 +419,7 @@ def from_pretrained(
cache_dir=cache_dir,
token=token,
force_download=force_download,
subfolder=subfolder,
subfolder=config_folder,
trust_remote_code=trust_remote_code,
)
elif isinstance(config, (str, os.PathLike)):
Expand All @@ -421,7 +429,7 @@ def from_pretrained(
cache_dir=cache_dir,
token=token,
force_download=force_download,
subfolder=subfolder,
subfolder=config_folder,
trust_remote_code=trust_remote_code,
)

Expand Down
6 changes: 2 additions & 4 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,13 +510,12 @@ def _from_pretrained(

if file_name is None:
if model_path.is_dir():
onnx_files = list(model_path.glob("*.onnx"))
onnx_files = list((model_path / subfolder).glob("*.onnx"))
else:
repo_files, _ = TasksManager.get_model_files(
model_id, revision=revision, cache_dir=cache_dir, token=token
)
repo_files = map(Path, repo_files)

pattern = "*.onnx" if subfolder == "" else f"{subfolder}/*.onnx"
onnx_files = [p for p in repo_files if p.match(pattern)]

Expand Down Expand Up @@ -983,10 +982,9 @@ def _cached_file(
token = use_auth_token

model_path = Path(model_path)

# locates a file in a local folder and repo, downloads and cache it if necessary.
if model_path.is_dir():
model_cache_path = model_path / file_name
model_cache_path = model_path / subfolder / file_name
preprocessors = maybe_load_preprocessors(model_path.as_posix())
else:
model_cache_path = hf_hub_download(
Expand Down
15 changes: 15 additions & 0 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import requests
import timm
import torch
from huggingface_hub import HfApi
from huggingface_hub.constants import default_cache_path
from parameterized import parameterized
from PIL import Image
Expand Down Expand Up @@ -1263,6 +1264,20 @@ def test_trust_remote_code(self):
torch.allclose(pt_logits, ort_logits, atol=1e-4), f" Maxdiff: {torch.abs(pt_logits - ort_logits).max()}"
)

@parameterized.expand(("", "onnx"))
def test_loading_with_config_not_from_subfolder(self, subfolder):
# config.json file in the root directory and not in the subfolder
model_id = "sentence-transformers-testing/stsb-bert-tiny-onnx"
# hub model
ORTModelForFeatureExtraction.from_pretrained(model_id, subfolder=subfolder, export=subfolder == "")
# local model
api = HfApi()
with tempfile.TemporaryDirectory() as tmpdirname:
local_dir = Path(tmpdirname) / "model"
api.snapshot_download(repo_id=model_id, local_dir=local_dir)
ORTModelForFeatureExtraction.from_pretrained(local_dir, subfolder=subfolder, export=subfolder == "")
remove_directory(tmpdirname)


class ORTModelForQuestionAnsweringIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = [
Expand Down

0 comments on commit 1b5a63d

Please sign in to comment.