Skip to content

Commit

Permalink
[FEAT] Ranged Get Native Downloader (#1113)
Browse files Browse the repository at this point in the history
* Implements Range gets for Native Downloader
  • Loading branch information
samster25 committed Jul 1, 2023
1 parent 15b99a1 commit 5eb5318
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 25 deletions.
22 changes: 22 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions src/daft-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ default-features = false
features = ["stream", "rustls-tls"]
version = "0.11.18"

[dev-dependencies]
md5 = "0.7.0"
tempfile = "3.6.0"

[features]
default = ["python"]
python = ["dep:pyo3", "dep:pyo3-log", "common-error/python", "daft-core/python"]
Expand Down
68 changes: 63 additions & 5 deletions src/daft-io/src/http.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::sync::Arc;
use std::{ops::Range, sync::Arc};

use async_trait::async_trait;
use futures::{StreamExt, TryStreamExt};
use lazy_static::lazy_static;
use reqwest::header::RANGE;
use snafu::{IntoError, ResultExt, Snafu};

use super::object_io::{GetResult, ObjectSource};
Expand Down Expand Up @@ -75,10 +76,17 @@ impl HttpSource {

#[async_trait]
impl ObjectSource for HttpSource {
async fn get(&self, uri: &str) -> super::Result<GetResult> {
let response = self
.client
.get(uri)
async fn get(&self, uri: &str, range: Option<Range<usize>>) -> super::Result<GetResult> {
let request = self.client.get(uri);
let request = match range {
None => request,
Some(range) => request.header(
RANGE,
format!("bytes={}-{}", range.start, range.end.saturating_sub(1)),
),
};

let response = request
.send()
.await
.context(UnableToConnectSnafu::<String> { path: uri.into() })?;
Expand All @@ -98,3 +106,53 @@ impl ObjectSource for HttpSource {
Ok(GetResult::Stream(stream.boxed(), size_bytes))
}
}
#[cfg(test)]
mod tests {

use crate::object_io::ObjectSource;
use crate::HttpSource;
use crate::Result;
use tokio;

#[tokio::test]
async fn test_full_get_from_http() -> Result<()> {
let parquet_file_path = "https://daft-public-data.s3.us-west-2.amazonaws.com/test_fixtures/parquet_small/0dad4c3f-da0d-49db-90d8-98684571391b-0.parquet";
let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3";

let client = HttpSource::get_client().await?;
let parquet_file = client.get(parquet_file_path, None).await?;
let bytes = parquet_file.bytes().await?;
let all_bytes = bytes.as_ref();
let checksum = format!("{:x}", md5::compute(all_bytes));
assert_eq!(checksum, parquet_expected_md5);

let first_bytes = client
.get_range(parquet_file_path, 0..10)
.await?
.bytes()
.await?;
assert_eq!(first_bytes.len(), 10);
assert_eq!(first_bytes.as_ref(), &all_bytes[..10]);

let first_bytes = client
.get_range(parquet_file_path, 10..100)
.await?
.bytes()
.await?;
assert_eq!(first_bytes.len(), 90);
assert_eq!(first_bytes.as_ref(), &all_bytes[10..100]);

let last_bytes = client
.get_range(
parquet_file_path,
(all_bytes.len() - 10)..(all_bytes.len() + 10),
)
.await?
.bytes()
.await?;
assert_eq!(last_bytes.len(), 10);
assert_eq!(last_bytes.as_ref(), &all_bytes[(all_bytes.len() - 10)..]);

Ok(())
}
}
12 changes: 8 additions & 4 deletions src/daft-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use config::IOConfig;
#[cfg(feature = "python")]
pub use python::register_modules;

use std::{borrow::Cow, hash::Hash, sync::Arc};
use std::{borrow::Cow, hash::Hash, ops::Range, sync::Arc};

use futures::{StreamExt, TryStreamExt};

Expand Down Expand Up @@ -138,10 +138,14 @@ fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> {
}
}

async fn single_url_get(input: String, config: &IOConfig) -> Result<GetResult> {
async fn single_url_get(
input: String,
range: Option<Range<usize>>,
config: &IOConfig,
) -> Result<GetResult> {
let (scheme, path) = parse_url(&input)?;
let source = get_source(scheme, config).await?;
source.get(path.as_ref()).await
source.get(path.as_ref(), range).await
}

async fn single_url_download(
Expand All @@ -151,7 +155,7 @@ async fn single_url_download(
config: Arc<IOConfig>,
) -> Result<Option<bytes::Bytes>> {
let value = if let Some(input) = input {
let response = single_url_get(input, config.as_ref()).await;
let response = single_url_get(input, None, config.as_ref()).await;
let res = match response {
Ok(res) => res.bytes().await,
Err(err) => Err(err),
Expand Down
123 changes: 114 additions & 9 deletions src/daft-io/src/local.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::io::{Seek, SeekFrom};
use std::ops::Range;
use std::path::PathBuf;

use super::object_io::{GetResult, ObjectSource};
Expand All @@ -6,7 +8,7 @@ use async_trait::async_trait;
use bytes::Bytes;
use snafu::{ResultExt, Snafu};
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::{AsyncReadExt, AsyncSeek, AsyncSeekExt};
use url::ParseError;
pub(crate) struct LocalSource {}

Expand All @@ -23,6 +25,12 @@ enum Error {
path: String,
source: std::io::Error,
},
#[snafu(display("Unable to seek in file {}: {}", path, source))]
UnableToSeek {
path: String,
source: std::io::Error,
},

#[snafu(display("Unable to parse URL \"{}\"", url.to_string_lossy()))]
InvalidUrl { url: PathBuf, source: ParseError },

Expand Down Expand Up @@ -66,13 +74,21 @@ impl LocalSource {
}
}

pub struct LocalFile {
path: PathBuf,
range: Option<Range<usize>>,
}

#[async_trait]
impl ObjectSource for LocalSource {
async fn get(&self, uri: &str) -> super::Result<GetResult> {
async fn get(&self, uri: &str, range: Option<Range<usize>>) -> super::Result<GetResult> {
const TO_STRIP: &str = "file://";
if let Some(p) = uri.strip_prefix(TO_STRIP) {
let path = std::path::Path::new(p);
Ok(GetResult::File(path.to_path_buf()))
Ok(GetResult::File(LocalFile {
path: path.to_path_buf(),
range,
}))
} else {
return Err(Error::InvalidFilePath {
path: uri.to_string(),
Expand All @@ -82,14 +98,103 @@ impl ObjectSource for LocalSource {
}
}

pub(crate) async fn collect_file(path: &str) -> Result<Bytes> {
pub(crate) async fn collect_file(local_file: LocalFile) -> Result<Bytes> {
let path = &local_file.path;
let mut file = tokio::fs::File::open(path)
.await
.context(UnableToOpenFileSnafu { path })?;
.context(UnableToOpenFileSnafu {
path: path.to_string_lossy(),
})?;

let mut buf = vec![];
let _ = file
.read_to_end(&mut buf)
.await
.context(UnableToReadBytesSnafu::<String> { path: path.into() })?;

match local_file.range {
None => {
let _ = file
.read_to_end(&mut buf)
.await
.context(UnableToReadBytesSnafu {
path: path.to_string_lossy(),
})?;
}
Some(range) => {
let length = range.end - range.start;
file.seek(SeekFrom::Start(range.start as u64))
.await
.context(UnableToSeekSnafu {
path: path.to_string_lossy(),
})?;
buf.reserve(length);
file.take(length as u64)
.read_to_end(&mut buf)
.await
.context(UnableToReadBytesSnafu {
path: path.to_string_lossy(),
})?;
}
}
Ok(Bytes::from(buf))
}

#[cfg(test)]

mod tests {

use std::io::Write;

use crate::object_io::ObjectSource;
use crate::Result;
use crate::{HttpSource, LocalSource};
use tokio;
#[tokio::test]
async fn test_full_get_from_local() -> Result<()> {
let mut file1 = tempfile::NamedTempFile::new().unwrap();
let parquet_file_path = "https://daft-public-data.s3.us-west-2.amazonaws.com/test_fixtures/parquet_small/0dad4c3f-da0d-49db-90d8-98684571391b-0.parquet";
let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3";

let client = HttpSource::get_client().await?;
let parquet_file = client.get(parquet_file_path, None).await?;
let bytes = parquet_file.bytes().await?;
let all_bytes = bytes.as_ref();
let checksum = format!("{:x}", md5::compute(all_bytes));
assert_eq!(checksum, parquet_expected_md5);
file1.write_all(all_bytes).unwrap();
file1.flush().unwrap();

let parquet_file_path = format!("file://{}", file1.path().to_str().unwrap());
let client = LocalSource::get_client().await?;

let try_all_bytes = client.get(&parquet_file_path, None).await?.bytes().await?;
assert_eq!(try_all_bytes.len(), all_bytes.len());
assert_eq!(try_all_bytes.as_ref(), all_bytes);

let first_bytes = client
.get_range(&parquet_file_path, 0..10)
.await?
.bytes()
.await?;
assert_eq!(first_bytes.len(), 10);
assert_eq!(first_bytes.as_ref(), &all_bytes[..10]);

let first_bytes = client
.get_range(&parquet_file_path, 10..100)
.await?
.bytes()
.await?;
assert_eq!(first_bytes.len(), 90);
assert_eq!(first_bytes.as_ref(), &all_bytes[10..100]);

let last_bytes = client
.get_range(
&parquet_file_path,
(all_bytes.len() - 10)..(all_bytes.len() + 10),
)
.await?
.bytes()
.await?;
assert_eq!(last_bytes.len(), 10);
assert_eq!(last_bytes.as_ref(), &all_bytes[(all_bytes.len() - 10)..]);

Ok(())
}
}
12 changes: 8 additions & 4 deletions src/daft-io/src/object_io.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use std::ops::Range;
use std::path::PathBuf;

use async_trait::async_trait;
use bytes::Bytes;
use futures::stream::{BoxStream, Stream};
use futures::StreamExt;

use crate::local::collect_file;
use crate::local::{collect_file, LocalFile};

pub(crate) enum GetResult {
File(PathBuf),
File(LocalFile),
Stream(BoxStream<'static, super::Result<Bytes>>, Option<usize>),
}

Expand Down Expand Up @@ -39,13 +40,16 @@ impl GetResult {
pub async fn bytes(self) -> super::Result<Bytes> {
use GetResult::*;
match self {
File(path) => collect_file(path.to_str().unwrap()).await,
File(f) => collect_file(f).await,
Stream(stream, size) => collect_bytes(stream, size).await,
}
}
}

#[async_trait]
pub(crate) trait ObjectSource: Sync + Send {
async fn get(&self, uri: &str) -> super::Result<GetResult>;
async fn get(&self, uri: &str, range: Option<Range<usize>>) -> super::Result<GetResult>;
async fn get_range(&self, uri: &str, range: Range<usize>) -> super::Result<GetResult> {
self.get(uri, Some(range)).await
}
}
Loading

0 comments on commit 5eb5318

Please sign in to comment.