From d8227db500ae1c4d391982b0aa6e898f66fb5161 Mon Sep 17 00:00:00 2001 From: xiaohanzhangcmu Date: Sun, 2 Jun 2024 00:33:14 -0700 Subject: [PATCH] update --- streaming/base/stream.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/streaming/base/stream.py b/streaming/base/stream.py index c3416de03..0e4b91b12 100644 --- a/streaming/base/stream.py +++ b/streaming/base/stream.py @@ -510,6 +510,30 @@ def get_index_size(self) -> int: filename = os.path.join(self.local, self.split, get_index_basename()) return os.stat(filename).st_size +import json +import os + +def save_dict_to_file(directory, filename, dictionary): + """Save a dictionary to a file in the specified directory.""" + if not os.path.exists(directory): + os.makedirs(directory) + + file_path = os.path.join(directory, filename) + with open(file_path, 'w') as file: + json.dump(dictionary, file, indent=4) + print(f"Dictionary saved to {file_path}") + +def load_dict_from_file(directory, filename): + """Load a dictionary from a file in the specified directory.""" + file_path = os.path.join(directory, filename) + if not os.path.exists(file_path): + raise FileNotFoundError(f"No such file: '{file_path}'") + + with open(file_path, 'r') as file: + dictionary = json.load(file) + print(f"Dictionary loaded from {file_path}") + return dictionary + class DeltaStream(Stream): @@ -547,6 +571,7 @@ def generate_unique_basename(self, url: str, index: int) -> str: basename = '.'.join(['shard', f'{index:05}', 'mds']) self.url_to_basename[url] = basename self.basename_to_url[basename] = url + return basename def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: @@ -647,6 +672,8 @@ def get_shards(self, world: World, allow_unsafe_types: bool) -> List[Reader]: shard.validate(allow_unsafe_types) shards.append(shard) + save_dict_to_file('./', 'basename_to_url.json', self.basename_to_url) + return shards def _download_file(self, from_basename: str, to_basename: Optional[str] = None) -> str: