Skip to content

Commit

Permalink
Add default signature to mlflow saved model (#952)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Feb 7, 2024
1 parent 105f766 commit 60ab97f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 7 deletions.
30 changes: 29 additions & 1 deletion llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ class HuggingFaceCheckpointer(Callback):
that will get passed along to the MLflow ``save_model`` call.
Expected to contain ``metadata`` and ``task`` keys. If either is
unspecified, the defaults are ``'text-generation'`` and
``{'task': 'llm/v1/completions'}`` respectively.
``{'task': 'llm/v1/completions'}`` respectively. A default input example
and signature intended for text generation is also included under the
keys ``input_example`` and ``signature``.
flatten_imports (Sequence[str]): A sequence of import prefixes that will
be flattened when editing MPT files.
"""
Expand Down Expand Up @@ -126,6 +128,10 @@ def __init__(
if mlflow_logging_config is None:
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
import numpy as np
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import ColSpec, Schema

# Both the metadata and the task are needed in order for mlflow
# and databricks optimized model serving to work
default_metadata = {'task': 'llm/v1/completions'}
Expand All @@ -135,6 +141,28 @@ def __init__(
**passed_metadata
}
mlflow_logging_config.setdefault('task', 'text-generation')

# Define a default input/output that is good for standard text generation LMs
input_schema = Schema([
ColSpec('string', 'prompt'),
ColSpec('double', 'temperature', optional=True),
ColSpec('integer', 'max_tokens', optional=True),
ColSpec('string', 'stop', optional=True),
ColSpec('integer', 'candidate_count', optional=True)
])

output_schema = Schema([ColSpec('string', 'predictions')])

default_signature = ModelSignature(inputs=input_schema,
outputs=output_schema)

default_input_example = {
'prompt': np.array(['What is Machine Learning?'])
}
mlflow_logging_config.setdefault('input_example',
default_input_example)
mlflow_logging_config.setdefault('signature', default_signature)

self.mlflow_logging_config = mlflow_logging_config

self.huggingface_folder_name_fstr = os.path.join(
Expand Down
37 changes: 31 additions & 6 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,12 @@ def test_callback_inits():
save_folder='test',
save_interval='1ba',
mlflow_registered_model_name='test_model_name')
assert hf_checkpointer.mlflow_logging_config == {
'task': 'text-generation',
'metadata': {
'task': 'llm/v1/completions'
}
}

assert hf_checkpointer.mlflow_logging_config['task'] == 'text-generation'
assert hf_checkpointer.mlflow_logging_config['metadata'][
'task'] == 'llm/v1/completions'
assert 'input_example' in hf_checkpointer.mlflow_logging_config
assert 'signature' in hf_checkpointer.mlflow_logging_config


@pytest.mark.gpu
Expand Down Expand Up @@ -331,6 +331,8 @@ def test_huggingface_conversion_callback_interval(
transformers_model=ANY,
path=ANY,
task='text-generation',
input_example=ANY,
signature=ANY,
metadata={'task': 'llm/v1/completions'})
assert mlflow_logger_mock.register_model.call_count == 1
else:
Expand Down Expand Up @@ -593,11 +595,34 @@ def test_huggingface_conversion_callback(
}
}
else:
import numpy as np
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import ColSpec, Schema

input_schema = Schema([
ColSpec('string', 'prompt'),
ColSpec('double', 'temperature', optional=True),
ColSpec('integer', 'max_tokens', optional=True),
ColSpec('string', 'stop', optional=True),
ColSpec('integer', 'candidate_count', optional=True)
])

output_schema = Schema([ColSpec('string', 'predictions')])

default_signature = ModelSignature(inputs=input_schema,
outputs=output_schema)

default_input_example = {
'prompt': np.array(['What is Machine Learning?'])
}

expectation = {
'flavor': 'transformers',
'transformers_model': ANY,
'path': ANY,
'task': 'text-generation',
'signature': default_signature,
'input_example': default_input_example,
'metadata': {
'task': 'llm/v1/completions'
}
Expand Down

0 comments on commit 60ab97f

Please sign in to comment.