Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved local directory creation and existence check from CloudUploader to Writer class #520

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion streaming/base/format/base/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class Writer(ABC):
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
exist_ok (bool): If the local directory exists and not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` overwrites the
content. Defaults to `False`.
"""

format: str = '' # Name of the format (like "mds", "csv", "json", etc).
Expand Down Expand Up @@ -100,7 +103,8 @@ def __init__(self,

# Validate keyword arguments
invalid_kwargs = [
arg for arg in kwargs.keys() if arg not in ('progress_bar', 'max_workers', 'retry')
arg for arg in kwargs.keys()
if arg not in ('progress_bar', 'max_workers', 'retry', 'exist_ok')
]
if invalid_kwargs:
raise ValueError(f'Invalid Writer argument(s): {invalid_kwargs} ')
Expand All @@ -120,6 +124,18 @@ def __init__(self,
kwargs.get('retry', 2))
self.local = self.cloud_writer.local
self.remote = self.cloud_writer.remote

if os.path.exists(self.local) and len(os.listdir(self.local)) != 0:
if kwargs.get('exist_ok', False):
logger.warning(f'Directory {self.local} exists and not empty since you provided ' +
f'`exist_ok=True`.')
else:
raise FileExistsError(f'Directory is not empty: {self.local}. If you still want ' +
f'to use this directory without emptying the content, ' +
f'please provide `exist_ok=True`.')
# Create the local directory if it does not exist.
os.makedirs(self.local, exist_ok=True)

# `max_workers`: The maximum number of threads that can be executed in parallel.
# One thread is responsible for uploading one shard file to a remote location.
self.executor = ThreadPoolExecutor(max_workers=kwargs.get('max_workers', None))
Expand Down Expand Up @@ -380,6 +396,9 @@ class JointWriter(Writer):
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
exist_ok (bool): If the local directory exists and not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` overwrites the
content. Defaults to `False`.
"""

def __init__(self,
Expand Down Expand Up @@ -466,6 +485,9 @@ class SplitWriter(Writer):
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
exist_ok (bool): If the local directory exists and not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` overwrites the
content. Defaults to `False`.
"""

extra_bytes_per_shard = 0
Expand Down
5 changes: 5 additions & 0 deletions streaming/base/format/json/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class JSONWriter(SplitWriter):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
exist_ok (bool): If the local directory exists and not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` overwrites the
content. Defaults to `False`.
"""

format = 'json'
Expand Down
5 changes: 5 additions & 0 deletions streaming/base/format/mds/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class MDSWriter(JointWriter):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
exist_ok (bool): If the local directory exists and not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` overwrites the
content. Defaults to `False`.
"""

format = 'mds'
Expand Down
15 changes: 15 additions & 0 deletions streaming/base/format/xsv/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class XSVWriter(SplitWriter):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
exist_ok (bool): If the local directory exists and not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` overwrites the
content. Defaults to `False`.
"""

format = 'xsv'
Expand Down Expand Up @@ -164,6 +169,11 @@ class CSVWriter(XSVWriter):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
exist_ok (bool): If the local directory exists and not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` overwrites the
content. Defaults to `False`.
"""

format = 'csv'
Expand Down Expand Up @@ -230,6 +240,11 @@ class TSVWriter(XSVWriter):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
exist_ok (bool): If the local directory exists and not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` overwrites the
content. Defaults to `False`.
"""

format = 'tsv'
Expand Down
87 changes: 22 additions & 65 deletions streaming/base/storage/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def get(cls,
out: Union[str, Tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> Any:
retry: int = 2) -> Any:
"""Instantiate a cloud provider uploader or a local uploader based on remote path.

