Skip to content

Commit

Permalink
Call generate callback at end of training (mosaicml#2607)
Browse files Browse the repository at this point in the history
* Call generate callback at end of training

* update, add tests

* foo

* test update

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
aspfohl and dakinggg committed Oct 5, 2023
1 parent de760ed commit 4934aa5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
8 changes: 6 additions & 2 deletions composer/callbacks/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,17 @@ def __init__(self,
self.prompts = prompts
self.generate_kwargs = kwargs
self.batch_size = batch_size if batch_size is not None else len(prompts)
self.check_interval = create_interval_scheduler(interval, include_end_of_training=False)
self.check_interval = create_interval_scheduler(interval, include_end_of_training=True)
self.last_generate_batch: Optional[Time] = None

def run_event(self, event: Event, state: State, logger: Logger) -> None:
if state.get_elapsed_duration() is not None and self.check_interval(state, event):
if state.get_elapsed_duration() is not None and self.check_interval(
state, event) and self.last_generate_batch != state.timestamp.batch:
self.generate(state, logger)

def generate(self, state: State, logger: Logger):
self.last_generate_batch = state.timestamp.batch

model = state.model.module if state.is_model_ddp else state.model
if not isinstance(model, HuggingFaceModel): # TODO: Extend to support any models that have a generate method.
raise ValueError(f'Expected HuggingFaceModel, but got {model.__class__.__name__}')
Expand Down
31 changes: 31 additions & 0 deletions tests/callbacks/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from packaging import version

from composer.callbacks import Generate
from composer.core import Event
from composer.trainer import Trainer
from composer.utils import dist
from tests.common.datasets import dummy_gpt_lm_dataloader
Expand Down Expand Up @@ -111,3 +112,33 @@ def test_calls(self, device, world_size, use_fsdp):
assert trainer.logger.log_table.call_count == expected_cb_call_count
else:
trainer.logger.log_table.assert_not_called()

def test_calls_end_of_training(self, device, world_size, use_fsdp):
self._check_test_params(device, world_size, use_fsdp)

prompts = ['a', 'bc', 'defg']
prompt_batch_size = 2
gen_interval = 2
generate_cb = Generate(prompts, interval=f'{gen_interval}ba', batch_size=prompt_batch_size, max_new_tokens=5)

# Create trainer with gen_interval > max_duration
train_batches = 1
trainer = self._create_trainer(device, f'{train_batches}ba', use_fsdp, generate_cb)

# Mock methods
state = trainer.state
model = state.model.module if state.is_model_ddp else state.model
model.generate = Mock(wraps=model.generate) # type: ignore
generate_cb.generate = Mock(wraps=generate_cb.generate)
trainer.logger.log_table = Mock()

trainer.fit()

expected_cb_call_count = 1

# Assert that the generate callback has been called ONLY once
assert generate_cb.generate.call_count == expected_cb_call_count

# An additional fit call should not trigger additional calls to generate
trainer.engine.run_event(Event.FIT_END)
assert generate_cb.generate.call_count == expected_cb_call_count

0 comments on commit 4934aa5

Please sign in to comment.