Skip to content

Commit

Permalink
[FEAT] native google cloud reader (#1271)
Browse files Browse the repository at this point in the history
* Adds GCS support to the native io reader for urls with `gs://` or
`gcs://` paths.
* io_config can be configured for GCS via
```
    daft.io.IOConfig(
        gcs=daft.io.GCSConfig(
            project_id="...",
            #anonymous = True
        )
    )
```
  • Loading branch information
samster25 committed Aug 15, 2023
1 parent a5c702b commit a3e38c4
Show file tree
Hide file tree
Showing 12 changed files with 654 additions and 16 deletions.
281 changes: 281 additions & 0 deletions Cargo.lock

Large diffs are not rendered by default.

13 changes: 11 additions & 2 deletions daft/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys

from daft.daft import AzureConfig, IOConfig, S3Config
from daft.daft import AzureConfig, GCSConfig, IOConfig, S3Config
from daft.io._csv import read_csv
from daft.io._json import read_json
from daft.io._parquet import read_parquet
Expand All @@ -23,4 +23,13 @@ def _set_linux_cert_paths():
if sys.platform == "linux":
_set_linux_cert_paths()

__all__ = ["read_csv", "read_json", "from_glob_path", "read_parquet", "IOConfig", "S3Config", "AzureConfig"]
__all__ = [
"read_csv",
"read_json",
"from_glob_path",
"read_parquet",
"IOConfig",
"S3Config",
"AzureConfig",
"GCSConfig",
]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ readme = "README.rst"
requires-python = ">=3.7"

[project.optional-dependencies]
all = ["getdaft[aws, azure, ray, pandas, numpy, viz]"]
all = ["getdaft[aws, azure, gcp, ray, pandas, numpy, viz]"]
aws = ["s3fs"]
azure = ["adlfs"]
gcp = ["gcsfs"]
numpy = ["numpy"]
pandas = ["pandas"]
ray = [
Expand Down
5 changes: 5 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ urllib3<2; python_version < '3.8'
adlfs==2022.2.0; python_version < '3.8'
adlfs==2023.8.0; python_version >= '3.8'


# GCS
gcsfs==2023.1.0; python_version < '3.8'
gcsfs==2023.6.0; python_version >= '3.8'

# Documentation
myst-nb>=0.16.0
Sphinx <= 5
Expand Down
1 change: 1 addition & 0 deletions src/daft-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ bytes = {workspace = true}
common-error = {path = "../common/error", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
futures = {workspace = true}
google-cloud-storage = {version = "0.13.0", default-features = false, features = ["rustls-tls", "auth"]}
lazy_static = {workspace = true}
log = {workspace = true}
pyo3 = {workspace = true, optional = true}
Expand Down
22 changes: 21 additions & 1 deletion src/daft-io/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,29 @@ impl Display for AzureConfig {
}
}

#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct GCSConfig {
pub project_id: Option<String>,
pub anonymous: bool,
}

impl Display for GCSConfig {
fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
write!(
f,
"GCSConfig
project_id: {:?}
anonymous: {:?}",
self.project_id, self.anonymous
)
}
}

#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct IOConfig {
pub s3: S3Config,
pub azure: AzureConfig,
pub gcs: GCSConfig,
}

impl Display for IOConfig {
Expand All @@ -87,8 +106,9 @@ impl Display for IOConfig {
f,
"IOConfig:
{}
{}
{}",
self.s3, self.azure
self.s3, self.azure, self.gcs
)
}
}
247 changes: 247 additions & 0 deletions src/daft-io/src/google_cloud.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
use std::ops::Range;
use std::sync::Arc;

use futures::StreamExt;
use futures::TryStreamExt;
use google_cloud_storage::client::ClientConfig;

use async_trait::async_trait;
use google_cloud_storage::client::Client;
use google_cloud_storage::http::objects::get::GetObjectRequest;
use google_cloud_storage::http::Error as GError;
use snafu::IntoError;
use snafu::ResultExt;
use snafu::Snafu;

use crate::config;
use crate::config::GCSConfig;
use crate::object_io::ObjectSource;
use crate::s3_like;
use crate::GetResult;

#[derive(Debug, Snafu)]
enum Error {
#[snafu(display("Unable to open {}: {}", path, source))]
UnableToOpenFile { path: String, source: GError },

#[snafu(display("Unable to read data from {}: {}", path, source))]
UnableToReadBytes { path: String, source: GError },

#[snafu(display("Unable to parse URL: \"{}\"", path))]
InvalidUrl {
path: String,
source: url::ParseError,
},
#[snafu(display("Unable to load Credentials: {}", source))]
UnableToLoadCredentials {
source: google_cloud_storage::client::google_cloud_auth::error::Error,
},

#[snafu(display("Not a File: \"{}\"", path))]
NotAFile { path: String },
}

