diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index f26eaad38..5ab7863db 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -463,6 +463,7 @@ def generate_metric_module( batch_size=batch_size, world_size=world_size, window_seconds=metrics_config.throughput_metric.window_size, + warmup_steps=metrics_config.throughput_metric.warmup_steps, ) else: throughput_metric = None diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index f77565646..192db371b 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -80,6 +80,7 @@ class RecComputeMode(Enum): _DEFAULT_WINDOW_SIZE = 10_000_000 _DEFAULT_THROUGHPUT_WINDOW_SECONDS = 100 +_DEFAULT_THROUGHPUT_WARMUP_STEPS = 100 @dataclass @@ -113,6 +114,7 @@ class StateMetricEnum(StrValueMixin, Enum): @dataclass class ThroughputDef: window_size: int = _DEFAULT_THROUGHPUT_WINDOW_SECONDS + warmup_steps: int = _DEFAULT_THROUGHPUT_WARMUP_STEPS @dataclass