diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 9f799cb547..b4d4143006 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -12,8 +12,9 @@ import torch from composer import algorithms from composer.callbacks import (EarlyStopper, Generate, LRMonitor, - MemoryMonitor, MemorySnapshot, OptimizerMonitor, - RuntimeEstimator, SpeedMonitor) + MemoryMonitor, MemorySnapshot, OOMObserver, + OptimizerMonitor, RuntimeEstimator, + SpeedMonitor) from composer.core import Algorithm, Callback, Evaluator from composer.datasets.in_context_learning_evaluation import \ get_icl_task_dataloader @@ -165,6 +166,8 @@ def build_callback( return LRMonitor() elif name == 'memory_monitor': return MemoryMonitor() + elif name == 'oom_observer': + return OOMObserver(**kwargs) elif name == 'memory_snapshot': return MemorySnapshot(**kwargs) elif name == 'speed_monitor':