Skip to content

Commit

Permalink
misc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aspfohl committed Dec 6, 2023
1 parent bb040d1 commit 14f386f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 17 deletions.
60 changes: 44 additions & 16 deletions llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union

from composer.callbacks import CheckpointSaver
Expand Down Expand Up @@ -94,7 +95,8 @@ def get_latest_checkpoint(event: Event, state: State) -> Optional[str]:
log.warning('No saved checkpoints found on the checkpointer')
return None

return checkpointer.saved_checkpoints[-1]
latest = checkpointer.saved_checkpoints[-1]
return str(Path(latest).parts[-1])


def get_eval_parameters(
Expand Down Expand Up @@ -153,6 +155,35 @@ def get_eval_parameters(
return subset_keys


def validate_interval(interval: Union[str, int, Time],
save_interval: Union[str, int, Time]) -> Time:
if isinstance(save_interval, str):
new_save_interval: Time = Time.from_timestring(save_interval)
elif isinstance(save_interval, int):
new_save_interval: Time = Time(save_interval, TimeUnit.EPOCH)
else:
new_save_interval: Time = save_interval

if isinstance(interval, str):
result: Time = Time.from_timestring(interval)
elif isinstance(interval, int):
result: Time = Time(interval, TimeUnit.EPOCH)
else:
result: Time = interval

if new_save_interval.unit != result.unit:
raise ValueError(
'Save interval and async eval interval must be in the same unit')
if result < new_save_interval:
raise ValueError(
'Async eval interval must be equal or greater (less frequent) than save interval'
)
if result.value % new_save_interval.value != 0:
raise ValueError(
'Async eval interval must be a multiple of save interval')
return result


class AsyncEval(Callback):
"""Run the eval loop asynchronously as part of a MosaicML platform run.
Expand All @@ -176,15 +207,14 @@ def __init__(
compute: Optional[Union[ComputeConfig, Dict[str, Any]]] = None,
):

self.training_config = training_config

if isinstance(interval, str):
self.interval = Time.from_timestring(interval)
elif isinstance(interval, int):
self.interval = Time(interval, TimeUnit.EPOCH)
else:
self.interval = interval
for required in ('save_interval', 'save_folder'):
if required not in training_config:
raise ValueError(f'{required} required for async eval')

self.checkpoint_save_folder = training_config['save_folder']
self.training_config = training_config
self.interval = validate_interval(interval,
self.training_config['save_interval'])
self.check_interval = create_interval_scheduler(
interval,
# There is a custom close to ensure that the final checkpoint
Expand Down Expand Up @@ -220,34 +250,32 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
if not checkpoint:
return # warnings logged in get_latest_checkpoint

if checkpoint == self.last_checkpoint:
full_checkpoint = f'{self.checkpoint_save_folder}/{checkpoint}'
if full_checkpoint == self.last_checkpoint:
# Do not eval a checkpoint that has already been evaluated.
log.info(
'Skipping async eval because the checkpoint has not changed'
)
return

self.launch_run(checkpoint, current_interval)
self.last_checkpoint = checkpoint
self.launch_run(full_checkpoint, current_interval)
self.last_checkpoint = full_checkpoint

def close(self, state: State, logger: Logger) -> None:
del state
del logger

if dist.get_global_rank() != 0:
return
self.training_config

# TODO: enforce this exists before
save_folder = self.training_config['save_folder']
save_latest_filename = self.training_config.get('save_latest_filename',
None)

if not save_latest_filename:
rank = dist.get_global_rank()
save_latest_filename = f'latest-rank{rank}.pt'

checkpoint = f'{save_folder}/{save_latest_filename}'
checkpoint = f'{self.checkpoint_save_folder}/{save_latest_filename}'
self.launch_run(checkpoint, 'final')

def _get_current_run(self) -> Run:
Expand Down
19 changes: 18 additions & 1 deletion tests/callbacks/test_async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from unittest.mock import MagicMock, patch

import pytest
from composer.core import Time, TimeUnit

from llmfoundry.callbacks.async_eval_callback import (AsyncEval,
get_eval_parameters,
get_run_name)
get_run_name,
validate_interval)
from mcli import Run, RunConfig, RunStatus

# here
Expand Down Expand Up @@ -164,6 +166,21 @@ def test_get_eval_parameters():
}


def test_validate_interval():
with pytest.raises(ValueError):
validate_interval('1ba', '1ep') # different units
with pytest.raises(ValueError):
validate_interval('1ba', '2ba') # checkpointing happens less often
with pytest.raises(ValueError):
validate_interval('3ba', '2ba') # not a multiple

assert validate_interval('2ba', '1ba') == Time(2, TimeUnit.BATCH)
two_epochs = Time(2, TimeUnit.EPOCH)
assert validate_interval(2, 2) == two_epochs
assert validate_interval(two_epochs, two_epochs) == two_epochs
assert validate_interval('2ep', two_epochs) == two_epochs


FAKE_RUN = Run(
run_uid='123',
name=RUN_NAME,
Expand Down

0 comments on commit 14f386f

Please sign in to comment.