From 5eb53183f174400b162e5cfa026029eedbb291df Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Sat, 1 Jul 2023 16:11:42 -0700 Subject: [PATCH] [FEAT] Ranged Get Native Downloader (#1113) * Implements Range gets for Native Downloader --- Cargo.lock | 22 +++++++ src/daft-io/Cargo.toml | 4 ++ src/daft-io/src/http.rs | 68 +++++++++++++++++-- src/daft-io/src/lib.rs | 12 ++-- src/daft-io/src/local.rs | 123 ++++++++++++++++++++++++++++++++--- src/daft-io/src/object_io.rs | 12 ++-- src/daft-io/src/s3_like.rs | 68 ++++++++++++++++++- 7 files changed, 284 insertions(+), 25 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9521b3b712..e06f51ae84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -766,11 +766,13 @@ dependencies = [ "futures", "lazy_static", "log", + "md5", "pyo3", "pyo3-log", "reqwest", "serde", "snafu", + "tempfile", "tokio", "url", ] @@ -1465,6 +1467,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.5.0" @@ -2359,6 +2367,20 @@ version = "0.12.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b1c7f239eb94671427157bd93b3694320f3668d4e1eff08c7285366fd777fac" +[[package]] +name = "tempfile" +version = "3.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6" +dependencies = [ + "autocfg", + "cfg-if", + "fastrand", + "redox_syscall 0.3.5", + "rustix", + "windows-sys 0.48.0", +] + [[package]] name = "term" version = "0.7.0" diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index 56132b1ed0..bb0201eb1c 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -23,6 +23,10 @@ default-features = false features = ["stream", "rustls-tls"] version = "0.11.18" +[dev-dependencies] +md5 = "0.7.0" +tempfile = "3.6.0" + [features] default = ["python"] python = ["dep:pyo3", "dep:pyo3-log", "common-error/python", "daft-core/python"] diff --git a/src/daft-io/src/http.rs b/src/daft-io/src/http.rs index 428da1ce8c..b99070bc15 100644 --- a/src/daft-io/src/http.rs +++ b/src/daft-io/src/http.rs @@ -1,8 +1,9 @@ -use std::sync::Arc; +use std::{ops::Range, sync::Arc}; use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; use lazy_static::lazy_static; +use reqwest::header::RANGE; use snafu::{IntoError, ResultExt, Snafu}; use super::object_io::{GetResult, ObjectSource}; @@ -75,10 +76,17 @@ impl HttpSource { #[async_trait] impl ObjectSource for HttpSource { - async fn get(&self, uri: &str) -> super::Result { - let response = self - .client - .get(uri) + async fn get(&self, uri: &str, range: Option>) -> super::Result { + let request = self.client.get(uri); + let request = match range { + None => request, + Some(range) => request.header( + RANGE, + format!("bytes={}-{}", range.start, range.end.saturating_sub(1)), + ), + }; + + let response = request .send() .await .context(UnableToConnectSnafu:: { path: uri.into() })?; @@ -98,3 +106,53 @@ impl ObjectSource for HttpSource { Ok(GetResult::Stream(stream.boxed(), size_bytes)) } } +#[cfg(test)] +mod tests { + + use crate::object_io::ObjectSource; + use crate::HttpSource; + use crate::Result; + use tokio; + + #[tokio::test] + async fn test_full_get_from_http() -> Result<()> { + let parquet_file_path = "https://daft-public-data.s3.us-west-2.amazonaws.com/test_fixtures/parquet_small/0dad4c3f-da0d-49db-90d8-98684571391b-0.parquet"; + let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3"; + + let client = HttpSource::get_client().await?; + let parquet_file = client.get(parquet_file_path, None).await?; + let bytes = parquet_file.bytes().await?; + let all_bytes = bytes.as_ref(); + let checksum = format!("{:x}", md5::compute(all_bytes)); + assert_eq!(checksum, parquet_expected_md5); + + let first_bytes = client + .get_range(parquet_file_path, 0..10) + .await? + .bytes() + .await?; + assert_eq!(first_bytes.len(), 10); + assert_eq!(first_bytes.as_ref(), &all_bytes[..10]); + + let first_bytes = client + .get_range(parquet_file_path, 10..100) + .await? + .bytes() + .await?; + assert_eq!(first_bytes.len(), 90); + assert_eq!(first_bytes.as_ref(), &all_bytes[10..100]); + + let last_bytes = client + .get_range( + parquet_file_path, + (all_bytes.len() - 10)..(all_bytes.len() + 10), + ) + .await? + .bytes() + .await?; + assert_eq!(last_bytes.len(), 10); + assert_eq!(last_bytes.as_ref(), &all_bytes[(all_bytes.len() - 10)..]); + + Ok(()) + } +} diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index ea48e005f6..3a9ebcae15 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -13,7 +13,7 @@ use config::IOConfig; #[cfg(feature = "python")] pub use python::register_modules; -use std::{borrow::Cow, hash::Hash, sync::Arc}; +use std::{borrow::Cow, hash::Hash, ops::Range, sync::Arc}; use futures::{StreamExt, TryStreamExt}; @@ -138,10 +138,14 @@ fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> { } } -async fn single_url_get(input: String, config: &IOConfig) -> Result { +async fn single_url_get( + input: String, + range: Option>, + config: &IOConfig, +) -> Result { let (scheme, path) = parse_url(&input)?; let source = get_source(scheme, config).await?; - source.get(path.as_ref()).await + source.get(path.as_ref(), range).await } async fn single_url_download( @@ -151,7 +155,7 @@ async fn single_url_download( config: Arc, ) -> Result> { let value = if let Some(input) = input { - let response = single_url_get(input, config.as_ref()).await; + let response = single_url_get(input, None, config.as_ref()).await; let res = match response { Ok(res) => res.bytes().await, Err(err) => Err(err), diff --git a/src/daft-io/src/local.rs b/src/daft-io/src/local.rs index 61e749b031..9e5355cd95 100644 --- a/src/daft-io/src/local.rs +++ b/src/daft-io/src/local.rs @@ -1,3 +1,5 @@ +use std::io::{Seek, SeekFrom}; +use std::ops::Range; use std::path::PathBuf; use super::object_io::{GetResult, ObjectSource}; @@ -6,7 +8,7 @@ use async_trait::async_trait; use bytes::Bytes; use snafu::{ResultExt, Snafu}; use std::sync::Arc; -use tokio::io::AsyncReadExt; +use tokio::io::{AsyncReadExt, AsyncSeek, AsyncSeekExt}; use url::ParseError; pub(crate) struct LocalSource {} @@ -23,6 +25,12 @@ enum Error { path: String, source: std::io::Error, }, + #[snafu(display("Unable to seek in file {}: {}", path, source))] + UnableToSeek { + path: String, + source: std::io::Error, + }, + #[snafu(display("Unable to parse URL \"{}\"", url.to_string_lossy()))] InvalidUrl { url: PathBuf, source: ParseError }, @@ -66,13 +74,21 @@ impl LocalSource { } } +pub struct LocalFile { + path: PathBuf, + range: Option>, +} + #[async_trait] impl ObjectSource for LocalSource { - async fn get(&self, uri: &str) -> super::Result { + async fn get(&self, uri: &str, range: Option>) -> super::Result { const TO_STRIP: &str = "file://"; if let Some(p) = uri.strip_prefix(TO_STRIP) { let path = std::path::Path::new(p); - Ok(GetResult::File(path.to_path_buf())) + Ok(GetResult::File(LocalFile { + path: path.to_path_buf(), + range, + })) } else { return Err(Error::InvalidFilePath { path: uri.to_string(), @@ -82,14 +98,103 @@ impl ObjectSource for LocalSource { } } -pub(crate) async fn collect_file(path: &str) -> Result { +pub(crate) async fn collect_file(local_file: LocalFile) -> Result { + let path = &local_file.path; let mut file = tokio::fs::File::open(path) .await - .context(UnableToOpenFileSnafu { path })?; + .context(UnableToOpenFileSnafu { + path: path.to_string_lossy(), + })?; + let mut buf = vec![]; - let _ = file - .read_to_end(&mut buf) - .await - .context(UnableToReadBytesSnafu:: { path: path.into() })?; + + match local_file.range { + None => { + let _ = file + .read_to_end(&mut buf) + .await + .context(UnableToReadBytesSnafu { + path: path.to_string_lossy(), + })?; + } + Some(range) => { + let length = range.end - range.start; + file.seek(SeekFrom::Start(range.start as u64)) + .await + .context(UnableToSeekSnafu { + path: path.to_string_lossy(), + })?; + buf.reserve(length); + file.take(length as u64) + .read_to_end(&mut buf) + .await + .context(UnableToReadBytesSnafu { + path: path.to_string_lossy(), + })?; + } + } Ok(Bytes::from(buf)) } + +#[cfg(test)] + +mod tests { + + use std::io::Write; + + use crate::object_io::ObjectSource; + use crate::Result; + use crate::{HttpSource, LocalSource}; + use tokio; + #[tokio::test] + async fn test_full_get_from_local() -> Result<()> { + let mut file1 = tempfile::NamedTempFile::new().unwrap(); + let parquet_file_path = "https://daft-public-data.s3.us-west-2.amazonaws.com/test_fixtures/parquet_small/0dad4c3f-da0d-49db-90d8-98684571391b-0.parquet"; + let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3"; + + let client = HttpSource::get_client().await?; + let parquet_file = client.get(parquet_file_path, None).await?; + let bytes = parquet_file.bytes().await?; + let all_bytes = bytes.as_ref(); + let checksum = format!("{:x}", md5::compute(all_bytes)); + assert_eq!(checksum, parquet_expected_md5); + file1.write_all(all_bytes).unwrap(); + file1.flush().unwrap(); + + let parquet_file_path = format!("file://{}", file1.path().to_str().unwrap()); + let client = LocalSource::get_client().await?; + + let try_all_bytes = client.get(&parquet_file_path, None).await?.bytes().await?; + assert_eq!(try_all_bytes.len(), all_bytes.len()); + assert_eq!(try_all_bytes.as_ref(), all_bytes); + + let first_bytes = client + .get_range(&parquet_file_path, 0..10) + .await? + .bytes() + .await?; + assert_eq!(first_bytes.len(), 10); + assert_eq!(first_bytes.as_ref(), &all_bytes[..10]); + + let first_bytes = client + .get_range(&parquet_file_path, 10..100) + .await? + .bytes() + .await?; + assert_eq!(first_bytes.len(), 90); + assert_eq!(first_bytes.as_ref(), &all_bytes[10..100]); + + let last_bytes = client + .get_range( + &parquet_file_path, + (all_bytes.len() - 10)..(all_bytes.len() + 10), + ) + .await? + .bytes() + .await?; + assert_eq!(last_bytes.len(), 10); + assert_eq!(last_bytes.as_ref(), &all_bytes[(all_bytes.len() - 10)..]); + + Ok(()) + } +} diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index 1cbb8e5246..d385466cbd 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -1,3 +1,4 @@ +use std::ops::Range; use std::path::PathBuf; use async_trait::async_trait; @@ -5,10 +6,10 @@ use bytes::Bytes; use futures::stream::{BoxStream, Stream}; use futures::StreamExt; -use crate::local::collect_file; +use crate::local::{collect_file, LocalFile}; pub(crate) enum GetResult { - File(PathBuf), + File(LocalFile), Stream(BoxStream<'static, super::Result>, Option), } @@ -39,7 +40,7 @@ impl GetResult { pub async fn bytes(self) -> super::Result { use GetResult::*; match self { - File(path) => collect_file(path.to_str().unwrap()).await, + File(f) => collect_file(f).await, Stream(stream, size) => collect_bytes(stream, size).await, } } @@ -47,5 +48,8 @@ impl GetResult { #[async_trait] pub(crate) trait ObjectSource: Sync + Send { - async fn get(&self, uri: &str) -> super::Result; + async fn get(&self, uri: &str, range: Option>) -> super::Result; + async fn get_range(&self, uri: &str, range: Range) -> super::Result { + self.get(uri, Some(range)).await + } } diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index cb43ffc46a..48cf12f9d5 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -7,7 +7,7 @@ use aws_credential_types::cache::ProvideCachedCredentials; use aws_credential_types::provider::error::CredentialsError; use aws_sig_auth::signer::SigningRequirements; use futures::{StreamExt, TryStreamExt}; -use s3::client::customize::{Operation, Response}; +use s3::client::customize::Response; use s3::config::{Credentials, Region}; use s3::error::{ProvideErrorMetadata, SdkError}; use s3::operation::get_object::GetObjectError; @@ -19,6 +19,7 @@ use aws_sdk_s3 as s3; use aws_sdk_s3::primitives::ByteStreamError; use lazy_static::lazy_static; use std::collections::HashMap; +use std::ops::Range; use std::string::FromUtf8Error; use std::sync::{Arc, RwLock}; @@ -201,7 +202,7 @@ impl S3LikeSource { #[async_trait] impl ObjectSource for S3LikeSource { - async fn get(&self, uri: &str) -> super::Result { + async fn get(&self, uri: &str, range: Option>) -> super::Result { let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; let bucket = match parsed.host_str() { Some(s) => Ok(s), @@ -213,6 +214,14 @@ impl ObjectSource for S3LikeSource { let key = parsed.path(); if let Some(key) = key.strip_prefix('/') { let request = self.s3_client.get_object().bucket(bucket).key(key); + let request = match &range { + None => request, + Some(range) => request.range(format!( + "bytes={}-{}", + range.start, + range.end.saturating_sub(1) + )), + }; let response = if self.anonymous { request @@ -281,7 +290,7 @@ impl ObjectSource for S3LikeSource { let new_client = S3LikeSource::get_client(&new_config).await?; log::warn!("Correct S3 Region of {uri} found: {:?}. Attempting GET in that region with new client", new_client.s3_client.conf().region().map_or("", |v| v.as_ref())); - return new_client.get(uri).await; + return new_client.get(uri, range).await; } _ => Err(UnableToOpenFileSnafu { path: uri } .into_error(SdkError::ServiceError(err)) @@ -298,3 +307,56 @@ impl ObjectSource for S3LikeSource { } } } + +#[cfg(test)] +mod tests { + + use crate::object_io::ObjectSource; + use crate::S3LikeSource; + use crate::{config::S3Config, Result}; + use tokio; + + #[tokio::test] + async fn test_full_get_from_s3() -> Result<()> { + let parquet_file_path = "s3://daft-public-data/test_fixtures/parquet_small/0dad4c3f-da0d-49db-90d8-98684571391b-0.parquet"; + let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3"; + + let mut config = S3Config::default(); + config.anonymous = true; + let client = S3LikeSource::get_client(&config).await?; + let parquet_file = client.get(parquet_file_path, None).await?; + let bytes = parquet_file.bytes().await?; + let all_bytes = bytes.as_ref(); + let checksum = format!("{:x}", md5::compute(all_bytes)); + assert_eq!(checksum, parquet_expected_md5); + + let first_bytes = client + .get_range(parquet_file_path, 0..10) + .await? + .bytes() + .await?; + assert_eq!(first_bytes.len(), 10); + assert_eq!(first_bytes.as_ref(), &all_bytes[..10]); + + let first_bytes = client + .get_range(parquet_file_path, 10..100) + .await? + .bytes() + .await?; + assert_eq!(first_bytes.len(), 90); + assert_eq!(first_bytes.as_ref(), &all_bytes[10..100]); + + let last_bytes = client + .get_range( + parquet_file_path, + (all_bytes.len() - 10)..(all_bytes.len() + 10), + ) + .await? + .bytes() + .await?; + assert_eq!(last_bytes.len(), 10); + assert_eq!(last_bytes.as_ref(), &all_bytes[(all_bytes.len() - 10)..]); + + Ok(()) + } +}