Skip to content

Commit

Permalink
improve error when architecture deps are missing
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Oct 3, 2024
1 parent 7b68e6c commit 8917a25
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 9 deletions.
6 changes: 2 additions & 4 deletions src/metatrain/cli/export.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import argparse
import importlib
import logging
from pathlib import Path
from typing import Any, Union

import torch

from ..utils.architectures import check_architecture_name, find_all_architectures
from ..utils.architectures import find_all_architectures, import_architecture
from ..utils.export import is_exported
from ..utils.io import check_file_extension
from .formatter import CustomHelpFormatter
Expand Down Expand Up @@ -57,8 +56,7 @@ def _add_export_model_parser(subparser: argparse._SubParsersAction) -> None:
def _prepare_export_model_args(args: argparse.Namespace) -> None:
"""Prepare arguments for export_model."""
architecture_name = args.__dict__.pop("architecture_name")
check_architecture_name(architecture_name)
architecture = importlib.import_module(f"metatrain.{architecture_name}")
architecture = import_architecture(architecture_name)

args.model = architecture.__model__.load_checkpoint(args.__dict__.pop("path"))

Expand Down
9 changes: 6 additions & 3 deletions src/metatrain/cli/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import importlib
import itertools
import json
import logging
Expand All @@ -14,7 +13,11 @@
from omegaconf import DictConfig, OmegaConf

from .. import PACKAGE_ROOT
from ..utils.architectures import check_architecture_options, get_default_hypers
from ..utils.architectures import (
check_architecture_options,
get_default_hypers,
import_architecture,
)
from ..utils.data import (
DatasetInfo,
TargetInfoDict,
Expand Down Expand Up @@ -135,7 +138,7 @@ def train_model(
check_architecture_options(
name=architecture_name, options=OmegaConf.to_container(options["architecture"])
)
architecture = importlib.import_module(f"metatrain.{architecture_name}")
architecture = import_architecture(architecture_name)

logger.info(f"Running training for {architecture_name!r} architecture")

Expand Down
26 changes: 26 additions & 0 deletions src/metatrain/utils/architectures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import difflib
import importlib
import json
import logging
from ast import Import
from importlib.util import find_spec
from pathlib import Path
from typing import Dict, List, Union
Expand Down Expand Up @@ -110,6 +112,30 @@ def get_architecture_name(path: Union[str, Path]) -> str:
return name


def import_architecture(name: str):
"""Import an architecture.
:param name: name of the architecture
:raises ImportError: if the architecture dependencies are not met
"""
check_architecture_name(name)
try:
return importlib.import_module(f"metatrain.{name}")
except ImportError as err:
# consistent name with pyproject.toml's `optional-dependencies` section
name_for_deps = name
if "experimental." in name or "deprecated." in name:
name_for_deps = ".".join(name.split(".")[1:])

name_for_deps = name_for_deps.replace("_", "-")

raise ImportError(
f"Trying to import '{name}' but architecture dependencies "
f"seem not be installed. \n"
f"Try to install them with `pip install .[{name_for_deps}]`"
) from err


def get_architecture_path(name: str) -> Path:
"""Return the relative path to the architeture directory.
Expand Down
4 changes: 2 additions & 2 deletions src/metatrain/utils/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
from omegaconf.basecontainer import BaseContainer

from .. import PACKAGE_ROOT, RANDOM_SEED
from .architectures import import_architecture
from .devices import pick_devices
from .jsonschema import validate


def _get_architecture_model(conf: BaseContainer) -> Any:
architecture_name = conf["architecture"]["name"]
architecture = importlib.import_module(f"metatrain.{architecture_name}")
architecture = import_architecture(conf["architecture"]["name"])
return architecture.__model__


Expand Down
30 changes: 30 additions & 0 deletions tests/utils/test_architectures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
from pathlib import Path

import pytest
Expand All @@ -11,9 +12,14 @@
get_architecture_name,
get_architecture_path,
get_default_hypers,
import_architecture,
)


def is_None(*args, **kwargs) -> None:
return None


def test_find_all_architectures():
all_arches = find_all_architectures()
assert len(all_arches) == 4
Expand Down Expand Up @@ -116,3 +122,27 @@ def test_check_architecture_options_error_raise():
match = r"Unrecognized options \('num_epochxxx' was unexpected\)"
with pytest.raises(ValidationError, match=match):
check_architecture_options(name=name, options=options)


def test_import_architecture():
name = "experimental.soap_bpnn"
architecture_ref = importlib.import_module(f"metatrain.{name}")
assert import_architecture(name) == architecture_ref


def test_import_architecture_erro(monkeypatch):
# `check_architecture_name` is called inside `import_architecture` and we have to
# disble the check to allow passing our "unknown" fancy-model below.
monkeypatch.setattr(
"metatrain.utils.architectures.check_architecture_name", is_None
)

name = "experimental.fancy_model"
name_for_deps = "fancy-model"

match = (
rf"Trying to import '{name}' but architecture dependencies seem not be "
rf"installed. \nTry to install them with `pip install .\[{name_for_deps}\]`"
)
with pytest.raises(ImportError, match=match):
import_architecture(name)

0 comments on commit 8917a25

Please sign in to comment.