Skip to content

Commit

Permalink
Merge branch 'dev' into sv-sd
Browse files Browse the repository at this point in the history
  • Loading branch information
eracah authored Jun 17, 2024
2 parents aa05ce5 + 6023fe5 commit c4ef047
Show file tree
Hide file tree
Showing 15 changed files with 338 additions and 45 deletions.
93 changes: 79 additions & 14 deletions composer/callbacks/system_metrics_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os

import psutil
import torch

from composer.core import Callback, Event, State
from composer.loggers import Logger
Expand All @@ -19,13 +20,52 @@

__all__ = ['SystemMetricsMonitor']

_GPU_METRICS = [
'gpu_percentage',
'memory_percentage',
'gpu_temperature_C',
'gpu_power_usage_W',
]


class SystemMetricsMonitor(Callback):
"""Track system metrics."""
"""Logs GPU/CPU metrics.
GPU Metrics:
gpu_percentage: Occupancy rate, percent of time over sampling period during which one or more kernels was executing on the GPU.
memory_percentage: Percent of time over sampling period during which global memory was being read or written.
gpu_temperature_C: Temperature of device, in Celcius.
gpu_power_usage_W: Power usage of device, in Watts.
By default, only the maximum and minimum values for these metrics, alongside their respective ranks in the key names,
are logged on the :attr:`.Event.BATCH_START`, :attr:`.Event.EVAL_BATCH_START`, :attr:`.Event.PREDICT_BATCH_START`
events for every batch. If log_all_data is set to True, all values for these metrics across all ranks are logged on the
above events for every batch.
Example:
.. doctest::
def __init__(self, gpu_available: bool = False) -> None:
>>> from composer import Trainer
>>> from composer.callbacks import SystemMetricsMonitor
>>> # constructing trainer object with this callback
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_dataloader,
... optimizers=optimizer,
... max_duration='1ep',
... callbacks=[SystemMetricsMonitor()],
... )
Args:
log_all_data (bool, optional): True if user wants to log data for all ranks, not just the min/max.
Defaults to False.
"""

def __init__(self, log_all_data: bool = False) -> None:
super().__init__()
self.gpu_available = gpu_available
self.gpu_available = torch.cuda.is_available()
self.log_all_data = log_all_data
if self.gpu_available:
try:
import pynvml
Expand All @@ -46,9 +86,23 @@ def run_event(self, event: Event, state: State, logger: Logger):
]:
local_node_system_metrics = self.compute_system_metrics()
all_system_metrics = dist.all_gather_object(local_node_system_metrics)
system_metrics = {
key: value for local_metrics in all_system_metrics for key, value in local_metrics.items()
}
system_metrics = {}

if self.log_all_data:
for rank, metrics in enumerate(all_system_metrics):
for key, value in metrics.items():
if key in _GPU_METRICS:
system_metrics[f'{key}_rank_{rank}'] = value
else:
system_metrics[key] = value

else:
system_metrics = self.compute_gpu_min_max_metrics(all_system_metrics, state)
for rank, metrics in enumerate(all_system_metrics):
for key, value in metrics.items():
if key not in _GPU_METRICS:
system_metrics[key] = value

logger.log_metrics(system_metrics)

def compute_system_metrics(self):
Expand All @@ -58,17 +112,14 @@ def compute_system_metrics(self):
if self.gpu_available:
import pynvml
local_rank = dist.get_local_rank()
global_rank = dist.get_global_rank()
handle = pynvml.nvmlDeviceGetHandleByIndex(local_rank)
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
system_metrics[f'device{global_rank}_memory_total'] = memory.total
system_metrics[f'device{global_rank}_memory_free'] = memory.free
system_metrics[f'device{global_rank}_memory_used'] = memory.used
device_utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
system_metrics[f'device{global_rank}_gpu_percentage'] = device_utilization.gpu
system_metrics[f'device{global_rank}_memory_percentage'] = device_utilization.memory
system_metrics['gpu_percentage'] = device_utilization.gpu
system_metrics['memory_percentage'] = device_utilization.memory
temperature = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
system_metrics[f'device{global_rank}_gpu_temperature'] = temperature
system_metrics['gpu_temperature_C'] = temperature
power = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 # convert from mW to W
system_metrics['gpu_power_usage_W'] = power

# Get metrics for the system
cpu_percent = psutil.cpu_percent()
Expand All @@ -83,3 +134,17 @@ def compute_system_metrics(self):
for k, v in network_usage.items():
system_metrics[f'network_{k}'] = v
return system_metrics

