diff --git a/Cargo.lock b/Cargo.lock index 0b02e57cc8..35c8aed8bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -869,6 +869,7 @@ dependencies = [ "aws-sdk-s3", "aws-sig-auth", "aws-sigv4", + "bincode", "bytes", "common-error", "daft-core", diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index e2807f8784..37819221a1 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -6,6 +6,7 @@ aws-credential-types = {version = "0.55.3", features = ["hardcoded-credentials"] aws-sdk-s3 = "0.28.0" aws-sig-auth = "0.55.3" aws-sigv4 = "0.55.3" +bincode = {workspace = true} bytes = "1.4.0" common-error = {path = "../common/error", default-features = false} daft-core = {path = "../daft-core", default-features = false} diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index 105a30ba3c..3bda1859d3 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -1,5 +1,5 @@ use crate::config::{IOConfig, S3Config}; -use pyo3::prelude::*; +use pyo3::{prelude::*, types::PyBytes}; #[derive(Clone, Default)] #[pyclass] @@ -34,6 +34,20 @@ impl PyIOConfig { config: self.config.s3.clone(), }) } + + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + match state.extract::<&PyBytes>(py) { + Ok(s) => { + self.config = bincode::deserialize(s.as_bytes()).unwrap(); + Ok(()) + } + Err(e) => Err(e), + } + } + + pub fn __getstate__(&self, py: Python) -> PyResult { + Ok(PyBytes::new(py, &bincode::serialize(&self.config).unwrap()).to_object(py)) + } } #[pymethods] diff --git a/tests/integration/io/parquet/test_remote_reads.py b/tests/integration/io/parquet/test_remote_reads.py index d933992f4b..95463abe75 100644 --- a/tests/integration/io/parquet/test_remote_reads.py +++ b/tests/integration/io/parquet/test_remote_reads.py @@ -5,6 +5,7 @@ import pytest from pyarrow import parquet as pq +import daft from daft.filesystem import get_filesystem_from_path, get_protocol_from_path from daft.io import IOConfig, S3Config from daft.table import Table @@ -177,9 +178,20 @@ def read_parquet_with_pyarrow(path) -> pa.Table: @pytest.mark.integration() -def test_parquet_read(parquet_file): +def test_parquet_read_table(parquet_file): _, url = parquet_file daft_native_read = Table.read_parquet(url, io_config=IOConfig(s3=S3Config(anonymous=True))) pa_read = Table.from_arrow(read_parquet_with_pyarrow(url)) assert daft_native_read.schema() == pa_read.schema() pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas()) + + +@pytest.mark.integration() +def test_parquet_read_df(parquet_file): + _, url = parquet_file + daft_native_read = daft.read_parquet( + url, io_config=IOConfig(s3=S3Config(anonymous=True)), use_native_downloader=True + ) + pa_read = Table.from_arrow(read_parquet_with_pyarrow(url)) + assert daft_native_read.schema() == pa_read.schema() + pd.testing.assert_frame_equal(daft_native_read.to_pandas(), pa_read.to_pandas())