Skip to content

Commit

Permalink
Deprecate get_state and remove deprecations
Browse files Browse the repository at this point in the history
commit-id:5d87db4b
  • Loading branch information
b-chu committed Feb 15, 2024
1 parent aff9f48 commit 4109fa9
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 78 deletions.
18 changes: 2 additions & 16 deletions composer/callbacks/mlperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,22 +263,8 @@ def _get_dataloader_stats(self, dataloader: Iterable):
if isinstance(dataloader.dataset, IterableDataset):
num_samples *= dist.get_world_size()
return (dataloader.batch_size, num_samples)
try:
# attempt to import ffcv and test if its an ffcv loader.
import ffcv # type: ignore

warnings.warn(DeprecationWarning('ffcv is deprecated and will be removed in v0.18'))

if isinstance(dataloader, ffcv.loader.Loader):
# Use the cached attribute ffcv.init_traversal_order to compute number of samples
return (
dataloader.batch_size, # type: ignore
len(dataloader.next_traversal_order()) * dist.get_world_size() # type: ignore
)
except ImportError:
pass

raise TypeError(f'torch dataloader or ffcv dataloader required (and ffcv installed)')

raise TypeError(f'torch dataloader required')

def fit_start(self, state: State, logger: Logger) -> None:
if _global_rank_zero():
Expand Down
43 changes: 0 additions & 43 deletions composer/callbacks/utils.py

This file was deleted.

3 changes: 2 additions & 1 deletion composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,8 @@ def fsdp_sharded_state_dict_enabled(self):

@property
def fsdp_elastic_sharded_enabled(self):
warnings.warn('state.fsdp_elastic_sharded_enabled is deprecated and will be removed v0.21.0')
warnings.warn('state.fsdp_elastic_sharded_enabled is deprecated and will be removed v0.21.0',
DeprecationWarning)
return self.fsdp_sharded_state_dict_enabled

@property
Expand Down
18 changes: 3 additions & 15 deletions composer/core/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import datetime
import re
import warnings
from typing import Any, Dict, Generic, Optional, TypeVar, Union, cast

from composer.core.serializable import Serializable
Expand Down Expand Up @@ -532,21 +533,8 @@ def get_state(self) -> Dict[str, Union[Time[int], datetime.timedelta]]:
Returns:
Dict[str, Union[Time[int], datetime.timedelta]]: All values of the timestamp object.
"""
return {
'iteration': self.iteration,
'epoch': self.epoch,
'batch': self.batch,
'sample': self.sample,
'token': self.token,
'epoch_in_iteration': self.epoch_in_iteration,
'batch_in_epoch': self.batch_in_epoch,
'sample_in_epoch': self.sample_in_epoch,
'token_in_epoch': self.token_in_epoch,
'total_wct': self.total_wct,
'iteration_wct': self.iteration_wct,
'epoch_wct': self.epoch_wct,
'batch_wct': self.batch_wct,
}
warnings.warn('core.time.Timestamp.get_state is deprecated and will be removed v0.21.0', DeprecationWarning)
return self.state_dict()

def load_state_dict(self, state: Dict[str, Any]) -> None:
self._epoch = Time(state['epoch'], TimeUnit.EPOCH)
Expand Down
5 changes: 2 additions & 3 deletions composer/loggers/in_memory_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import numpy as np
from torch import Tensor

from composer.core.time import Time
from composer.loggers.logger import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils.import_helpers import MissingConditionalImportError
Expand Down Expand Up @@ -157,8 +156,8 @@ def get_timeseries(self, metric: str) -> Dict[str, Any]:
timestamp, metric_value = datapoint
timeseries.setdefault(metric, []).append(metric_value)
# Iterate through time units and add them all!
for field, time in timestamp.get_state().items():
time_value = time.value if isinstance(time, Time) else time.total_seconds()
for field, time in timestamp.state_dict().items():
time_value = time if isinstance(time, int) else time.total_seconds()
timeseries.setdefault(field, []).append(time_value)
# Convert to numpy arrays
for k, v in timeseries.items():
Expand Down

0 comments on commit 4109fa9

Please sign in to comment.