diff --git a/daft/daft.pyi b/daft/daft.pyi index f495af1e56..63acd596a2 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -427,6 +427,13 @@ class FileInfos: def __len__(self) -> int: ... +class HTTPConfig: + """ + I/O configuration for accessing HTTP systems + """ + + user_agent: str | None + class S3Config: """ I/O configuration for accessing an S3-compatible system. @@ -599,12 +606,14 @@ class IOConfig: s3: S3Config azure: AzureConfig gcs: GCSConfig + http: HTTPConfig def __init__( self, s3: S3Config | None = None, azure: AzureConfig | None = None, gcs: GCSConfig | None = None, + http: HTTPConfig | None = None, ): ... @staticmethod def from_json(input: str) -> IOConfig: @@ -618,6 +627,7 @@ class IOConfig: s3: S3Config | None = None, azure: AzureConfig | None = None, gcs: GCSConfig | None = None, + http: HTTPConfig | None = None, ) -> IOConfig: """Replaces values if provided, returning a new IOConfig""" ... diff --git a/daft/io/__init__.py b/daft/io/__init__.py index 43ab01aac6..a7d1f1a868 100644 --- a/daft/io/__init__.py +++ b/daft/io/__init__.py @@ -6,6 +6,7 @@ AzureConfig, GCSConfig, IOConfig, + HTTPConfig, S3Config, S3Credentials, set_io_pool_num_threads, @@ -52,6 +53,7 @@ def _set_linux_cert_paths(): "S3Credentials", "AzureConfig", "GCSConfig", + "HTTPConfig", "set_io_pool_num_threads", "DataCatalogType", "DataCatalogTable", diff --git a/src/common/io-config/src/config.rs b/src/common/io-config/src/config.rs index e2b1d4e142..94f97e1cad 100644 --- a/src/common/io-config/src/config.rs +++ b/src/common/io-config/src/config.rs @@ -4,12 +4,14 @@ use std::fmt::Formatter; use serde::Deserialize; use serde::Serialize; +use crate::HTTPConfig; use crate::{AzureConfig, GCSConfig, S3Config}; #[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] pub struct IOConfig { pub s3: S3Config, pub azure: AzureConfig, pub gcs: GCSConfig, + pub http: HTTPConfig, } impl IOConfig { @@ -27,6 +29,10 @@ impl IOConfig { "GCS config = {{ {} }}", self.gcs.multiline_display().join(", ") )); + res.push(format!( + "HTTP config = {{ {} }}", + self.http.multiline_display().join(", ") + )); res } } @@ -38,8 +44,9 @@ impl Display for IOConfig { "IOConfig: {} {} +{} {}", - self.s3, self.azure, self.gcs + self.s3, self.azure, self.gcs, self.http, ) } } diff --git a/src/common/io-config/src/http.rs b/src/common/io-config/src/http.rs new file mode 100644 index 0000000000..c619c0b7d8 --- /dev/null +++ b/src/common/io-config/src/http.rs @@ -0,0 +1,35 @@ +use std::fmt::Display; +use std::fmt::Formatter; + +use serde::Deserialize; +use serde::Serialize; + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct HTTPConfig { + pub user_agent: String, +} + +impl Default for HTTPConfig { + fn default() -> Self { + HTTPConfig { + user_agent: "daft/0.0.1".to_string(), // NOTE: Ideally we grab the version of Daft, but that requires a dependency on daft-core + } + } +} + +impl HTTPConfig { + pub fn multiline_display(&self) -> Vec { + vec![format!("user_agent = {}", self.user_agent)] + } +} + +impl Display for HTTPConfig { + fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + write!( + f, + "HTTPConfig + user_agent: {}", + self.user_agent, + ) + } +} diff --git a/src/common/io-config/src/lib.rs b/src/common/io-config/src/lib.rs index d0952541b4..58c4f23978 100644 --- a/src/common/io-config/src/lib.rs +++ b/src/common/io-config/src/lib.rs @@ -4,8 +4,10 @@ pub mod python; mod azure; mod config; mod gcs; +mod http; mod s3; pub use crate::{ - azure::AzureConfig, config::IOConfig, gcs::GCSConfig, s3::S3Config, s3::S3Credentials, + azure::AzureConfig, config::IOConfig, gcs::GCSConfig, http::HTTPConfig, s3::S3Config, + s3::S3Credentials, }; diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index 6817ce26ca..dce9459875 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -127,15 +127,35 @@ pub struct IOConfig { pub config: config::IOConfig, } +/// Create configurations to be used when accessing HTTP URLs. +/// +/// Args: +/// user_agent (str, optional): The value for the user-agent header, defaults to "daft/{__version__}" if not provided +/// +/// Example: +/// >>> io_config = IOConfig(http=HTTPConfig(user_agent="my_application/0.0.1")) +/// >>> daft.read_parquet("http://some-path", io_config=io_config) +#[derive(Clone, Default)] +#[pyclass] +pub struct HTTPConfig { + pub config: crate::HTTPConfig, +} + #[pymethods] impl IOConfig { #[new] - pub fn new(s3: Option, azure: Option, gcs: Option) -> Self { + pub fn new( + s3: Option, + azure: Option, + gcs: Option, + http: Option, + ) -> Self { IOConfig { config: config::IOConfig { s3: s3.unwrap_or_default().config, azure: azure.unwrap_or_default().config, gcs: gcs.unwrap_or_default().config, + http: http.unwrap_or_default().config, }, } } @@ -145,6 +165,7 @@ impl IOConfig { s3: Option, azure: Option, gcs: Option, + http: Option, ) -> Self { IOConfig { config: config::IOConfig { @@ -153,6 +174,9 @@ impl IOConfig { .map(|azure| azure.config) .unwrap_or(self.config.azure.clone()), gcs: gcs.map(|gcs| gcs.config).unwrap_or(self.config.gcs.clone()), + http: http + .map(|http| http.config) + .unwrap_or(self.config.http.clone()), }, } } @@ -185,6 +209,14 @@ impl IOConfig { }) } + /// Configuration to be used when accessing Azure URLs + #[getter] + pub fn http(&self) -> PyResult { + Ok(HTTPConfig { + config: self.config.http.clone(), + }) + } + #[staticmethod] pub fn from_json(input: &str) -> PyResult { let config: config::IOConfig = serde_json::from_str(input).map_err(DaftError::from)?; @@ -826,6 +858,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; + parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; Ok(()) diff --git a/src/daft-io/src/http.rs b/src/daft-io/src/http.rs index f455050ede..2fb2cf643d 100644 --- a/src/daft-io/src/http.rs +++ b/src/daft-io/src/http.rs @@ -1,8 +1,10 @@ use std::{num::ParseIntError, ops::Range, string::FromUtf8Error, sync::Arc}; use async_trait::async_trait; +use common_io_config::HTTPConfig; use futures::{stream::BoxStream, TryStreamExt}; +use hyper::header; use lazy_static::lazy_static; use regex::Regex; use reqwest::header::{CONTENT_LENGTH, RANGE}; @@ -74,6 +76,9 @@ enum Error { "Unable to parse data as Integer while reading header for file: {path}. {source}" ))] UnableToParseInteger { path: String, source: ParseIntError }, + + #[snafu(display("Unable to create HTTP header: {source}"))] + UnableToCreateHeader { source: header::InvalidHeaderValue }, } /// Finds and retrieves FileMetadata from HTML text @@ -162,10 +167,18 @@ impl From for super::Error { } impl HttpSource { - pub async fn get_client() -> super::Result> { + pub async fn get_client(config: &HTTPConfig) -> super::Result> { + let mut default_headers = header::HeaderMap::new(); + default_headers.append( + "user-agent", + header::HeaderValue::from_str(config.user_agent.as_str()) + .context(UnableToCreateHeaderSnafu)?, + ); + Ok(HttpSource { client: reqwest::ClientBuilder::default() .pool_max_idle_per_host(70) + .default_headers(default_headers) .build() .context(UnableToCreateClientSnafu)?, } @@ -327,6 +340,8 @@ impl ObjectSource for HttpSource { #[cfg(test)] mod tests { + use std::default; + use crate::object_io::ObjectSource; use crate::HttpSource; use crate::Result; @@ -336,7 +351,7 @@ mod tests { let parquet_file_path = "https://daft-public-data.s3.us-west-2.amazonaws.com/test_fixtures/parquet_small/0dad4c3f-da0d-49db-90d8-98684571391b-0.parquet"; let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3"; - let client = HttpSource::get_client().await?; + let client = HttpSource::get_client(&default::Default::default()).await?; let parquet_file = client.get(parquet_file_path, None, None).await?; let bytes = parquet_file.bytes().await?; let all_bytes = bytes.as_ref(); diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 087f280888..a6726e278f 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -191,7 +191,9 @@ impl IOClient { let new_source = match source_type { SourceType::File => LocalSource::get_client().await? as Arc, - SourceType::Http => HttpSource::get_client().await? as Arc, + SourceType::Http => { + HttpSource::get_client(&self.config.http).await? as Arc + } SourceType::S3 => { S3LikeSource::get_client(&self.config.s3).await? as Arc } diff --git a/src/daft-io/src/local.rs b/src/daft-io/src/local.rs index 72fd001b0b..58eab5d7be 100644 --- a/src/daft-io/src/local.rs +++ b/src/daft-io/src/local.rs @@ -332,8 +332,11 @@ pub(crate) async fn collect_file(local_file: LocalFile) -> Result { #[cfg(test)] mod tests { + use std::default; use std::io::Write; + use common_io_config::HTTPConfig; + use crate::object_io::{FileMetadata, FileType, ObjectSource}; use crate::Result; use crate::{HttpSource, LocalSource}; @@ -344,7 +347,7 @@ mod tests { let parquet_file_path = "https://daft-public-data.s3.us-west-2.amazonaws.com/test_fixtures/parquet_small/0dad4c3f-da0d-49db-90d8-98684571391b-0.parquet"; let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3"; - let client = HttpSource::get_client().await?; + let client = HttpSource::get_client(&default::Default::default()).await?; let parquet_file = client.get(parquet_file_path, None, None).await?; let bytes = parquet_file.bytes().await?; let all_bytes = bytes.as_ref();