Skip to content

Commit

Permalink
fix many more test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Apr 17, 2024
1 parent 0b5721e commit f770f60
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
6 changes: 4 additions & 2 deletions llmfoundry/models/hf/hf_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from composer.models.huggingface import peft_installed
from composer.utils import dist
from omegaconf import OmegaConf as om
from transformers import (AutoConfig, AutoModelForCausalLM, PretrainedConfig,
PreTrainedModel, PreTrainedTokenizerBase)

Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
pretrained_model_name_or_path: str,
pretrained: bool,
pretrained: Optional[bool] = True,
pretrained_lora_id_or_path: Optional[str] = None,
trust_remote_code: bool = True,
use_auth_token: bool = False,
Expand All @@ -77,7 +78,8 @@ def __init__(

from llmfoundry.utils.builders import build_metric

config_overrides = config_overrides or {}
config_overrides = om.to_container(
config_overrides, resolve=True) if config_overrides else {}
additional_train_metrics = additional_train_metrics or []

pretrained_model_name_or_path = pretrained_model_name_or_path
Expand Down
7 changes: 5 additions & 2 deletions llmfoundry/models/hf/hf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import List, Mapping, Optional

from composer.utils import dist
from omegaconf import OmegaConf as om
from transformers import (AutoConfig, PreTrainedTokenizerBase,
T5ForConditionalGeneration)

Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
pretrained_model_name_or_path: str,
pretrained: bool,
pretrained: Optional[bool] = True,
trust_remote_code: bool = True,
use_auth_token: bool = False,
config_overrides: Optional[Mapping] = None,
Expand All @@ -57,14 +58,16 @@ def __init__(
):
from llmfoundry.utils.builders import build_metric

config_overrides = om.to_container(config_overrides or {}, resolve=True)

config = AutoConfig.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)

# set config overrides
for k, v in (config_overrides or {}):
for k, v in (config_overrides or {}).items():
if not hasattr(config, k):
raise ValueError(
f'config does not have attribute "{k}" to override ({k}: {v}).'
Expand Down
2 changes: 1 addition & 1 deletion tests/models/hf/test_hf_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_olmo_wraps():

config = DictConfig(conf)

model = ComposerHFCausalLM(config.model, None)
model = ComposerHFCausalLM(**config.model, tokenizer=None)

# check that all the modules we except are blocked from FSDP wrapping
underlying_model = maybe_get_underlying_model(model.model)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/hf/test_hf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ def test_experimental_hf_t5():
tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base')

with pytest.warns(ExperimentalWarning):
_ = ComposerHFT5(cfg, tokenizer)
_ = ComposerHFT5(**cfg, tokenizer=tokenizer)

0 comments on commit f770f60

Please sign in to comment.