Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Dec 11, 2023
1 parent 818f4ac commit 0e4f085
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 28 deletions.
17 changes: 6 additions & 11 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,13 @@ def get_eval_parameters(
parameters: Dict[str, Any],
checkpoint: str,
training_run_name: str,
interval: Time,
) -> Dict[str, Any]:
"""Get the parameters needed for the eval run.
Args:
parameters: The parameters from the training run
checkpoint: The path to the latest checkpoint
training_run_name: The name of the training run
interval: The current Time interval
Returns:
The parameters needed for the eval run as a dict
Expand All @@ -136,13 +134,6 @@ def get_eval_parameters(
if logger == 'wandb':
config['group'] = config.pop('name', training_run_name)

config['init_kwargs'] = config.pop('init_kwargs', {})
config['init_kwargs']['config'] = config['init_kwargs'].pop(
'config', {})
config['init_kwargs']['config']['eval_interval'] = interval.value
config['init_kwargs']['config'][
'eval_interval_units'] = interval.unit.value

# mlflow currently does not support grouping, so this will just launch
# a new mlflow run

Expand Down Expand Up @@ -240,7 +231,6 @@ def __init__(
parameters=training_config,
checkpoint='test',
training_run_name=self.current_run.name,
interval=Time(0, self.interval.unit),
)
log.info(
f'Initialized AsyncEval callback. Will generate runs at interval {interval}'
Expand Down Expand Up @@ -320,7 +310,6 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run:
parameters=self.training_config,
checkpoint=checkpoint,
training_run_name=self.current_run.name,
interval=current_interval,
)
params['run_name'] = run_name

Expand Down Expand Up @@ -359,6 +348,12 @@ def launch_run(self, checkpoint: str, current_interval: Time) -> Run:
'ssh_clone': False,
})

# This will record the timestamp and make it available for grouping
# and plotting in wandb
metadata = cfg.metadata
metadata['eval_timestamp'] = current_interval.value
metadata['eval_timestamp_unit'] = current_interval.unit.value

# TODO: This just runs an eval run, but we also want to attach the
# deployment, which would require a hf conversion and parametrizing the
# dependent_deployment in the run config
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def build_icl_data_and_gauntlet(
def build_callback(
name: str,
kwargs: Union[DictConfig, Dict[str, Any]],
config: Any = None,
config: Dict[str, Any] = None,
) -> Callback:
if name == 'lr_monitor':
return LRMonitor()
Expand Down
4 changes: 3 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,8 +531,10 @@ def main(cfg: DictConfig) -> Trainer:
## Evaluation
print('Building eval loader...')
eval_icl_seq_len: int = icl_seq_len if icl_seq_len else max_seq_len
# TODO: evaluators should not be built at all if use_async_eval is True
# This will be fixed when eval_loader support is fully added to AsyncEval
evaluators, _, eval_gauntlet_callback = build_evaluators(
eval_loader_config, # TODO: async eval should not even call eval loader
eval_loader_config,
icl_tasks_config if not use_async_eval else None,
eval_gauntlet_config if not use_async_eval else None,
tokenizer=tokenizer,
Expand Down
28 changes: 13 additions & 15 deletions tests/callbacks/test_async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,10 @@ def test_get_eval_parameters():
with pytest.raises(
Exception,
match='Missing the following required parameters for async eval:'):
get_eval_parameters({}, 'checkpoints/file', RUN_NAME,
Time(0, TimeUnit.EPOCH))
get_eval_parameters({}, 'checkpoints/file', RUN_NAME)

# minimal example
params = get_eval_parameters(BASIC_PARAMS, 'checkpoints/file', RUN_NAME,
Time(0, TimeUnit.EPOCH))
params = get_eval_parameters(BASIC_PARAMS, 'checkpoints/file', RUN_NAME)
assert params == {
'device_eval_batch_size':
2,
Expand Down Expand Up @@ -126,9 +124,6 @@ def test_get_eval_parameters():
'loggers': {
'wandb': {
'init_kwargs': {
'config': {
'foo': 'bar'
},
'fee': 'bee'
}
}
Expand All @@ -141,7 +136,6 @@ def test_get_eval_parameters():
},
'checkpoints/file',
RUN_NAME,
Time(0, TimeUnit.EPOCH),
)
assert params2 == {
'device_eval_batch_size': 2,
Expand Down Expand Up @@ -172,11 +166,6 @@ def test_get_eval_parameters():
'wandb': {
'group': 'foo_bar-1234',
'init_kwargs': {
'config': {
'eval_interval': 0,
'eval_interval_units': 'ep',
'foo': 'bar'
},
'fee': 'bee'
},
}
Expand Down Expand Up @@ -244,14 +233,23 @@ def test_async_eval_callback_minimal(mock_create_run: MagicMock,
assert mock_get_run.call_count == 1
assert mock_get_run.call_args[0][0] == RUN_NAME

callback.launch_run('checkpoint/path', Time(1, TimeUnit.BATCH))
launch_time = Time(1, TimeUnit.BATCH)
callback.launch_run('checkpoint/path', launch_time)
assert mock_create_run.call_count == 1

run_config_created = mock_create_run.call_args[0][0]
assert run_config_created.name == 'eval-1ba-foo_bar'
assert run_config_created.image == 'fake-image'

print(run_config_created)
metadata = run_config_created.metadata
assert 'eval_timestamp' in metadata
assert isinstance(metadata['eval_timestamp'], int)
assert metadata['eval_timestamp'] == launch_time.value

assert 'eval_timestamp_unit' in metadata
assert isinstance(metadata['eval_timestamp_unit'], str)
assert metadata['eval_timestamp_unit'] == launch_time.unit.value

assert 'cd llm-foundry/scripts' in run_config_created.command

integrations = run_config_created.integrations
Expand Down

0 comments on commit 0e4f085

Please sign in to comment.