diff --git a/streaming/base/converters/dataframe_to_mds.py b/streaming/base/converters/dataframe_to_mds.py index 1d62ef4e7..1adf07eee 100644 --- a/streaming/base/converters/dataframe_to_mds.py +++ b/streaming/base/converters/dataframe_to_mds.py @@ -6,6 +6,7 @@ import logging import os import shutil +import time from collections.abc import Iterable from typing import Any, Callable, Dict, Iterable, Optional, Tuple @@ -224,8 +225,8 @@ def write_mds(iterator: Iterable): ], axis=1) - if dataframe is None or dataframe.isEmpty(): - raise ValueError(f'Input dataframe is None or Empty!') + if dataframe is None: + raise ValueError(f'Input dataframe must be provided') if not mds_kwargs: mds_kwargs = {} @@ -261,6 +262,9 @@ def write_mds(iterator: Iterable): if cu.remote is None: mds_path = (cu.local, '') else: + if dataframe.isStreaming: + raise ValueError( + 'dataframe_to_mds currently only supports outputting to a local directory') mds_path = (cu.local, cu.remote) # Prepare partition schema @@ -269,21 +273,59 @@ def write_mds(iterator: Iterable): StructField('mds_path_remote', StringType(), False), StructField('fail_count', IntegerType(), False) ]) - partitions = dataframe.mapInPandas(func=write_mds, schema=result_schema).collect() + mapped_df = dataframe.mapInPandas(func=write_mds, schema=result_schema) + + if mapped_df.isStreaming: + + def merge_and_log(df: DataFrame, batch_id: int): + partitions = df.collect() + if len(partitions) == 0: + return + + if merge_index: + index_files = [ + (row['mds_path_local'], row['mds_path_remote']) for row in partitions + ] + lock_file_path = os.path.join(out, '.merge.lock') + # Acquire the lock. + while True: + try: + fd = os.open(lock_file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY) + except OSError: + time.sleep(1) # File already exists, wait and try again + else: + break + do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60) + # Release the lock. + os.close(fd) + + sum_fail_count = 0 + for row in partitions: + sum_fail_count += row['fail_count'] + + if sum_fail_count > 0: + logger.warning( + f'[Batch #{batch_id}] Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}' + ) + + mapped_df.writeStream.foreachBatch(merge_and_log).start() + return None, 0 + else: + partitions = mapped_df.collect() - if merge_index: - index_files = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions] - do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60) + if merge_index: + index_files = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions] + do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60) - if cu.remote is not None: - if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False: - shutil.rmtree(cu.local, ignore_errors=True) + if cu.remote is not None: + if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False: + shutil.rmtree(cu.local, ignore_errors=True) - sum_fail_count = 0 - for row in partitions: - sum_fail_count += row['fail_count'] + sum_fail_count = 0 + for row in partitions: + sum_fail_count += row['fail_count'] - if sum_fail_count > 0: - logger.warning( - f'Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}') - return mds_path, sum_fail_count + if sum_fail_count > 0: + logger.warning( + f'Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}') + return mds_path, sum_fail_count