def compute_gpu_min_max_metrics(self, all_metrics, state):
min_max_metrics = {}

if self.gpu_available:
for key in _GPU_METRICS:
values = torch.tensor([metrics_for_cur_rank[key] for metrics_for_cur_rank in all_metrics])
values = state.device.tensor_to_device(values)
min_rank = int(torch.argmin(values).item())
max_rank = int(torch.argmax(values).item())
min_max_metrics[f'min_{key}_rank_{min_rank}'] = values[min_rank].item()
min_max_metrics[f'max_{key}_rank_{max_rank}'] = values[max_rank].item()

return min_max_metrics
12 changes: 7 additions & 5 deletions composer/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,13 @@ def _parse_args():
if args.nproc < 1:
raise ValueError('The nproc must be 1 or greater')

if args.world_size is None and 'WORLD_SIZE' in os.environ:
args.world_size = int(os.environ['WORLD_SIZE'])
if args.world_size is None:
if 'WORLD_SIZE' in os.environ and os.environ.get('LOCAL_WORLD_SIZE') != os.environ['WORLD_SIZE']:
# Use WORLD_SIZE env var if set and running multinode. Otherwise, default to nproc
# to enable easy overriding of number of processes when on a single node.
args.world_size = int(os.environ['WORLD_SIZE'])
else:
args.world_size = args.nproc

if args.base_rank is None and 'BASE_RANK' in os.environ:
args.base_rank = int(os.environ['BASE_RANK'])
Expand All @@ -212,9 +217,6 @@ def _parse_args():
if args.master_port is None and 'MASTER_PORT' in os.environ:
args.master_port = int(os.environ['MASTER_PORT'])

if args.world_size is None:
args.world_size = args.nproc

if args.world_size < args.nproc:
raise ValueError(f'world_size({args.world_size}) cannot be less than nproc({args.nproc})')

Expand Down
28 changes: 15 additions & 13 deletions composer/core/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,19 +273,21 @@ def batch_end(self, state: State, logger: Logger) -> None:
The following :attr:`.State.timestamp` member variables are
incremented immediately before the :attr:`.Event.BATCH_END` event.
+------------------------------------+
| :attr:`.Timestamp.batch` |
+------------------------------------+
| :attr:`.Timestamp.batch_in_epoch` |
+------------------------------------+
| :attr:`.Timestamp.sample` |
+------------------------------------+
| :attr:`.Timestamp.sample_in_epoch` |
+------------------------------------+
| :attr:`.Timestamp.token` |
+------------------------------------+
| :attr:`.Timestamp.token_in_epoch` |
+------------------------------------+
+--------------------------------------+
| :attr:`.Timestamp.batch` |
+--------------------------------------+
| :attr:`.Timestamp.batch_in_epoch` |
+--------------------------------------+
| :attr:`.Timestamp.sample` |
+--------------------------------------+
| :attr:`.Timestamp.sample_in_epoch` |
+--------------------------------------+
| :attr:`.Timestamp.token` |
+--------------------------------------+
| :attr:`.Timestamp.token_in_epoch` |
+--------------------------------------+
| :attr:`.Timestamp.token_in_iteration`|
+--------------------------------------+
Args:
state (State): The training state.
Expand Down
2 changes: 1 addition & 1 deletion composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ def _iteration_length(self, iteration_length: Optional[Union[str, Time[int]]]):
return
if isinstance(iteration_length, str):
iteration_length = ensure_time(iteration_length, TimeUnit.EPOCH)
if iteration_length.unit != TimeUnit.EPOCH:
if iteration_length.unit != TimeUnit.EPOCH and iteration_length.unit != TimeUnit.TOKEN:
raise NotImplementedError(f'{iteration_length.unit} is not allowed as a unit for iteration_length.')
self.__iteration_length = iteration_length