impl From<Error> for super::Error {
fn from(error: Error) -> Self {
use Error::*;
match error {
UnableToReadBytes { path, source } | UnableToOpenFile { path, source } => {
match source {
GError::HttpClient(err) => match err.status().map(|s| s.as_u16()) {
Some(404) | Some(410) => super::Error::NotFound {
path,
source: err.into(),
},
Some(401) => super::Error::Unauthorized {
store: super::SourceType::GCS,
path,
source: err.into(),
},
_ => super::Error::UnableToOpenFile {
path,
source: err.into(),
},
},
GError::Response(err) => match err.code {
404 | 410 => super::Error::NotFound {
path,
source: err.into(),
},
401 => super::Error::Unauthorized {
store: super::SourceType::GCS,
path,
source: err.into(),
},
_ => super::Error::UnableToOpenFile {
path,
source: err.into(),
},
},
GError::TokenSource(err) => super::Error::UnableToLoadCredentials {
store: super::SourceType::GCS,
source: err,
},
}
}
NotAFile { path } => super::Error::NotAFile { path },
InvalidUrl { path, source } => super::Error::InvalidUrl { path, source },
UnableToLoadCredentials { source } => super::Error::UnableToLoadCredentials {
store: super::SourceType::GCS,
source: source.into(),
},
}
}
}

enum GCSClientWrapper {
Native(Client),
S3Compat(Arc<s3_like::S3LikeSource>),
}

impl GCSClientWrapper {
async fn get(&self, uri: &str, range: Option<Range<usize>>) -> super::Result<GetResult> {
let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?;
let bucket = match parsed.host_str() {
Some(s) => Ok(s),
None => Err(Error::InvalidUrl {
path: uri.into(),
source: url::ParseError::EmptyHost,
}),
}?;
let key = parsed.path();
let key = if let Some(key) = key.strip_prefix('/') {
key
} else {
return Err(Error::NotAFile { path: uri.into() }.into());
};

match self {
GCSClientWrapper::Native(client) => {
let req = GetObjectRequest {
bucket: bucket.into(),
object: key.into(),
..Default::default()
};
use google_cloud_storage::http::objects::download::Range as GRange;
let (grange, size) = if let Some(range) = range {
(
GRange(Some(range.start as u64), Some(range.end as u64)),
Some(range.len()),
)
} else {
(GRange::default(), None)
};
let owned_uri = uri.to_string();
let response = client
.download_streamed_object(&req, &grange)
.await
.context(UnableToOpenFileSnafu {
path: uri.to_string(),
})?;
let response = response.map_err(move |e| {
UnableToReadBytesSnafu::<String> {
path: owned_uri.clone(),
}
.into_error(e)
.into()
});
Ok(GetResult::Stream(response.boxed(), size))
}
GCSClientWrapper::S3Compat(client) => {
let uri = format!("s3://{}/{}", bucket, key);
client.get(&uri, range).await
}
}
}

async fn get_size(&self, uri: &str) -> super::Result<usize> {
let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?;
let bucket = match parsed.host_str() {
Some(s) => Ok(s),
None => Err(Error::InvalidUrl {
path: uri.into(),
source: url::ParseError::EmptyHost,
}),
}?;
let key = parsed.path();
let key = if let Some(key) = key.strip_prefix('/') {
key
} else {
return Err(Error::NotAFile { path: uri.into() }.into());
};
match self {
GCSClientWrapper::Native(client) => {
let req = GetObjectRequest {
bucket: bucket.into(),
object: key.into(),
..Default::default()
};

let response = client
.get_object(&req)
.await
.context(UnableToOpenFileSnafu {
path: uri.to_string(),
})?;
Ok(response.size as usize)
}
GCSClientWrapper::S3Compat(client) => {
let uri = format!("s3://{}/{}", bucket, key);
client.get_size(&uri).await
}
}
}
}

pub(crate) struct GCSSource {
client: GCSClientWrapper,
}

impl GCSSource {
async fn build_s3_compat_client() -> super::Result<Arc<Self>> {
let s3_config = config::S3Config {
anonymous: true,
endpoint_url: Some("https://storage.googleapis.com".to_string()),
..Default::default()
};
let s3_client = s3_like::S3LikeSource::get_client(&s3_config).await?;
Ok(GCSSource {
client: GCSClientWrapper::S3Compat(s3_client),
}
.into())
}
pub async fn get_client(config: &GCSConfig) -> super::Result<Arc<Self>> {
if config.anonymous {
GCSSource::build_s3_compat_client().await
} else {
let config = ClientConfig::default()
.with_auth()
.await
.context(UnableToLoadCredentialsSnafu {});
match config {
Ok(config) => {
let client = Client::new(config);
Ok(GCSSource {
client: GCSClientWrapper::Native(client),
}
.into())
}
Err(err) => {
log::warn!("Google Cloud Storage Credentials not provided or found when making client. Reverting to Anonymous mode.\nDetails\n{err}");
GCSSource::build_s3_compat_client().await
}
}
}
}
}

#[async_trait]
impl ObjectSource for GCSSource {
async fn get(&self, uri: &str, range: Option<Range<usize>>) -> super::Result<GetResult> {
self.client.get(uri, range).await
}

async fn get_size(&self, uri: &str) -> super::Result<usize> {
self.client.get_size(uri).await
}
}
Loading

0 comments on commit a3e38c4

Please sign in to comment.