Skip to content

Commit

Permalink
Support for storage options for both DCP OSS and Model Store paths (#900
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #900

Support for storage options for both DCP OSS and Model Store paths

Differential Revision: D62599192
  • Loading branch information
saumishr authored and facebook-github-bot committed Sep 13, 2024
1 parent 57a4279 commit db9ba21
Showing 1 changed file with 85 additions and 1 deletion.
86 changes: 85 additions & 1 deletion torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,22 @@
import os
import re
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
from functools import total_ordering
from operator import xor
from typing import Dict, List, Literal, Optional, Pattern, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Pattern, Tuple, Union

import fsspec
import torch.distributed as dist
from fsspec.core import url_to_fs
from pyre_extensions import none_throws
from torchsnapshot.fb.storage_plugins.manifold import (
_DEFAULT_ALLOW_OVERWRITES,
_DEFAULT_IS_LOCALHOST,
_DEFAULT_MAX_PARALLEL,
_DEFAULT_TTL,
)
from torchtnt.utils.distributed import PGWrapper, rank_zero_read_and_broadcast

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -754,3 +761,80 @@ def _metadata_exists(
fs: fsspec.AbstractFileSystem, dirpath: str, metadata_fname: str
) -> bool:
return fs.exists(os.path.join(dirpath, metadata_fname))


def parse_storage_options_for_dcp(
storage_options: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
# storage options are expecting strings from here: https://fburl.com/code/wcj155qh
# but dcp is using this API: https://fburl.com/code/hs87ix7o, https://fburl.com/code/74kpaccv
# This helper maps the TorchSnapshot storage options to the DCP OSS options

if storage_options is None:
return {}

parsed_storage_options: Dict[str, Any] = {}

tss_to_manifold = {
"apiKey": "api_key",
"directRead": "direct_read",
"readMaxParallel": "max_parallel",
"writeMaxParallel": "max_parallel",
}
for option_name, option_value in storage_options.items():
option_name = tss_to_manifold.get(option_name, option_name)
parsed_storage_options[option_name] = option_value

return parsed_storage_options


def parse_storage_options_for_dcp_modelstore(
storage_options: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
# Storage options are expecting strings from here: https://fburl.com/code/wcj155qh
# but DCP with Model Store is using these options: https://fburl.com/code/jnvbbv67
# This helper maps the TorchSnapshot storage options to the DCP Model Store options

if storage_options is None:
return {}

parsed_storage_options: Dict[str, Any] = {}

tss_to_manifold = {
"apiKey": "api_key",
"readMaxParallel": "read_concurrency",
"writeMaxParallel": "write_concurrency",
"ttl": "TTL",
"localhost": "localhost",
"port": "port",
"allowOverwrites": "allow_overwrites",
}

for option_name, option_value in storage_options.items():
if option_name in tss_to_manifold:
option_name = tss_to_manifold.get(option_name, option_name)
parsed_storage_options[option_name] = option_value

tss_defaults = {
"read_concurrency": _DEFAULT_MAX_PARALLEL,
"write_concurrency": _DEFAULT_MAX_PARALLEL,
"TTL": _DEFAULT_TTL,
"localhost": _DEFAULT_IS_LOCALHOST,
"allow_overwrites": _DEFAULT_ALLOW_OVERWRITES,
}

# If any of the defaults are not set, we need to add them
# This needs to be consistent with the default storage options
# define in torchsnapshot (https://fburl.com/code/wcj155qh) to
# ensure consistency.
for option_name, option_default in tss_defaults.items():
if option_name not in parsed_storage_options:
parsed_storage_options[option_name] = option_default

# Model Store components expect the TTL to be in total seconds
if isinstance(parsed_storage_options["TTL"], timedelta):
parsed_storage_options["TTL"] = int(
parsed_storage_options["TTL"].total_seconds()
)

return parsed_storage_options

0 comments on commit db9ba21

Please sign in to comment.