Args:
Expand All @@ -75,8 +74,6 @@ def get(cls,
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.

Returns:
CloudUploader: An instance of sub-class.
Expand All @@ -89,8 +86,8 @@ def get(cls,
prefix = os.path.join(path.parts[0], path.parts[1])
if prefix == 'dbfs:/Volumes':
provider_prefix = prefix
return getattr(sys.modules[__name__],
UPLOADERS[provider_prefix])(out, keep_local, progress_bar, retry, exist_ok)
return getattr(sys.modules[__name__], UPLOADERS[provider_prefix])(out, keep_local,
progress_bar, retry)

def _validate(self, out: Union[str, Tuple[str, str]]) -> None:
"""Validate the `out` argument.
Expand Down Expand Up @@ -124,8 +121,7 @@ def __init__(self,
out: Union[str, Tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
retry: int = 2) -> None:
"""Initialize and validate local and remote path.

Args:
Expand All @@ -142,8 +138,6 @@ def __init__(self,
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.

Raises:
FileExistsError: Local directory must be empty.
Expand All @@ -166,16 +160,6 @@ def __init__(self,
self.local = out[0]
self.remote = out[1]

if os.path.exists(self.local) and len(os.listdir(self.local)) != 0:
if not exist_ok:
raise FileExistsError(f'Directory is not empty: {self.local}')
else:
logger.warning(
f'Directory {self.local} exists and not empty. But continue to mkdir since exist_ok is set to be True.'
)

os.makedirs(self.local, exist_ok=True)

def upload_file(self, filename: str):
"""Upload file from local instance to remote instance.

Expand Down Expand Up @@ -225,17 +209,14 @@ class S3Uploader(CloudUploader):
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""

def __init__(self,
out: Union[str, Tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
retry: int = 2) -> None:
super().__init__(out, keep_local, progress_bar, retry)

import boto3
from botocore.config import Config
Expand Down Expand Up @@ -346,17 +327,14 @@ class GCSUploader(CloudUploader):
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""

def __init__(self,
out: Union[str, Tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
retry: int = 2) -> None:
super().__init__(out, keep_local, progress_bar, retry)
if 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ:
import boto3

Expand Down Expand Up @@ -494,17 +472,14 @@ class OCIUploader(CloudUploader):
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""

def __init__(self,
out: Union[str, Tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
retry: int = 2) -> None:
super().__init__(out, keep_local, progress_bar, retry)

import oci

Expand Down Expand Up @@ -631,17 +606,14 @@ class AzureUploader(CloudUploader):
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""

def __init__(self,
out: Union[str, Tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
retry: int = 2) -> None:
super().__init__(out, keep_local, progress_bar, retry)

from azure.storage.blob import BlobServiceClient

Expand Down Expand Up @@ -719,17 +691,14 @@ class AzureDataLakeUploader(CloudUploader):
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""

def __init__(self,
out: Union[str, Tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
retry: int = 2) -> None:
super().__init__(out, keep_local, progress_bar, retry)

from azure.storage.filedatalake import DataLakeServiceClient

Expand Down Expand Up @@ -804,17 +773,14 @@ class DatabricksUploader(CloudUploader):
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""

def __init__(self,
out: Union[str, Tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
retry: int = 2) -> None:
super().__init__(out, keep_local, progress_bar, retry)
self.client = self._create_workspace_client()

def _create_workspace_client(self):
Expand Down Expand Up @@ -843,17 +809,14 @@ class DatabricksUnityCatalogUploader(DatabricksUploader):
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""

def __init__(self,
out: Union[str, Tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
retry: int = 2) -> None:
super().__init__(out, keep_local, progress_bar, retry)

def upload_file(self, filename: str):
"""Upload file from local instance to Databricks Unity Catalog.
Expand Down Expand Up @@ -892,17 +855,14 @@ class DBFSUploader(DatabricksUploader):
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""

def __init__(self,
out: Union[str, Tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
retry: int = 2) -> None:
super().__init__(out, keep_local, progress_bar, retry)
self.dbfs_path = self.remote.lstrip('dbfs:') # pyright: ignore
self.check_folder_exists()

Expand Down Expand Up @@ -962,17 +922,14 @@ class LocalUploader(CloudUploader):
progress_bar (bool): Display TQDM progress bars for uploading output dataset files to
a remote location. Default to ``False``.
retry (int): Number of times to retry uploading a file. Defaults to ``2``.
exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already
exists and has contents. Defaults to ``False``.
"""

def __init__(self,
out: Union[str, Tuple[str, str]],
keep_local: bool = False,
progress_bar: bool = False,
retry: int = 2,
exist_ok: bool = False) -> None:
super().__init__(out, keep_local, progress_bar, retry, exist_ok)
retry: int = 2) -> None:
super().__init__(out, keep_local, progress_bar, retry)
# Create remote directory if it doesn't exist
if self.remote:
os.makedirs(self.remote, exist_ok=True)
Expand Down
8 changes: 4 additions & 4 deletions streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]],
# This is the index json file name, e.g., it is index.json as of 0.6.0
index_basename = get_index_basename()

cu = CloudUploader.get(out, keep_local=True, exist_ok=True)
cu = CloudUploader.get(out, keep_local=True)

# Remove duplicates, and strip '/' from right if any
index_file_urls = list(OrderedDict.fromkeys(index_file_urls))
Expand All @@ -297,7 +297,7 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]],

# Prepare a temp folder to download index.json from remote if necessary. Removed in the end.
with tempfile.TemporaryDirectory() as temp_root:
logging.warning(f'A temporary folder {temp_root} is created to store index files')
logging.debug(f'A temporary folder {temp_root} is created to store index files')

# Copy files to a temporary directory. Download if necessary
partitions = []
Expand Down Expand Up @@ -394,10 +394,10 @@ def not_merged_index(index_file_path: str, out: str):
logger.warning('No MDS dataset folder specified, no index merged')
return

cu = CloudUploader.get(out, exist_ok=True, keep_local=True)
cu = CloudUploader.get(out, keep_local=True)

local_index_files = []
cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True)
cl = CloudUploader.get(cu.local, keep_local=True)
for file in cl.list_objects():
if file.endswith('.json') and not_merged_index(file, cu.local):
local_index_files.append(file)
Expand Down
Loading
Loading