Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add code import to train/eval scripts #1002

Merged
merged 13 commits into from
Mar 11, 2024
37 changes: 37 additions & 0 deletions llmfoundry/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import importlib.util
import os
from pathlib import Path
from types import ModuleType
from typing import Union

__all__ = ['import_file']


def import_file(loc: Union[str, Path]) -> ModuleType:
"""Import module from a file. Used to run arbitrary python code.

Args:
name (str): Name of module to load.
loc (str / Path): Path to the file.

Returns:
ModuleType: The module object.
"""
if not os.path.exists(loc):
raise FileNotFoundError(f'File {loc} does not exist.')

spec = importlib.util.spec_from_file_location('python_code', str(loc))

assert spec is not None
assert spec.loader is not None

module = importlib.util.module_from_spec(spec)

try:
spec.loader.exec_module(module)
except Exception as e:
raise RuntimeError(f'Error executing {loc}') from e
return module
11 changes: 11 additions & 0 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

install()
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
from llmfoundry.registry import import_file
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
build_evaluators, build_logger,
build_tokenizer)
Expand Down Expand Up @@ -188,6 +189,16 @@ def evaluate_model(


def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]:
# Run user provided code if specified
code_paths = pop_config(cfg,
'code_paths',
must_exist=False,
default_value=[],
convert=True)
# Import any user provided code
for code_path in code_paths:
import_file(code_path)

om.resolve(cfg)

# Create copy of config for logging
Expand Down
11 changes: 11 additions & 0 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from llmfoundry import COMPOSER_MODEL_REGISTRY
from llmfoundry.callbacks import AsyncEval
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.registry import import_file
from llmfoundry.utils.builders import (add_metrics_to_eval_loaders,
build_algorithm, build_callback,
build_evaluators, build_logger,
Expand Down Expand Up @@ -158,6 +159,16 @@ def main(cfg: DictConfig) -> Trainer:
'torch.distributed.*_base is a private function and will be deprecated.*'
)

# Run user provided code if specified
code_paths = pop_config(cfg,
'code_paths',
must_exist=False,
default_value=[],
convert=True)
# Import any user provided code
for code_path in code_paths:
import_file(code_path)

# Check for incompatibilities between the model and data loaders
validate_config(cfg)

Expand Down
44 changes: 44 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import os
import pathlib

import pytest

from llmfoundry.registry import import_file


def test_registry_init_code(tmp_path: pathlib.Path):
register_code = """
import os
os.environ['TEST_ENVIRON_REGISTRY_KEY'] = 'test'
"""

with open(tmp_path / 'init_code.py', 'w') as _f:
_f.write(register_code)

import_file(tmp_path / 'init_code.py')

assert os.environ['TEST_ENVIRON_REGISTRY_KEY'] == 'test'

del os.environ['TEST_ENVIRON_REGISTRY_KEY']


def test_registry_init_code_fails(tmp_path: pathlib.Path):
register_code = """
import os
os.environ['TEST_ENVIRON_REGISTRY_KEY'] = 'test'
asdf
"""

with open(tmp_path / 'init_code.py', 'w') as _f:
_f.write(register_code)

with pytest.raises(RuntimeError, match='Error executing .*init_code.py'):
import_file(tmp_path / 'init_code.py')


def test_registry_init_code_dne(tmp_path: pathlib.Path):
with pytest.raises(FileNotFoundError, match='File .* does not exist'):
import_file(tmp_path / 'init_code.py')
Loading