Expand Down
32 changes: 32 additions & 0 deletions composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ class Timestamp(Serializable):
sample (int | Time[int], optional): The sample.
token (int | Time[int], optional): The token.
epoch_in_iteration (int | Time[int], optional): The epoch in the iteration.
token_in_iteration (int | Time[int], optional): The token in the iteration.
batch_in_epoch (int | Time[int], optional): The batch in the epoch.
sample_in_epoch (int | Time[int], optional): The sample in the epoch.
token_in_epoch (int | Time[int], optional): The token in the epoch.
Expand All @@ -490,6 +491,7 @@ def __init__(
sample: Union[int, Time[int]] = 0,
token: Union[int, Time[int]] = 0,
epoch_in_iteration: Union[int, Time[int]] = 0,
token_in_iteration: Union[int, Time[int]] = 0,
batch_in_epoch: Union[int, Time[int]] = 0,
sample_in_epoch: Union[int, Time[int]] = 0,
token_in_epoch: Union[int, Time[int]] = 0,
Expand Down Expand Up @@ -531,6 +533,14 @@ def __init__(
))
self._epoch_in_iteration = epoch_in_iteration

token_in_iteration = Time.from_input(token_in_iteration, TimeUnit.TOKEN)
if token_in_iteration.unit != TimeUnit.TOKEN:
raise ValueError((
f'The `token_in_iteration` argument has units of {token_in_iteration.unit}; '
f'not {TimeUnit.TOKEN}.'
))
self._token_in_iteration = token_in_iteration

batch_in_epoch = Time.from_input(batch_in_epoch, TimeUnit.BATCH)
if batch_in_epoch.unit != TimeUnit.BATCH:
raise ValueError(
Expand Down Expand Up @@ -579,6 +589,7 @@ def state_dict(self) -> dict[str, Any]:
'sample': self.sample.value,
'token': self.token.value,
'epoch_in_iteration': self.epoch_in_iteration.value,
'token_in_iteration': self.token_in_iteration.value,
'batch_in_epoch': self.batch_in_epoch.value,
'sample_in_epoch': self.sample_in_epoch.value,
'token_in_epoch': self.token_in_epoch.value,
Expand Down Expand Up @@ -609,6 +620,8 @@ def load_state_dict(self, state: dict[str, Any]) -> None:
self._iteration = Time(state['iteration'], TimeUnit.ITERATION)
if 'epoch_in_iteration' in state:
self._epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH)
if 'token_in_iteration' in state:
self._token_in_iteration = Time(state['token_in_iteration'], TimeUnit.TOKEN)
if 'iteration_wct' in state:
self._iteration_wct = state['iteration_wct']

Expand Down Expand Up @@ -642,6 +655,11 @@ def epoch_in_iteration(self) -> Time[int]:
"""The epoch count in the current iteration (resets at 0 at the beginning of every iteration)."""
return self._epoch_in_iteration

@property
def token_in_iteration(self) -> Time[int]:
"""The token count in the current iteration (resets at 0 at the beginning of every iteration)."""
return self._token_in_iteration

