Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
maddiedawson committed Nov 2, 2023
1 parent c493ce0 commit 8b2e560
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
21 changes: 9 additions & 12 deletions streaming/base/converters/dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,10 @@ def write_mds(iterator: Iterable):
def merge_and_log(df: DataFrame, batch_id: int):
partitions = df.collect()
if len(partitions) == 0:
logger.warning(f'[Batch #{batch_id}] No records to write')
return

if merge_index:
index_files = [
(row['mds_path_local'], row['mds_path_remote']) for row in partitions
]
lock_file_path = os.path.join(out, '.merge.lock')
# Acquire the lock.
while True:
Expand All @@ -295,20 +293,19 @@ def merge_and_log(df: DataFrame, batch_id: int):
time.sleep(1) # File already exists, wait and try again
else:
break
do_merge_index(index_files, out, keep_local=keep_local, download_timeout=60)
do_merge_index(out, keep_local=keep_local, download_timeout=60)
# Release the lock.
os.close(fd)
os.remove(lock_file_path)

sum_fail_count = 0
for row in partitions:
sum_fail_count += row['fail_count']
logger.warning(f"[Batch #{batch_id}] {row['fail_count']} failed record(s) for {row['mds_path_local']}")

if sum_fail_count > 0:
logger.warning(
f'[Batch #{batch_id}] Total failed records = {sum_fail_count}\nOverall records {dataframe.count()}'
)

mapped_df.writeStream.foreachBatch(merge_and_log).start()
mapped_df \
.writeStream \
.foreachBatch(merge_and_log) \
.start() \
.awaitTermination()
return None, 0
else:
partitions = mapped_df.collect()
Expand Down
4 changes: 3 additions & 1 deletion streaming/base/format/mds/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
""":class:`MDSWriter` writes samples to ``.mds`` files that can be read by :class:`MDSReader`."""

import json
import time
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -123,7 +124,8 @@ def get_config(self) -> Dict[str, Any]:
obj.update({
'column_names': self.column_names,
'column_encodings': self.column_encodings,
'column_sizes': self.column_sizes
'column_sizes': self.column_sizes,
'write_timestamp': time.time(),
})
return obj

Expand Down
4 changes: 3 additions & 1 deletion streaming/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,9 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]],

# Move merged index from temp path to local part in out
# Upload merged index to remote if out has remote part
shutil.move(merged_index_path, cu.local)
dst_index_path = os.path.join(cu.local, os.path.basename(merged_index_path))
shutil.copy(merged_index_path, dst_index_path)
os.remove(merged_index_path)
if cu.remote is not None:
cu.upload_file(index_basename)

Expand Down

0 comments on commit 8b2e560

Please sign in to comment.