diff --git a/Cargo.lock b/Cargo.lock index 6e351373f5..ff4a611e01 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -674,6 +674,12 @@ dependencies = [ "vsimd", ] +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bincode" version = "1.3.3" @@ -849,6 +855,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "const-oid" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4c78c047431fee22c1a7bb92e00ad095a02a983affe4d8a72e2a2c62c1b94f3" + [[package]] name = "core-foundation" version = "0.9.3" @@ -935,6 +947,16 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crypto-bigint" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03c6a1d5fa1de37e071642dfa44ec552ca5b299adb128fab16138e24b548fd21" +dependencies = [ + "generic-array", + "subtle", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -1041,6 +1063,7 @@ dependencies = [ "common-error", "daft-core", "futures", + "google-cloud-storage", "lazy_static", "log", "md5", @@ -1111,6 +1134,17 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "der" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6919815d73839e7ad218de758883aae3a257ba6759ce7a9992501efbb53d705c" +dependencies = [ + "const-oid", + "crypto-bigint", + "pem-rfc7468", +] + [[package]] name = "digest" version = "0.10.7" @@ -1430,6 +1464,78 @@ version = "0.27.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" +[[package]] +name = "google-cloud-auth" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "931bedb2264cb00f914b0a6a5c304e34865c34306632d3932e0951a073e4a67d" +dependencies = [ + "async-trait", + "base64 0.21.2", + "google-cloud-metadata", + "google-cloud-token", + "home", + "jsonwebtoken", + "reqwest", + "serde", + "serde_json", + "thiserror", + "time 0.3.23", + "tokio", + "tracing", + "urlencoding", +] + +[[package]] +name = "google-cloud-metadata" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96e4ad0802d3f416f62e7ce01ac1460898ee0efc98f8b45cd4aab7611607012f" +dependencies = [ + "reqwest", + "thiserror", + "tokio", +] + +[[package]] +name = "google-cloud-storage" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6e1438784cd168c9094d37bbde12006326cae5b478a70b88aabc165c16c93a4" +dependencies = [ + "async-stream", + "base64 0.21.2", + "bytes", + "futures-util", + "google-cloud-auth", + "google-cloud-metadata", + "google-cloud-token", + "hex", + "once_cell", + "percent-encoding", + "regex", + "reqwest", + "ring", + "rsa", + "serde", + "serde_json", + "sha2", + "thiserror", + "time 0.3.23", + "tokio", + "tracing", + "url", +] + +[[package]] +name = "google-cloud-token" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fcd62eb34e3de2f085bcc33a09c3e17c4f65650f36d53eb328b00d63bcb536a" +dependencies = [ + "async-trait", +] + [[package]] name = "h2" version = "0.3.20" @@ -1500,6 +1606,15 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" +dependencies = [ + "windows-sys", +] + [[package]] name = "html-escape" version = "0.2.13" @@ -1761,11 +1876,28 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonwebtoken" +version = "8.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378" +dependencies = [ + "base64 0.21.2", + "pem", + "ring", + "serde", + "serde_json", + "simple_asn1", +] + [[package]] name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +dependencies = [ + "spin", +] [[package]] name = "lexical-core" @@ -1837,6 +1969,12 @@ version = "0.2.147" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +[[package]] +name = "libm" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" + [[package]] name = "linux-raw-sys" version = "0.4.3" @@ -1925,6 +2063,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4192263c238a5f0d0c6bfd21f336a313a4ce1c450542449ca191bb657b4642ef" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "miniz_oxide" version = "0.7.1" @@ -1981,6 +2129,34 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "num-bigint" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-bigint-dig" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc84195820f291c7697304f3cbdadd1cb7199c0efc917ff5eafd71225c136151" +dependencies = [ + "byteorder", + "lazy_static", + "libm", + "num-integer", + "num-iter", + "num-traits", + "rand 0.8.5", + "smallvec", + "zeroize", +] + [[package]] name = "num-complex" version = "0.4.3" @@ -2011,6 +2187,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-rational" version = "0.4.1" @@ -2029,6 +2216,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -2165,6 +2353,24 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +[[package]] +name = "pem" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8835c273a76a90455d7344889b0964598e3316e2a79ede8e36f16bdcf2228b8" +dependencies = [ + "base64 0.13.1", +] + +[[package]] +name = "pem-rfc7468" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01de5d978f34aa4b2296576379fcc416034702fd94117c56ffd8a1a767cefb30" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.0" @@ -2241,6 +2447,28 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a78f66c04ccc83dd4486fd46c33896f4e17b24a7a3a6400dedc48ed0ddd72320" +dependencies = [ + "der", + "pkcs8", + "zeroize", +] + +[[package]] +name = "pkcs8" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7cabda3fb821068a9a4fab19a683eac3af12edf0f34b94a8be53c4972b8149d0" +dependencies = [ + "der", + "spki", + "zeroize", +] + [[package]] name = "pkg-config" version = "0.3.27" @@ -2566,6 +2794,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "once_cell", "percent-encoding", "pin-project-lite", @@ -2602,6 +2831,26 @@ dependencies = [ "winapi", ] +[[package]] +name = "rsa" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cf22754c49613d2b3b119f0e5d46e34a2c628a937e3024b8762de4e7d8c710b" +dependencies = [ + "byteorder", + "digest", + "num-bigint-dig", + "num-integer", + "num-iter", + "num-traits", + "pkcs1", + "pkcs8", + "rand_core 0.6.4", + "smallvec", + "subtle", + "zeroize", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -2860,6 +3109,18 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" +[[package]] +name = "simple_asn1" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" +dependencies = [ + "num-bigint", + "num-traits", + "thiserror", + "time 0.3.23", +] + [[package]] name = "siphasher" version = "0.3.10" @@ -2925,6 +3186,16 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "spki" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d01ac02a6ccf3e07db148d2be087da624fea0221a16152ed01f0496a6b0a27" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -3114,6 +3385,7 @@ dependencies = [ "libc", "mio", "num_cpus", + "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", @@ -3251,6 +3523,15 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +[[package]] +name = "unicase" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f37be617794602aabbeee0be4f259dc1778fabe05e2d67ee8f79326d5cb4f6" +dependencies = [ + "version_check", +] + [[package]] name = "unicode-bidi" version = "0.3.13" diff --git a/daft/io/__init__.py b/daft/io/__init__.py index 5b58fb17d9..8cd278a879 100644 --- a/daft/io/__init__.py +++ b/daft/io/__init__.py @@ -2,7 +2,7 @@ import sys -from daft.daft import AzureConfig, IOConfig, S3Config +from daft.daft import AzureConfig, GCSConfig, IOConfig, S3Config from daft.io._csv import read_csv from daft.io._json import read_json from daft.io._parquet import read_parquet @@ -23,4 +23,13 @@ def _set_linux_cert_paths(): if sys.platform == "linux": _set_linux_cert_paths() -__all__ = ["read_csv", "read_json", "from_glob_path", "read_parquet", "IOConfig", "S3Config", "AzureConfig"] +__all__ = [ + "read_csv", + "read_json", + "from_glob_path", + "read_parquet", + "IOConfig", + "S3Config", + "AzureConfig", + "GCSConfig", +] diff --git a/pyproject.toml b/pyproject.toml index ac7d293597..11f0a0ea3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,10 @@ readme = "README.rst" requires-python = ">=3.7" [project.optional-dependencies] -all = ["getdaft[aws, azure, ray, pandas, numpy, viz]"] +all = ["getdaft[aws, azure, gcp, ray, pandas, numpy, viz]"] aws = ["s3fs"] azure = ["adlfs"] +gcp = ["gcsfs"] numpy = ["numpy"] pandas = ["pandas"] ray = [ diff --git a/requirements-dev.txt b/requirements-dev.txt index c85d13127e..903d787cc4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -43,6 +43,11 @@ urllib3<2; python_version < '3.8' adlfs==2022.2.0; python_version < '3.8' adlfs==2023.8.0; python_version >= '3.8' + +# GCS +gcsfs==2023.1.0; python_version < '3.8' +gcsfs==2023.6.0; python_version >= '3.8' + # Documentation myst-nb>=0.16.0 Sphinx <= 5 diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index 9ec0dc4255..6158afe73b 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -13,6 +13,7 @@ bytes = {workspace = true} common-error = {path = "../common/error", default-features = false} daft-core = {path = "../daft-core", default-features = false} futures = {workspace = true} +google-cloud-storage = {version = "0.13.0", default-features = false, features = ["rustls-tls", "auth"]} lazy_static = {workspace = true} log = {workspace = true} pyo3 = {workspace = true, optional = true} diff --git a/src/daft-io/src/config.rs b/src/daft-io/src/config.rs index 082dfc6643..98f7e5c8d0 100644 --- a/src/daft-io/src/config.rs +++ b/src/daft-io/src/config.rs @@ -75,10 +75,29 @@ impl Display for AzureConfig { } } +#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct GCSConfig { + pub project_id: Option, + pub anonymous: bool, +} + +impl Display for GCSConfig { + fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!( + f, + "GCSConfig + project_id: {:?} + anonymous: {:?}", + self.project_id, self.anonymous + ) + } +} + #[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct IOConfig { pub s3: S3Config, pub azure: AzureConfig, + pub gcs: GCSConfig, } impl Display for IOConfig { @@ -87,8 +106,9 @@ impl Display for IOConfig { f, "IOConfig: {} +{} {}", - self.s3, self.azure + self.s3, self.azure, self.gcs ) } } diff --git a/src/daft-io/src/google_cloud.rs b/src/daft-io/src/google_cloud.rs new file mode 100644 index 0000000000..260e67adbe --- /dev/null +++ b/src/daft-io/src/google_cloud.rs @@ -0,0 +1,247 @@ +use std::ops::Range; +use std::sync::Arc; + +use futures::StreamExt; +use futures::TryStreamExt; +use google_cloud_storage::client::ClientConfig; + +use async_trait::async_trait; +use google_cloud_storage::client::Client; +use google_cloud_storage::http::objects::get::GetObjectRequest; +use google_cloud_storage::http::Error as GError; +use snafu::IntoError; +use snafu::ResultExt; +use snafu::Snafu; + +use crate::config; +use crate::config::GCSConfig; +use crate::object_io::ObjectSource; +use crate::s3_like; +use crate::GetResult; + +#[derive(Debug, Snafu)] +enum Error { + #[snafu(display("Unable to open {}: {}", path, source))] + UnableToOpenFile { path: String, source: GError }, + + #[snafu(display("Unable to read data from {}: {}", path, source))] + UnableToReadBytes { path: String, source: GError }, + + #[snafu(display("Unable to parse URL: \"{}\"", path))] + InvalidUrl { + path: String, + source: url::ParseError, + }, + #[snafu(display("Unable to load Credentials: {}", source))] + UnableToLoadCredentials { + source: google_cloud_storage::client::google_cloud_auth::error::Error, + }, + + #[snafu(display("Not a File: \"{}\"", path))] + NotAFile { path: String }, +} + +impl From for super::Error { + fn from(error: Error) -> Self { + use Error::*; + match error { + UnableToReadBytes { path, source } | UnableToOpenFile { path, source } => { + match source { + GError::HttpClient(err) => match err.status().map(|s| s.as_u16()) { + Some(404) | Some(410) => super::Error::NotFound { + path, + source: err.into(), + }, + Some(401) => super::Error::Unauthorized { + store: super::SourceType::GCS, + path, + source: err.into(), + }, + _ => super::Error::UnableToOpenFile { + path, + source: err.into(), + }, + }, + GError::Response(err) => match err.code { + 404 | 410 => super::Error::NotFound { + path, + source: err.into(), + }, + 401 => super::Error::Unauthorized { + store: super::SourceType::GCS, + path, + source: err.into(), + }, + _ => super::Error::UnableToOpenFile { + path, + source: err.into(), + }, + }, + GError::TokenSource(err) => super::Error::UnableToLoadCredentials { + store: super::SourceType::GCS, + source: err, + }, + } + } + NotAFile { path } => super::Error::NotAFile { path }, + InvalidUrl { path, source } => super::Error::InvalidUrl { path, source }, + UnableToLoadCredentials { source } => super::Error::UnableToLoadCredentials { + store: super::SourceType::GCS, + source: source.into(), + }, + } + } +} + +enum GCSClientWrapper { + Native(Client), + S3Compat(Arc), +} + +impl GCSClientWrapper { + async fn get(&self, uri: &str, range: Option>) -> super::Result { + let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; + let bucket = match parsed.host_str() { + Some(s) => Ok(s), + None => Err(Error::InvalidUrl { + path: uri.into(), + source: url::ParseError::EmptyHost, + }), + }?; + let key = parsed.path(); + let key = if let Some(key) = key.strip_prefix('/') { + key + } else { + return Err(Error::NotAFile { path: uri.into() }.into()); + }; + + match self { + GCSClientWrapper::Native(client) => { + let req = GetObjectRequest { + bucket: bucket.into(), + object: key.into(), + ..Default::default() + }; + use google_cloud_storage::http::objects::download::Range as GRange; + let (grange, size) = if let Some(range) = range { + ( + GRange(Some(range.start as u64), Some(range.end as u64)), + Some(range.len()), + ) + } else { + (GRange::default(), None) + }; + let owned_uri = uri.to_string(); + let response = client + .download_streamed_object(&req, &grange) + .await + .context(UnableToOpenFileSnafu { + path: uri.to_string(), + })?; + let response = response.map_err(move |e| { + UnableToReadBytesSnafu:: { + path: owned_uri.clone(), + } + .into_error(e) + .into() + }); + Ok(GetResult::Stream(response.boxed(), size)) + } + GCSClientWrapper::S3Compat(client) => { + let uri = format!("s3://{}/{}", bucket, key); + client.get(&uri, range).await + } + } + } + + async fn get_size(&self, uri: &str) -> super::Result { + let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; + let bucket = match parsed.host_str() { + Some(s) => Ok(s), + None => Err(Error::InvalidUrl { + path: uri.into(), + source: url::ParseError::EmptyHost, + }), + }?; + let key = parsed.path(); + let key = if let Some(key) = key.strip_prefix('/') { + key + } else { + return Err(Error::NotAFile { path: uri.into() }.into()); + }; + match self { + GCSClientWrapper::Native(client) => { + let req = GetObjectRequest { + bucket: bucket.into(), + object: key.into(), + ..Default::default() + }; + + let response = client + .get_object(&req) + .await + .context(UnableToOpenFileSnafu { + path: uri.to_string(), + })?; + Ok(response.size as usize) + } + GCSClientWrapper::S3Compat(client) => { + let uri = format!("s3://{}/{}", bucket, key); + client.get_size(&uri).await + } + } + } +} + +pub(crate) struct GCSSource { + client: GCSClientWrapper, +} + +impl GCSSource { + async fn build_s3_compat_client() -> super::Result> { + let s3_config = config::S3Config { + anonymous: true, + endpoint_url: Some("https://storage.googleapis.com".to_string()), + ..Default::default() + }; + let s3_client = s3_like::S3LikeSource::get_client(&s3_config).await?; + Ok(GCSSource { + client: GCSClientWrapper::S3Compat(s3_client), + } + .into()) + } + pub async fn get_client(config: &GCSConfig) -> super::Result> { + if config.anonymous { + GCSSource::build_s3_compat_client().await + } else { + let config = ClientConfig::default() + .with_auth() + .await + .context(UnableToLoadCredentialsSnafu {}); + match config { + Ok(config) => { + let client = Client::new(config); + Ok(GCSSource { + client: GCSClientWrapper::Native(client), + } + .into()) + } + Err(err) => { + log::warn!("Google Cloud Storage Credentials not provided or found when making client. Reverting to Anonymous mode.\nDetails\n{err}"); + GCSSource::build_s3_compat_client().await + } + } + } + } +} + +#[async_trait] +impl ObjectSource for GCSSource { + async fn get(&self, uri: &str, range: Option>) -> super::Result { + self.client.get(uri, range).await + } + + async fn get_size(&self, uri: &str) -> super::Result { + self.client.get_size(uri).await + } +} diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 513664ba6b..5a812bcd0c 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -2,11 +2,13 @@ mod azure_blob; pub mod config; +mod google_cloud; mod http; mod local; mod object_io; mod s3_like; use azure_blob::AzureBlobSource; +use google_cloud::GCSSource; use lazy_static::lazy_static; #[cfg(feature = "python")] pub mod python; @@ -41,7 +43,7 @@ pub enum Error { #[snafu(display("Generic {} error: {}", store, source))] Generic { store: SourceType, source: DynError }, - #[snafu(display("Object at location {} not found: {}", path, source))] + #[snafu(display("Object at location {} not found\nDetails:\n{}", path, source))] NotFound { path: String, source: DynError }, #[snafu(display("Invalid Argument: {:?}", msg))] @@ -65,10 +67,10 @@ pub enum Error { #[snafu(display("Not a File: \"{}\"", path))] NotAFile { path: String }, - #[snafu(display("Unable to load Credentials for store: {store} {source}"))] + #[snafu(display("Unable to load Credentials for store: {store}\nDetails:\n{source:?}"))] UnableToLoadCredentials { store: SourceType, source: DynError }, - #[snafu(display("Failed to load Credentials for store: {store} {source}"))] + #[snafu(display("Failed to load Credentials for store: {store}\nDetails:\n{source:?}"))] UnableToCreateClient { store: SourceType, source: DynError }, #[snafu(display("Unauthorized to access store: {store} for file: {path}\nYou may need to set valid Credentials\n{source}"))] @@ -138,6 +140,10 @@ impl IOClient { SourceType::AzureBlob => { AzureBlobSource::get_client(&self.config.azure).await? as Arc } + + SourceType::GCS => { + GCSSource::get_client(&self.config.gcs).await? as Arc + } }; if w_handle.get(source_type).is_none() { @@ -202,6 +208,7 @@ pub enum SourceType { Http, S3, AzureBlob, + GCS, } impl std::fmt::Display for SourceType { @@ -211,6 +218,7 @@ impl std::fmt::Display for SourceType { SourceType::Http => write!(f, "http"), SourceType::S3 => write!(f, "s3"), SourceType::AzureBlob => write!(f, "AzureBlob"), + SourceType::GCS => write!(f, "gcs"), } } } @@ -234,6 +242,7 @@ fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> { "http" | "https" => Ok((SourceType::Http, fixed_input)), "s3" => Ok((SourceType::S3, fixed_input)), "az" | "abfs" => Ok((SourceType::AzureBlob, fixed_input)), + "gcs" | "gs" => Ok((SourceType::GCS, fixed_input)), _ => Err(Error::NotImplementedSource { store: scheme }), } } diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index 3e4999d827..9843d0ffb9 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -39,14 +39,30 @@ pub struct AzureConfig { pub config: config::AzureConfig, } +/// Create configurations to be used when accessing Google Cloud Storage +/// +/// Args: +/// project_id: Google Project ID, defaults to reading credentials file or Google Cloud metadata service +/// anonymous: Whether or not to use "anonymous mode", which will access Google Storage without any credentials +/// +/// Example: +/// >>> io_config = IOConfig(gcs=GCSConfig(anonymous=True)) +/// >>> daft.read_parquet("gs://some-path", io_config=io_config) +#[derive(Clone, Default)] +#[pyclass] +pub struct GCSConfig { + pub config: config::GCSConfig, +} + /// Create configurations to be used when accessing storage /// /// Args: -/// s3: Configurations to use when accessing URLs with the `s3://` scheme -/// azure: Configurations to use when accessing URLs with the `az://` or `abfs://` scheme +/// s3: Configuration to use when accessing URLs with the `s3://` scheme +/// azure: Configuration to use when accessing URLs with the `az://` or `abfs://` scheme +/// gcs: Configuration to use when accessing URLs with the `gs://` or `gcs://` scheme /// Example: -/// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx", num_tries=10), azure=AzureConfig(anonymous=True)) -/// >>> daft.read_parquet(["s3://some-path", "az://some-other-path"], io_config=io_config) +/// >>> io_config = IOConfig(s3=S3Config(key_id="xxx", access_key="xxx", num_tries=10), azure=AzureConfig(anonymous=True), gcs=GCSConfig(...)) +/// >>> daft.read_parquet(["s3://some-path", "az://some-other-path", "gs://path3"], io_config=io_config) #[derive(Clone, Default)] #[pyclass] pub struct IOConfig { @@ -56,11 +72,12 @@ pub struct IOConfig { #[pymethods] impl IOConfig { #[new] - pub fn new(s3: Option, azure: Option) -> Self { + pub fn new(s3: Option, azure: Option, gcs: Option) -> Self { IOConfig { config: config::IOConfig { s3: s3.unwrap_or_default().config, azure: azure.unwrap_or_default().config, + gcs: gcs.unwrap_or_default().config, }, } } @@ -69,7 +86,7 @@ impl IOConfig { Ok(format!("{}", self.config)) } - /// Configurations to be used when accessing s3 URLs + /// Configuration to be used when accessing s3 URLs #[getter] pub fn s3(&self) -> PyResult { Ok(S3Config { @@ -77,7 +94,7 @@ impl IOConfig { }) } - /// Configurations to be used when accessing Azure URLs + /// Configuration to be used when accessing Azure URLs #[getter] pub fn azure(&self) -> PyResult { Ok(AzureConfig { @@ -85,6 +102,14 @@ impl IOConfig { }) } + /// Configuration to be used when accessing Azure URLs + #[getter] + pub fn gcs(&self) -> PyResult { + Ok(GCSConfig { + config: self.config.gcs.clone(), + }) + } + #[staticmethod] pub fn from_json(input: &str) -> PyResult { let config: config::IOConfig = serde_json::from_str(input).map_err(DaftError::from)?; @@ -216,6 +241,31 @@ impl AzureConfig { } } +#[pymethods] +impl GCSConfig { + #[allow(clippy::too_many_arguments)] + #[new] + pub fn new(project_id: Option, anonymous: Option) -> Self { + let def = config::GCSConfig::default(); + GCSConfig { + config: config::GCSConfig { + project_id: project_id.or(def.project_id), + anonymous: anonymous.unwrap_or(def.anonymous), + }, + } + } + + pub fn __repr__(&self) -> PyResult { + Ok(format!("{}", self.config)) + } + + /// Project ID to use when accessing Google Cloud Storage + #[getter] + pub fn project_id(&self) -> PyResult> { + Ok(self.config.project_id.clone()) + } +} + impl From for IOConfig { fn from(config: config::IOConfig) -> Self { Self { config } @@ -224,6 +274,7 @@ impl From for IOConfig { pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_class::()?; + parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index ed3da3b361..3208f70ad4 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -97,6 +97,16 @@ impl From for super::Error { source: err.into(), }, }, + UnableToHeadFile { path, source } => match source.into_service_error() { + HeadObjectError::NotFound(no_such_key) => super::Error::NotFound { + path, + source: no_such_key.into(), + }, + err => super::Error::UnableToOpenFile { + path, + source: err.into(), + }, + }, InvalidUrl { path, source } => super::Error::InvalidUrl { path, source }, UnableToReadBytes { path, source } => super::Error::UnableToReadBytes { path, diff --git a/tests/integration/io/parquet/test_reads_local_fixtures.py b/tests/integration/io/parquet/test_reads_local_fixtures.py index 15ef9d6f82..cab98d5f41 100644 --- a/tests/integration/io/parquet/test_reads_local_fixtures.py +++ b/tests/integration/io/parquet/test_reads_local_fixtures.py @@ -14,8 +14,7 @@ @pytest.mark.parametrize("bucket", BUCKETS) def test_non_retryable_errors(retry_server_s3_config, status_code: int, bucket: str): data_path = f"s3://{bucket}/{status_code}/1/{uuid.uuid4()}" - - with pytest.raises(ValueError): + with pytest.raises((FileNotFoundError, ValueError)): Table.read_parquet(data_path, io_config=retry_server_s3_config) diff --git a/tests/integration/io/parquet/test_reads_public_data.py b/tests/integration/io/parquet/test_reads_public_data.py index 9c95fbeebb..c444a43f73 100644 --- a/tests/integration/io/parquet/test_reads_public_data.py +++ b/tests/integration/io/parquet/test_reads_public_data.py @@ -161,6 +161,10 @@ "azure/mvp", "az://public-anonymous/mvp.parquet", ), + ( + "gcs/mvp", + "gs://daft-public-data-gs/mvp.parquet", + ), ] @@ -169,6 +173,7 @@ def public_storage_io_config() -> daft.io.IOConfig: return daft.io.IOConfig( azure=daft.io.AzureConfig(storage_account="dafttestdata", anonymous=True), s3=daft.io.S3Config(region_name="us-west-2", anonymous=True), + gcs=daft.io.GCSConfig(anonymous=True), )