@property
def batch_in_epoch(self) -> Time[int]:
"""The batch count in the current epoch (resets at 0 at the beginning of every epoch)."""
Expand Down Expand Up @@ -814,6 +832,7 @@ def to_next_batch(
sample_in_epoch=self.sample_in_epoch + samples,
token=self.token + tokens,
token_in_epoch=self.token_in_epoch + tokens,
token_in_iteration=self.token_in_iteration + tokens,
total_wct=self.total_wct + duration,
iteration_wct=self.iteration_wct + duration,
epoch_wct=self.epoch_wct + duration,
Expand All @@ -822,6 +841,7 @@ def to_next_batch(

def to_next_epoch(
self,
tokens: Union[int, Time] = 0,
duration: Optional[datetime.timedelta] = None,
):
"""Create a new :class:`.Timestamp`, advanced to the next epoch.
Expand All @@ -841,6 +861,7 @@ def to_next_epoch(
>>> timestamp.copy(
... epoch=timestamp.epoch + 1,
... epoch_in_iteration=timestamp.epoch_in_iteration + 1,
... token_in_iteration=timestamp.token_in_iteration + tokens,
... batch_in_epoch=0,
... sample_in_epoch=0,
... token_in_epoch=0,
Expand All @@ -851,12 +872,17 @@ def to_next_epoch(
... )
Timestamp(...)
Args:
tokens (int | Time, optional): The number of tokens trained in the batch. Defaults to 0.
duration (datetime.timedelta, optional): The duration to train the batch.
"""
if duration is None:
duration = datetime.timedelta(seconds=0)
return self.copy(
epoch=self.epoch + 1,
epoch_in_iteration=self.epoch_in_iteration + 1,
token_in_iteration=self.token_in_iteration + tokens,
batch_in_epoch=0,
sample_in_epoch=0,
token_in_epoch=0,
Expand Down Expand Up @@ -886,6 +912,7 @@ def to_next_iteration(
>>> timestamp.copy(
... iteration=timestamp.iteration + 1,
... epoch_in_iteration=0,
... token_in_iteration=0,
... batch_in_epoch=0,
... sample_in_epoch=0,
... token_in_epoch=0,
Expand All @@ -902,6 +929,7 @@ def to_next_iteration(
return self.copy(
iteration=self.iteration + 1,
epoch_in_iteration=0,
token_in_iteration=0,
batch_in_epoch=0,
sample_in_epoch=0,
token_in_epoch=0,
Expand All @@ -919,6 +947,7 @@ def copy(
sample: Optional[Union[int, Time[int]]] = None,
token: Optional[Union[int, Time[int]]] = None,
epoch_in_iteration: Optional[Union[int, Time[int]]] = None,
token_in_iteration: Optional[Union[int, Time[int]]] = None,
batch_in_epoch: Optional[Union[int, Time[int]]] = None,
sample_in_epoch: Optional[Union[int, Time[int]]] = None,
token_in_epoch: Optional[Union[int, Time[int]]] = None,
Expand All @@ -938,6 +967,7 @@ def copy(
sample (int | Time[int], optional): The sample.
token (int | Time[int], optional): The token.
epoch_in_iteration (int | Time[int], optional): The epoch in the iteration.
token_in_iteration (int | Time[int], optional): The token in the iteration.
batch_in_epoch (int | Time[int], optional): The batch in the epoch.
sample_in_epoch (int | Time[int], optional): The sample in the epoch.
token_in_epoch (int | Time[int], optional): The token in the epoch.
Expand All @@ -957,6 +987,7 @@ def copy(
sample=sample if sample is not None else self.sample,
token=token if token is not None else self.token,
epoch_in_iteration=epoch_in_iteration if epoch_in_iteration is not None else self.epoch_in_iteration,
token_in_iteration=token_in_iteration if token_in_iteration is not None else self.token_in_iteration,
batch_in_epoch=batch_in_epoch if batch_in_epoch is not None else self.batch_in_epoch,
sample_in_epoch=sample_in_epoch if sample_in_epoch is not None else self.sample_in_epoch,
token_in_epoch=token_in_epoch if token_in_epoch is not None else self.token_in_epoch,
Expand All @@ -975,6 +1006,7 @@ def __repr__(self) -> str:
f'sample={int(self.sample)}, '
f'token={int(self.token)}, '
f'epoch_in_iteration={int(self.epoch_in_iteration)}, '
f'token_in_iteration={int(self.token_in_iteration)}, '
f'batch_in_epoch={int(self.batch_in_epoch)}, '
f'sample_in_epoch={int(self.sample_in_epoch)}, '
f'token_in_epoch={int(self.token_in_epoch)}, '
Expand Down
7 changes: 7 additions & 0 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class MLFlowLogger(LoggerDestination):
resume (bool, optional): If ``True``, Composer will search for an existing run tagged with
the `run_name` and resume it. If no existing run is found, a new run will be created.
If ``False``, Composer will create a new run. (default: ``False``)
logging_buffer_seconds (int, optional): The amount of time, in seconds, that MLflow
waits before sending logs to the MLflow tracking server. Metrics/params/tags logged
within this buffer time will be grouped in batches before being sent to the backend.
"""

def __init__(
Expand All @@ -85,6 +88,7 @@ def __init__(
ignore_hyperparameters: Optional[list[str]] = None,
run_group: Optional[str] = None,
resume: bool = False,
logging_buffer_seconds: Optional[int] = 10,
) -> None:
try:
import mlflow
Expand Down Expand Up @@ -116,6 +120,9 @@ def __init__(
)
self.resume = resume

if logging_buffer_seconds:
os.environ['MLFLOW_ASYNC_LOGGING_BUFFERING_SECONDS'] = str(logging_buffer_seconds)

self._rank_zero_only = rank_zero_only
self._last_flush_time = time.time()
self._flush_interval = flush_interval
Expand Down
2 changes: 2 additions & 0 deletions composer/trainer/_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,8 @@ def device_mesh__getitem__(self, mesh_dim_names: Union[str, tuple[str]]) -> 'Dev
return submesh

else:
from torch.distributed.device_mesh import _mesh_resources

def create_child_mesh(
self, parent_mesh: 'DeviceMesh', submesh_dim_names: Tuple[str, ...],
) -> 'DeviceMesh':
Expand Down
Loading

0 comments on commit c4ef047

Please sign in to comment.