Skip to content

Commit

Permalink
[BUG] Add retries to pyarrow write_dataset call (#2445)
Browse files Browse the repository at this point in the history
Multipart uploads to Cloudflare R2 intermittently fail with an
`InvalidPart` error ([more info about the error
here](https://docs.aws.amazon.com/AmazonS3/latest/API/API_CompleteMultipartUpload.html)).
I have confirmed that this error does not occur when the write is
retried, so this PR adds retry logic to `write_dataset` to fix this
issue.

Tested with this code:
```python
import daft
import pyarrow as pa
from tqdm import trange

def main():
    daft.context.set_runner_ray()

    df = daft.from_pydict({"a": list(range(10_000_000 * 16))})
    df = df.into_partitions(16)
    table = df.to_arrow()

    s3 = daft.io.S3Config(
        endpoint_url="...", 
        key_id="...", 
        access_key="...", 
        region_name="auto"
    )
    io_config = daft.io.IOConfig(s3=s3)

    for i in trange(1_000):
        path = f"s3://eventual-public-data/kevin-test/{i}.parquet"
        df.write_parquet(path, io_config=io_config)
        written_df = daft.read_parquet(path, io_config=io_config)
        written_df = written_df.sort("a")
        written_arrow = written_df.to_arrow()
        assert written_arrow.equals(table)

if __name__ == "__main__":
    main()
```
  • Loading branch information
kevinzwang authored Jun 28, 2024
1 parent b546ab8 commit cfc6505
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import contextlib
import math
import pathlib
import random
import time
from collections.abc import Callable, Generator
from typing import IO, TYPE_CHECKING, Any, Union
from uuid import uuid4
Expand Down Expand Up @@ -805,17 +807,34 @@ def _write_tabular_arrow_table(
else:
basename_template = f"{uuid4()}-{{i}}.{format.default_extname}"

pads.write_dataset(
arrow_table,
schema=schema,
base_dir=full_path,
basename_template=basename_template,
format=format,
partitioning=None,
file_options=opts,
file_visitor=file_visitor,
use_threads=True,
existing_data_behavior="overwrite_or_ignore",
filesystem=fs,
**kwargs,
)
NUM_TRIES = 3
JITTER_MS = 2_500
MAX_BACKOFF_MS = 20_000

for attempt in range(NUM_TRIES):
try:
pads.write_dataset(
arrow_table,
schema=schema,
base_dir=full_path,
basename_template=basename_template,
format=format,
partitioning=None,
file_options=opts,
file_visitor=file_visitor,
use_threads=True,
existing_data_behavior="overwrite_or_ignore",
filesystem=fs,
**kwargs,
)
break
except OSError as e:
if "InvalidPart" not in str(e):
raise

if attempt == NUM_TRIES - 1:
raise OSError(f"Failed to retry write to {full_path}") from e
else:
jitter = random.randint(0, (2**attempt) * JITTER_MS)
backoff = min(MAX_BACKOFF_MS, jitter)
time.sleep(backoff / 1000)

0 comments on commit cfc6505

Please sign in to comment.