diff --git a/llmfoundry/models/hf/model_wrapper.py b/llmfoundry/models/hf/model_wrapper.py index 5b07fea356..a6bc9722fc 100644 --- a/llmfoundry/models/hf/model_wrapper.py +++ b/llmfoundry/models/hf/model_wrapper.py @@ -58,9 +58,9 @@ def __init__(self, self.model_forward_args = inspect.getfullargspec( self.model.forward).args - # inspecting HuggingFace quantized model could not return args correctly + # inspect.getfullargspec HuggingFace quantized model could not return args correctly if not self.model_forward_args: - self.model_forward_args = ['input_ids', 'attention_mask'] + self.model_forward_args = inspect.signature(model.forward).parameters.keys() # Note: We need to add the FSDP related attributes to the model AFTER the super init, # so that the (possible) embedding resizing doesn't destroy them