Skip to content

Commit

Permalink
Add fixes to make PyIOConfig pickleable
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Jul 27, 2023
1 parent b3a0836 commit 0f1b076
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ aws-credential-types = {version = "0.55.3", features = ["hardcoded-credentials"]
aws-sdk-s3 = "0.28.0"
aws-sig-auth = "0.55.3"
aws-sigv4 = "0.55.3"
bincode = {workspace = true}
bytes = "1.4.0"
common-error = {path = "../common/error", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
Expand Down
16 changes: 15 additions & 1 deletion src/daft-io/src/python.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::config::{IOConfig, S3Config};
use pyo3::prelude::*;
use pyo3::{prelude::*, types::PyBytes};

#[derive(Clone, Default)]
#[pyclass]
Expand Down Expand Up @@ -34,6 +34,20 @@ impl PyIOConfig {
config: self.config.s3.clone(),
})
}

pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
match state.extract::<&PyBytes>(py) {
Ok(s) => {
self.config = bincode::deserialize(s.as_bytes()).unwrap();
Ok(())
}
Err(e) => Err(e),
}
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(py, &bincode::serialize(&self.config).unwrap()).to_object(py))
}
}

#[pymethods]
Expand Down
14 changes: 13 additions & 1 deletion tests/integration/io/parquet/test_remote_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from pyarrow import parquet as pq

import daft
from daft.filesystem import get_filesystem_from_path, get_protocol_from_path
from daft.io import IOConfig, S3Config
from daft.table import Table
Expand Down Expand Up @@ -177,9 +178,20 @@ def read_parquet_with_pyarrow(path) -> pa.Table:


@pytest.mark.integration()
def test_parquet_read(parquet_file):
def test_parquet_read_table(parquet_file):
_, url = parquet_file
daft_native_read = Table.read_parquet(url, io_config=IOConfig(s3=S3Config(anonymous=True)))
pa_read = Table.from_arrow(read_parquet_with_pyarrow(url))
assert daft_native_read.schema() == pa_read.schema()
pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas())


@pytest.mark.integration()
def test_parquet_read_df(parquet_file):
_, url = parquet_file
daft_native_read = daft.read_parquet(
url, io_config=IOConfig(s3=S3Config(anonymous=True)), use_native_downloader=True
)
pa_read = Table.from_arrow(read_parquet_with_pyarrow(url))
assert daft_native_read.schema() == pa_read.schema()
pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas())

0 comments on commit 0f1b076

Please sign in to comment.