Skip to content

Commit

Permalink
min/max flagging added to system_metrics_monitor with only non-redund…
Browse files Browse the repository at this point in the history
…ant, necessary gpu metrics logged (#3373)

* implemented min_max flag

* fixed string parsing

* refactoring compute_system_metrics for all_reduce

* keep track of rank within dict

* added compute_min_max

* added flag for both min_max and all_logging

* corrected min_max call with model_device

* removing total bytes (always going ot be constant)

* handled no gpu case in min_max flag

* removed unnecessary imports, patched unit tests

* fixed assert statement for with gpu case, world size 1

* case min_rank and max_rank as int to guarantee them working as indices

* fixed indent issue from fixing font

* made docs more concise and readable

* fixing unexpected unindent

* fixing unit test device

* modifying device to equal model_device.type

* reverting to device=model_device

* setting device in unit test = 'gpu'

* setting device = 'cuda' in unit testing

* reverting to next(state.model.parameters()).device

* removed torch as a dependecy for unit_testing

* cleaned up UI to be consistent + removed calling next to obtain device

---------

Co-authored-by: Mihir Patel <[email protected]>
Co-authored-by: Charles Tang <[email protected]>
  • Loading branch information
3 people committed Jun 17, 2024
1 parent 8119c14 commit cca51e2
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 16 deletions.
93 changes: 79 additions & 14 deletions composer/callbacks/system_metrics_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os

import psutil
import torch

from composer.core import Callback, Event, State
from composer.loggers import Logger
Expand All @@ -19,13 +20,52 @@

__all__ = ['SystemMetricsMonitor']

_GPU_METRICS = [
'gpu_percentage',
'memory_percentage',
'gpu_temperature_C',
'gpu_power_usage_W',
]


class SystemMetricsMonitor(Callback):
"""Track system metrics."""
"""Logs GPU/CPU metrics.
GPU Metrics:
gpu_percentage: Occupancy rate, percent of time over sampling period during which one or more kernels was executing on the GPU.
memory_percentage: Percent of time over sampling period during which global memory was being read or written.
gpu_temperature_C: Temperature of device, in Celcius.
gpu_power_usage_W: Power usage of device, in Watts.
By default, only the maximum and minimum values for these metrics, alongside their respective ranks in the key names,
are logged on the :attr:`.Event.BATCH_START`, :attr:`.Event.EVAL_BATCH_START`, :attr:`.Event.PREDICT_BATCH_START`
events for every batch. If log_all_data is set to True, all values for these metrics across all ranks are logged on the
above events for every batch.
Example:
.. doctest::
def __init__(self, gpu_available: bool = False) -> None:
>>> from composer import Trainer
>>> from composer.callbacks import SystemMetricsMonitor
>>> # constructing trainer object with this callback
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... optimizers=optimizer,
... max_duration='1ep',
... callbacks=[SystemMetricsMonitor()],
... )
Args:
log_all_data (bool, optional): True if user wants to log data for all ranks, not just the min/max.
Defaults to False.
"""

def __init__(self, log_all_data: bool = False) -> None:
super().__init__()
self.gpu_available = gpu_available
self.gpu_available = torch.cuda.is_available()
self.log_all_data = log_all_data
if self.gpu_available:
try:
import pynvml
Expand All @@ -46,9 +86,23 @@ def run_event(self, event: Event, state: State, logger: Logger):
]:
local_node_system_metrics = self.compute_system_metrics()
all_system_metrics = dist.all_gather_object(local_node_system_metrics)
system_metrics = {
key: value for local_metrics in all_system_metrics for key, value in local_metrics.items()
}
system_metrics = {}

if self.log_all_data:
for rank, metrics in enumerate(all_system_metrics):
for key, value in metrics.items():
if key in _GPU_METRICS:
system_metrics[f'{key}_rank_{rank}'] = value
else:
system_metrics[key] = value

else:
system_metrics = self.compute_gpu_min_max_metrics(all_system_metrics, state)
for rank, metrics in enumerate(all_system_metrics):
for key, value in metrics.items():
if key not in _GPU_METRICS:
system_metrics[key] = value

logger.log_metrics(system_metrics)

def compute_system_metrics(self):
Expand All @@ -58,17 +112,14 @@ def compute_system_metrics(self):
if self.gpu_available:
import pynvml
local_rank = dist.get_local_rank()
global_rank = dist.get_global_rank()
handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank)
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
system_metrics[f'device{global_rank}_memory_total'] = memory.total
system_metrics[f'device{global_rank}_memory_free'] = memory.free
system_metrics[f'device{global_rank}_memory_used'] = memory.used
device_utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
system_metrics[f'device{global_rank}_gpu_percentage'] = device_utilization.gpu
system_metrics[f'device{global_rank}_memory_percentage'] = device_utilization.memory
system_metrics['gpu_percentage'] = device_utilization.gpu
system_metrics['memory_percentage'] = device_utilization.memory
temperature = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
system_metrics[f'device{global_rank}_gpu_temperature'] = temperature
system_metrics['gpu_temperature_C'] = temperature
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 # convert from mW to W
system_metrics['gpu_power_usage_W'] = power

# Get metrics for the system
cpu_percent = psutil.cpu_percent()
Expand All @@ -83,3 +134,17 @@ def compute_system_metrics(self):
for k, v in network_usage.items():
system_metrics[f'network_{k}'] = v
return system_metrics

def compute_gpu_min_max_metrics(self, all_metrics, state):
min_max_metrics = {}

if self.gpu_available:
for key in _GPU_METRICS:
values = torch.tensor([metrics_for_cur_rank[key] for metrics_for_cur_rank in all_metrics])
values = state.device.tensor_to_device(values)
min_rank = int(torch.argmin(values).item())
max_rank = int(torch.argmax(values).item())
min_max_metrics[f'min_{key}_rank_{min_rank}'] = values[min_rank].item()
min_max_metrics[f'max_{key}_rank_{max_rank}'] = values[max_rank].item()

return min_max_metrics
4 changes: 2 additions & 2 deletions tests/callbacks/test_system_metrics_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@pytest.mark.gpu
def test_system_metrics_monitor_gpu():
# Construct the trainer
system_metrics_monitor = SystemMetricsMonitor(gpu_available=True)
system_metrics_monitor = SystemMetricsMonitor()
in_memory_logger = InMemoryLogger()
trainer = Trainer(
model=SimpleModel(),
Expand All @@ -24,7 +24,7 @@ def test_system_metrics_monitor_gpu():
)
trainer.fit()

assert 'device0_gpu_percentage' in in_memory_logger.data
assert 'min_gpu_percentage_rank_0' in in_memory_logger.data
assert 'cpu_percentage' in in_memory_logger.data


Expand Down

0 comments on commit cca51e2

Please sign in to comment.