diff --git a/composer/callbacks/generate.py b/composer/callbacks/generate.py index ef854e6c0d..b8ce1a1bf0 100644 --- a/composer/callbacks/generate.py +++ b/composer/callbacks/generate.py @@ -41,13 +41,17 @@ def __init__(self, self.prompts = prompts self.generate_kwargs = kwargs self.batch_size = batch_size if batch_size is not None else len(prompts) - self.check_interval = create_interval_scheduler(interval, include_end_of_training=False) + self.check_interval = create_interval_scheduler(interval, include_end_of_training=True) + self.last_generate_batch: Optional[Time] = None def run_event(self, event: Event, state: State, logger: Logger) -> None: - if state.get_elapsed_duration() is not None and self.check_interval(state, event): + if state.get_elapsed_duration() is not None and self.check_interval( + state, event) and self.last_generate_batch != state.timestamp.batch: self.generate(state, logger) def generate(self, state: State, logger: Logger): + self.last_generate_batch = state.timestamp.batch + model = state.model.module if state.is_model_ddp else state.model if not isinstance(model, HuggingFaceModel): # TODO: Extend to support any models that have a generate method. raise ValueError(f'Expected HuggingFaceModel, but got {model.__class__.__name__}') diff --git a/tests/callbacks/test_generate.py b/tests/callbacks/test_generate.py index 925d21bcc1..a848071dff 100644 --- a/tests/callbacks/test_generate.py +++ b/tests/callbacks/test_generate.py @@ -10,6 +10,7 @@ from packaging import version from composer.callbacks import Generate +from composer.core import Event from composer.trainer import Trainer from composer.utils import dist from tests.common.datasets import dummy_gpt_lm_dataloader @@ -111,3 +112,33 @@ def test_calls(self, device, world_size, use_fsdp): assert trainer.logger.log_table.call_count == expected_cb_call_count else: trainer.logger.log_table.assert_not_called() + + def test_calls_end_of_training(self, device, world_size, use_fsdp): + self._check_test_params(device, world_size, use_fsdp) + + prompts = ['a', 'bc', 'defg'] + prompt_batch_size = 2 + gen_interval = 2 + generate_cb = Generate(prompts, interval=f'{gen_interval}ba', batch_size=prompt_batch_size, max_new_tokens=5) + + # Create trainer with gen_interval > max_duration + train_batches = 1 + trainer = self._create_trainer(device, f'{train_batches}ba', use_fsdp, generate_cb) + + # Mock methods + state = trainer.state + model = state.model.module if state.is_model_ddp else state.model + model.generate = Mock(wraps=model.generate) # type: ignore + generate_cb.generate = Mock(wraps=generate_cb.generate) + trainer.logger.log_table = Mock() + + trainer.fit() + + expected_cb_call_count = 1 + + # Assert that the generate callback has been called ONLY once + assert generate_cb.generate.call_count == expected_cb_call_count + + # An additional fit call should not trigger additional calls to generate + trainer.engine.run_event(Event.FIT_END) + assert generate_cb.generate.call_count == expected_cb_call_count