Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
es94129 committed Aug 19, 2023
1 parent ed4fe1f commit 7ed0459
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
3 changes: 2 additions & 1 deletion llmfoundry/models/hf/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def __init__(self,
self.model.forward).args
# inspect.getfullargspec HuggingFace quantized model could not return args correctly
if not self.model_forward_args:
self.model_forward_args = inspect.signature(self.model.forward).parameters.keys()
self.model_forward_args = inspect.signature(
self.model.forward).parameters.keys()

# Note: We need to add the FSDP related attributes to the model AFTER the super init,
# so that the (possible) embedding resizing doesn't destroy them
Expand Down
5 changes: 3 additions & 2 deletions scripts/eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ def evaluate_model(model_cfg: DictConfig, dist_timeout: Union[float, int],
model_gauntlet) # type: ignore

if fsdp_config and model_cfg.model.load_in_8bit:
raise ValueError("The FSDP config block is not supported when loading " +
"Hugging Face models in 8bit.")
raise ValueError(
'The FSDP config block is not supported when loading ' +
'Hugging Face models in 8bit.')

if hasattr(model_cfg.model, 'pretrained_lora_id_or_path'):
composer_model = load_peft_model(model_cfg.model, tokenizer,
Expand Down
2 changes: 1 addition & 1 deletion scripts/eval/yamls/hf_8bit_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ device_eval_batch_size: 4

# With load_in_8bit, do not specify fsdp_config

icl_tasks: 'eval/yamls/tasks_light.yaml'
icl_tasks: 'eval/yamls/tasks_light.yaml'
6 changes: 4 additions & 2 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ def validate_config(cfg: DictConfig):
'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.'
)
torch._dynamo.config.suppress_errors = True # type: ignore

if cfg.model.get('load_in_8bit', False):
raise ValueError("`load_in_8bit` is only supported for evaluation rather than training.")
raise ValueError(
'`load_in_8bit` is only supported for evaluation rather than training.'
)


def build_composer_model(model_cfg: DictConfig,
Expand Down

0 comments on commit 7ed0459

Please sign in to comment.