Skip to content

Commit

Permalink
Init
Browse files Browse the repository at this point in the history
  • Loading branch information
maddiedawson committed Oct 20, 2023
1 parent 8827d7a commit c493ce0
Showing 1 changed file with 58 additions and 16 deletions.
74 changes: 58 additions & 16 deletions streaming/base/converters/dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
import shutil
import time
from collections.abc import Iterable
from typing import Any, Callable, Dict, Iterable, Optional, Tuple

Expand Down Expand Up @@ -224,8 +225,8 @@ def write_mds(iterator: Iterable):
],
axis=1)

if dataframe is None or dataframe.isEmpty():
raise ValueError(f'Input dataframe is None or Empty!')
if dataframe is None:
raise ValueError(f'Input dataframe must be provided')

if not mds_kwargs:
mds_kwargs = {}
Expand Down Expand Up @@ -261,6 +262,9 @@ def write_mds(iterator: Iterable):
if cu.remote is None:
mds_path = (cu.local, '')
else:
if dataframe.isStreaming:
raise ValueError(
'dataframe_to_mds currently only supports outputting to a local directory')
mds_path = (cu.local, cu.remote)

# Prepare partition schema
Expand All @@ -269,21 +273,59 @@ def write_mds(iterator: Iterable):
StructField('mds_path_remote', StringType(), False),
StructField('fail_count', IntegerType(), False)
])
partitions = dataframe.mapInPandas(func=write_mds, schema=result_schema).collect()
mapped_df = dataframe.mapInPandas(func=write_mds, schema=result_schema)

if mapped_df.isStreaming:

def merge_and_log(df: DataFrame, batch_id: int):
partitions = df.collect()
if len(partitions) == 0:
return

if merge_index:
index_files = [
(row['mds_path_local'], row['mds_path_remote']) for row in partitions
]
lock_file_path = os.path.join(out, '.merge.lock')
# Acquire the lock.
while True:
try:
fd = os.open(lock_file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
except OSError:
time.sleep(1) # File already exists, wait and try again
else:
break
do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60)
# Release the lock.
os.close(fd)

sum_fail_count = 0
for row in partitions:
sum_fail_count += row['fail_count']

if sum_fail_count > 0:
logger.warning(
f'[Batch #{batch_id}] Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}'
)

mapped_df.writeStream.foreachBatch(merge_and_log).start()
return None, 0
else:
partitions = mapped_df.collect()

if merge_index:
index_files = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions]
do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60)
if merge_index:
index_files = [(row['mds_path_local'], row['mds_path_remote']) for row in partitions]
do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60)

if cu.remote is not None:
if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False:
shutil.rmtree(cu.local, ignore_errors=True)
if cu.remote is not None:
if 'keep_local' in mds_kwargs and mds_kwargs['keep_local'] == False:
shutil.rmtree(cu.local, ignore_errors=True)

sum_fail_count = 0
for row in partitions:
sum_fail_count += row['fail_count']
sum_fail_count = 0
for row in partitions:
sum_fail_count += row['fail_count']

if sum_fail_count > 0:
logger.warning(
f'Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}')
return mds_path, sum_fail_count
if sum_fail_count > 0:
logger.warning(
f'Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}')
return mds_path, sum_fail_count

0 comments on commit c493ce0

Please sign in to comment.