Skip to content

Commit

Permalink
[FEAT] Add a IOConfig.http with initial option for user_agent (#2449)
Browse files Browse the repository at this point in the history
Adds a user-agent header for making API requests via HTTP

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Jun 29, 2024
1 parent cfc6505 commit 8a0aefa
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 7 deletions.
10 changes: 10 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,13 @@ class FileInfos:

def __len__(self) -> int: ...

class HTTPConfig:
"""
I/O configuration for accessing HTTP systems
"""

user_agent: str | None

class S3Config:
"""
I/O configuration for accessing an S3-compatible system.
Expand Down Expand Up @@ -599,12 +606,14 @@ class IOConfig:
s3: S3Config
azure: AzureConfig
gcs: GCSConfig
http: HTTPConfig

def __init__(
self,
s3: S3Config | None = None,
azure: AzureConfig | None = None,
gcs: GCSConfig | None = None,
http: HTTPConfig | None = None,
): ...
@staticmethod
def from_json(input: str) -> IOConfig:
Expand All @@ -618,6 +627,7 @@ class IOConfig:
s3: S3Config | None = None,
azure: AzureConfig | None = None,
gcs: GCSConfig | None = None,
http: HTTPConfig | None = None,
) -> IOConfig:
"""Replaces values if provided, returning a new IOConfig"""
...
Expand Down
2 changes: 2 additions & 0 deletions daft/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
AzureConfig,
GCSConfig,
IOConfig,
HTTPConfig,
S3Config,
S3Credentials,
set_io_pool_num_threads,
Expand Down Expand Up @@ -52,6 +53,7 @@ def _set_linux_cert_paths():
"S3Credentials",
"AzureConfig",
"GCSConfig",
"HTTPConfig",
"set_io_pool_num_threads",
"DataCatalogType",
"DataCatalogTable",
Expand Down
9 changes: 8 additions & 1 deletion src/common/io-config/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ use std::fmt::Formatter;
use serde::Deserialize;
use serde::Serialize;

use crate::HTTPConfig;
use crate::{AzureConfig, GCSConfig, S3Config};
#[derive(Clone, Default, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct IOConfig {
pub s3: S3Config,
pub azure: AzureConfig,
pub gcs: GCSConfig,
pub http: HTTPConfig,
}

impl IOConfig {
Expand All @@ -27,6 +29,10 @@ impl IOConfig {
"GCS config = {{ {} }}",
self.gcs.multiline_display().join(", ")
));
res.push(format!(
"HTTP config = {{ {} }}",
self.http.multiline_display().join(", ")
));
res
}
}
Expand All @@ -38,8 +44,9 @@ impl Display for IOConfig {
"IOConfig:
{}
{}
{}
{}",
self.s3, self.azure, self.gcs
self.s3, self.azure, self.gcs, self.http,
)
}
}
35 changes: 35 additions & 0 deletions src/common/io-config/src/http.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use std::fmt::Display;
use std::fmt::Formatter;

use serde::Deserialize;
use serde::Serialize;

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct HTTPConfig {
pub user_agent: String,
}

impl Default for HTTPConfig {
fn default() -> Self {
HTTPConfig {
user_agent: "daft/0.0.1".to_string(), // NOTE: Ideally we grab the version of Daft, but that requires a dependency on daft-core
}
}
}

impl HTTPConfig {
pub fn multiline_display(&self) -> Vec<String> {
vec![format!("user_agent = {}", self.user_agent)]
}
}

impl Display for HTTPConfig {
fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
write!(
f,
"HTTPConfig
user_agent: {}",
self.user_agent,
)
}
}
4 changes: 3 additions & 1 deletion src/common/io-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ pub mod python;
mod azure;
mod config;
mod gcs;
mod http;
mod s3;

pub use crate::{
azure::AzureConfig, config::IOConfig, gcs::GCSConfig, s3::S3Config, s3::S3Credentials,
azure::AzureConfig, config::IOConfig, gcs::GCSConfig, http::HTTPConfig, s3::S3Config,
s3::S3Credentials,
};
35 changes: 34 additions & 1 deletion src/common/io-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,35 @@ pub struct IOConfig {
pub config: config::IOConfig,
}

/// Create configurations to be used when accessing HTTP URLs.
///
/// Args:
/// user_agent (str, optional): The value for the user-agent header, defaults to "daft/{__version__}" if not provided
///
/// Example:
/// >>> io_config = IOConfig(http=HTTPConfig(user_agent="my_application/0.0.1"))
/// >>> daft.read_parquet("http://some-path", io_config=io_config)
#[derive(Clone, Default)]
#[pyclass]
pub struct HTTPConfig {
pub config: crate::HTTPConfig,
}

#[pymethods]
impl IOConfig {
#[new]
pub fn new(s3: Option<S3Config>, azure: Option<AzureConfig>, gcs: Option<GCSConfig>) -> Self {
pub fn new(
s3: Option<S3Config>,
azure: Option<AzureConfig>,
gcs: Option<GCSConfig>,
http: Option<HTTPConfig>,
) -> Self {
IOConfig {
config: config::IOConfig {
s3: s3.unwrap_or_default().config,
azure: azure.unwrap_or_default().config,
gcs: gcs.unwrap_or_default().config,
http: http.unwrap_or_default().config,
},
}
}
Expand All @@ -145,6 +165,7 @@ impl IOConfig {
s3: Option<S3Config>,
azure: Option<AzureConfig>,
gcs: Option<GCSConfig>,
http: Option<HTTPConfig>,
) -> Self {
IOConfig {
config: config::IOConfig {
Expand All @@ -153,6 +174,9 @@ impl IOConfig {
.map(|azure| azure.config)
.unwrap_or(self.config.azure.clone()),
gcs: gcs.map(|gcs| gcs.config).unwrap_or(self.config.gcs.clone()),
http: http
.map(|http| http.config)
.unwrap_or(self.config.http.clone()),
},
}
}
Expand Down Expand Up @@ -185,6 +209,14 @@ impl IOConfig {
})
}

/// Configuration to be used when accessing Azure URLs
#[getter]
pub fn http(&self) -> PyResult<HTTPConfig> {
Ok(HTTPConfig {
config: self.config.http.clone(),
})
}

#[staticmethod]
pub fn from_json(input: &str) -> PyResult<Self> {
let config: config::IOConfig = serde_json::from_str(input).map_err(DaftError::from)?;
Expand Down Expand Up @@ -826,6 +858,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_class::<AzureConfig>()?;
parent.add_class::<GCSConfig>()?;
parent.add_class::<S3Config>()?;
parent.add_class::<HTTPConfig>()?;
parent.add_class::<S3Credentials>()?;
parent.add_class::<IOConfig>()?;
Ok(())
Expand Down
19 changes: 17 additions & 2 deletions src/daft-io/src/http.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::{num::ParseIntError, ops::Range, string::FromUtf8Error, sync::Arc};

use async_trait::async_trait;
use common_io_config::HTTPConfig;
use futures::{stream::BoxStream, TryStreamExt};

use hyper::header;
use lazy_static::lazy_static;
use regex::Regex;
use reqwest::header::{CONTENT_LENGTH, RANGE};
Expand Down Expand Up @@ -74,6 +76,9 @@ enum Error {
"Unable to parse data as Integer while reading header for file: {path}. {source}"
))]
UnableToParseInteger { path: String, source: ParseIntError },

#[snafu(display("Unable to create HTTP header: {source}"))]
UnableToCreateHeader { source: header::InvalidHeaderValue },
}

/// Finds and retrieves FileMetadata from HTML text
Expand Down Expand Up @@ -162,10 +167,18 @@ impl From<Error> for super::Error {
}

impl HttpSource {
pub async fn get_client() -> super::Result<Arc<Self>> {
pub async fn get_client(config: &HTTPConfig) -> super::Result<Arc<Self>> {
let mut default_headers = header::HeaderMap::new();
default_headers.append(
"user-agent",
header::HeaderValue::from_str(config.user_agent.as_str())
.context(UnableToCreateHeaderSnafu)?,
);

Ok(HttpSource {
client: reqwest::ClientBuilder::default()
.pool_max_idle_per_host(70)
.default_headers(default_headers)
.build()
.context(UnableToCreateClientSnafu)?,
}
Expand Down Expand Up @@ -327,6 +340,8 @@ impl ObjectSource for HttpSource {
#[cfg(test)]
mod tests {

use std::default;

use crate::object_io::ObjectSource;
use crate::HttpSource;
use crate::Result;
Expand All @@ -336,7 +351,7 @@ mod tests {
let parquet_file_path = "https://daft-public-data.s3.us-west-2.amazonaws.com/test_fixtures/parquet_small/0dad4c3f-da0d-49db-90d8-98684571391b-0.parquet";
let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3";

let client = HttpSource::get_client().await?;
let client = HttpSource::get_client(&default::Default::default()).await?;
let parquet_file = client.get(parquet_file_path, None, None).await?;
let bytes = parquet_file.bytes().await?;
let all_bytes = bytes.as_ref();
Expand Down
4 changes: 3 additions & 1 deletion src/daft-io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ impl IOClient {

let new_source = match source_type {
SourceType::File => LocalSource::get_client().await? as Arc<dyn ObjectSource>,
SourceType::Http => HttpSource::get_client().await? as Arc<dyn ObjectSource>,
SourceType::Http => {
HttpSource::get_client(&self.config.http).await? as Arc<dyn ObjectSource>
}
SourceType::S3 => {
S3LikeSource::get_client(&self.config.s3).await? as Arc<dyn ObjectSource>
}
Expand Down
5 changes: 4 additions & 1 deletion src/daft-io/src/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,11 @@ pub(crate) async fn collect_file(local_file: LocalFile) -> Result<Bytes> {
#[cfg(test)]

mod tests {
use std::default;
use std::io::Write;

use common_io_config::HTTPConfig;

use crate::object_io::{FileMetadata, FileType, ObjectSource};
use crate::Result;
use crate::{HttpSource, LocalSource};
Expand All @@ -344,7 +347,7 @@ mod tests {
let parquet_file_path = "https://daft-public-data.s3.us-west-2.amazonaws.com/test_fixtures/parquet_small/0dad4c3f-da0d-49db-90d8-98684571391b-0.parquet";
let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3";

let client = HttpSource::get_client().await?;
let client = HttpSource::get_client(&default::Default::default()).await?;
let parquet_file = client.get(parquet_file_path, None, None).await?;
let bytes = parquet_file.bytes().await?;
let all_bytes = bytes.as_ref();
Expand Down

0 comments on commit 8a0aefa

Please sign in to comment.