diff --git a/Cargo.lock b/Cargo.lock index 60b98d0602..6c3af7a807 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -120,6 +120,19 @@ dependencies = [ "strength_reduce", ] +[[package]] +name = "async-compat" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b48b4ff0c2026db683dea961cd8ea874737f56cffca86fa84415eaddc51c00d" +dependencies = [ + "futures-core", + "futures-io", + "once_cell", + "pin-project-lite", + "tokio", +] + [[package]] name = "async-recursion" version = "1.0.4" @@ -771,6 +784,49 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if", +] + [[package]] name = "crypto-common" version = "0.1.6" @@ -891,6 +947,9 @@ name = "daft-parquet" version = "0.1.0" dependencies = [ "arrow2", + "async-compat", + "async-stream", + "bytes", "common-error", "daft-core", "daft-io", @@ -900,8 +959,11 @@ dependencies = [ "parquet2", "pyo3", "pyo3-log", + "rayon", "snafu", "tokio", + "tokio-stream", + "tokio-util", ] [[package]] @@ -2134,6 +2196,28 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" +[[package]] +name = "rayon" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "num_cpus", +] + [[package]] name = "redox_syscall" version = "0.2.16" diff --git a/Cargo.toml b/Cargo.toml index 3809b891d3..5b76547f9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ members = [ ] [workspace.dependencies] +bytes = "1.4.0" futures = "0.3.28" html-escape = "0.2.13" num-derive = "0.3.3" diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index e2807f8784..fc393cdde7 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -6,7 +6,7 @@ aws-credential-types = {version = "0.55.3", features = ["hardcoded-credentials"] aws-sdk-s3 = "0.28.0" aws-sig-auth = "0.55.3" aws-sigv4 = "0.55.3" -bytes = "1.4.0" +bytes = {workspace = true} common-error = {path = "../common/error", default-features = false} daft-core = {path = "../daft-core", default-features = false} futures = {workspace = true} diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 577c765c71..145e1f31ca 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -86,6 +86,12 @@ impl From for DaftError { } } +impl From for std::io::Error { + fn from(err: Error) -> std::io::Error { + std::io::Error::new(std::io::ErrorKind::Other, err) + } +} + type Result = std::result::Result; #[derive(Default)] diff --git a/src/daft-parquet/Cargo.toml b/src/daft-parquet/Cargo.toml index 5cc55a34a2..34aad9c353 100644 --- a/src/daft-parquet/Cargo.toml +++ b/src/daft-parquet/Cargo.toml @@ -1,5 +1,8 @@ [dependencies] arrow2 = {workspace = true, features = ["io_parquet", "io_parquet_compression"]} +async-compat = "0.2.1" +async-stream = "0.3.5" +bytes = {workspace = true} common-error = {path = "../common/error", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-io = {path = "../daft-io", default-features = false} @@ -9,8 +12,11 @@ log = {workspace = true} parquet2 = "0.17.2" pyo3 = {workspace = true, optional = true} pyo3-log = {workspace = true, optional = true} +rayon = "1.7.0" snafu = {workspace = true} tokio = {workspace = true} +tokio-stream = "0.1.14" +tokio-util = "0.7.8" [features] default = ["python"] diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index 59db8045ec..7e303214b7 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -5,17 +5,21 @@ use common_error::DaftResult; use daft_core::{utils::arrow::cast_array_for_daft_if_needed, Series}; use daft_io::IOClient; use daft_table::Table; -use parquet2::read::{BasicDecompressor, PageReader}; +use futures::{future::try_join_all, StreamExt}; +use parquet2::{ + page::{CompressedPage, Page}, + read::get_page_stream_from_column_start, + FallibleStreamingIterator, +}; use snafu::ResultExt; use crate::{ metadata::read_parquet_metadata, - read_planner::{self, CoalescePass, RangesContainer, ReadPlanner, SplitLargeRequestPass}, - UnableToConvertParquetPagesToArrowSnafu, UnableToOpenFileSnafu, + read_planner::{CoalescePass, RangesContainer, ReadPlanner, SplitLargeRequestPass}, + JoinSnafu, OneShotRecvSnafu, UnableToCreateParquetPageStreamSnafu, UnableToOpenFileSnafu, UnableToParseSchemaFromMetadataSnafu, }; use arrow2::io::parquet::read::column_iter_to_arrays; - pub(crate) struct ParquetReaderBuilder { uri: String, metadata: parquet2::metadata::FileMetaData, @@ -24,6 +28,61 @@ pub(crate) struct ParquetReaderBuilder { row_start_offset: usize, num_rows: usize, } +use parquet2::read::decompress; + +fn streaming_decompression>>( + input: S, +) -> impl futures::Stream> { + async_stream::stream! { + for await compressed_page in input { + let compressed_page = compressed_page?; + let (send, recv) = tokio::sync::oneshot::channel(); + + rayon::spawn(move || { + let mut buffer = vec![]; + let _ = send.send(decompress(compressed_page, &mut buffer)); + + }); + yield recv.await.expect("panic while decompressing page"); + } + } +} +pub struct VecIterator { + index: i64, + src: Vec>, +} + +impl VecIterator { + pub fn new(src: Vec>) -> Self { + VecIterator { index: -1, src } + } +} + +impl FallibleStreamingIterator for VecIterator { + type Error = parquet2::error::Error; + type Item = Page; + fn advance(&mut self) -> Result<(), Self::Error> { + self.index += 1; + if (self.index as usize) < self.src.len() { + if let Err(value) = self.src.get(self.index as usize).unwrap() { + return Err(value.clone()); + } + } + Ok(()) + } + + fn get(&self) -> Option<&Self::Item> { + if self.index < 0 || (self.index as usize) >= self.src.len() { + return None; + } + + if let Ok(val) = self.src.get(self.index as usize).unwrap() { + Some(val) + } else { + None + } + } +} impl ParquetReaderBuilder { pub async fn from_uri(uri: &str, io_client: Arc) -> super::Result { @@ -125,6 +184,7 @@ impl ParquetReaderBuilder { } } +#[derive(Copy, Clone)] struct RowGroupRange { row_group_index: usize, start: usize, @@ -133,9 +193,9 @@ struct RowGroupRange { pub(crate) struct ParquetFileReader { uri: String, - metadata: parquet2::metadata::FileMetaData, + metadata: Arc, arrow_schema: arrow2::datatypes::Schema, - row_ranges: Vec, + row_ranges: Arc>, } impl ParquetFileReader { @@ -147,9 +207,9 @@ impl ParquetFileReader { ) -> super::Result { Ok(ParquetFileReader { uri, - metadata, + metadata: Arc::new(metadata), arrow_schema, - row_ranges, + row_ranges: Arc::new(row_ranges), }) } @@ -185,7 +245,7 @@ impl ParquetFileReader { Ok(read_planner) } - pub async fn prebuffer_ranges(&self, io_client: Arc) -> DaftResult { + pub fn prebuffer_ranges(&self, io_client: Arc) -> DaftResult> { let mut read_planner = self.naive_read_plan()?; // TODO(sammy) these values should be populated by io_client read_planner.add_pass(Box::new(SplitLargeRequestPass { @@ -199,104 +259,154 @@ impl ParquetFileReader { })); read_planner.run_passes()?; - read_planner.collect(io_client).await + read_planner.collect(io_client) } - pub fn read_from_ranges(self, ranges: RangesContainer) -> DaftResult { - let all_series = self + pub async fn read_from_ranges(self, ranges: Arc) -> DaftResult
{ + let metadata = self.metadata; + let all_handles = self .arrow_schema .fields .iter() .map(|field| { - let field_series = self - .row_ranges + let owned_row_ranges = self.row_ranges.clone(); + let field_handles = owned_row_ranges .iter() .map(|row_range| { - let rg = self - .metadata - .row_groups - .get(row_range.row_group_index) - .expect("Row Group index should be in bounds"); - let columns = rg.columns(); - let field_name = &field.name; - let filtered_cols = columns - .iter() - .filter(|x| &x.descriptor().path_in_schema[0] == field_name) - .collect::>(); - - let mut decompressed_iters = Vec::with_capacity(filtered_cols.len()); - let mut ptypes = Vec::with_capacity(filtered_cols.len()); - - for col in filtered_cols { - let (start, len) = col.byte_range(); - let end = start + len; - - // should stream this instead - let range_reader: read_planner::MultiRead<'_> = - ranges.get_range_reader(start as usize..end as usize)?; - let pages = PageReader::new( - range_reader, - col, - Arc::new(|_, _| true), - vec![], - 4 * 1024 * 1024, - ); - - decompressed_iters.push(BasicDecompressor::new(pages, vec![])); - - ptypes.push(&col.descriptor().descriptor.primitive_type); - } - - let arr_iter = column_iter_to_arrays( - decompressed_iters, - ptypes, - field.clone(), - Some(4096), - rg.num_rows().min(row_range.start + row_range.num_rows), - ) - .context( - UnableToConvertParquetPagesToArrowSnafu:: { - path: self.uri.clone(), - }, - )?; - - let mut all_arrays = vec![]; - - let mut curr_index = 0; - - for arr in arr_iter { - let arr = arr?; - - if (curr_index + arr.len()) < row_range.start { - // throw arrays less than what we need - curr_index += arr.len(); - continue; - } else if curr_index < row_range.start { - let offset = row_range.start.saturating_sub(curr_index); - all_arrays.push(arr.sliced(offset, arr.len() - offset)); - curr_index += arr.len(); - } else { - curr_index += arr.len(); - all_arrays.push(arr); + let row_range = *row_range; + let field = field.clone(); + let owned_uri = self.uri.clone(); + let ranges = ranges.clone(); + let owned_metadata = metadata.clone(); + + let handle = tokio::task::spawn(async move { + let rg = owned_metadata + .row_groups + .get(row_range.row_group_index) + .expect("Row Group index should be in bounds"); + let num_rows = rg.num_rows().min(row_range.start + row_range.num_rows); + let columns = rg.columns(); + let field_name = &field.name; + let filtered_cols = columns + .iter() + .filter(|x| &x.descriptor().path_in_schema[0] == field_name) + .collect::>(); + + let mut decompressed_pages = Vec::with_capacity(filtered_cols.len()); + let mut ptypes = Vec::with_capacity(filtered_cols.len()); + + for col in filtered_cols { + let (start, len) = col.byte_range(); + let end = start + len; + + let range_reader = + ranges.get_range_reader(start as usize..end as usize)?; + + let mut pinned = Box::pin(range_reader); + let compressed_page_stream = get_page_stream_from_column_start( + col, + &mut pinned, + vec![], + Arc::new(|_, _| true), + 4 * 1024 * 1024, + ) + .await + .with_context( + |_| UnableToCreateParquetPageStreamSnafu:: { + path: owned_uri.clone(), + }, + )?; + let page_stream = streaming_decompression(compressed_page_stream); + + decompressed_pages.push(page_stream.collect::>().await); + + ptypes.push(col.descriptor().descriptor.primitive_type.clone()); } - } - - all_arrays - .into_iter() - .map(|a| { - Series::try_from(( - field.name.as_str(), - cast_array_for_daft_if_needed(a), - )) - }) - .collect::>>() + + let decompressed_iters = decompressed_pages + .into_iter() + .map(VecIterator::new) + .collect(); + + let (send, recv) = tokio::sync::oneshot::channel(); + rayon::spawn(move || { + let arr_iter = column_iter_to_arrays( + decompressed_iters, + ptypes.iter().collect(), + field.clone(), + Some(2048), + num_rows, + ); + + let ser = (|| { + let mut all_arrays = vec![]; + let mut curr_index = 0; + + for arr in arr_iter? { + let arr = arr?; + if (curr_index + arr.len()) < row_range.start { + // throw arrays less than what we need + curr_index += arr.len(); + continue; + } else if curr_index < row_range.start { + let offset = row_range.start.saturating_sub(curr_index); + all_arrays.push(arr.sliced(offset, arr.len() - offset)); + curr_index += arr.len(); + } else { + curr_index += arr.len(); + all_arrays.push(arr); + } + } + + all_arrays + .into_iter() + .map(|a| { + Series::try_from(( + field.name.as_str(), + cast_array_for_daft_if_needed(a), + )) + }) + .collect::>>() + })(); + + let _ = send.send(ser); + }); + recv.await.context(OneShotRecvSnafu {})? + }); + Ok(handle) }) .collect::>>()?; - - Series::concat(&field_series.iter().flatten().collect::>()) + let owned_uri = self.uri.clone(); + let concated_handle = tokio::task::spawn(async move { + let series_to_concat = + try_join_all(field_handles.into_iter()) + .await + .context(JoinSnafu { + path: owned_uri.to_string(), + })?; + let series_to_concat = series_to_concat + .into_iter() + .collect::>>()?; + + let (send, recv) = tokio::sync::oneshot::channel(); + rayon::spawn(move || { + let concated = + Series::concat(&series_to_concat.iter().flatten().collect::>()); + let _ = send.send(concated); + }); + recv.await.context(OneShotRecvSnafu {})? + }); + Ok(concated_handle) }) .collect::>>()?; + let all_series = try_join_all(all_handles.into_iter()) + .await + .context(JoinSnafu { + path: self.uri.to_string(), + })? + .into_iter() + .collect::>>()?; let daft_schema = daft_core::schema::Schema::try_from(&self.arrow_schema)?; Table::new(daft_schema, all_series) diff --git a/src/daft-parquet/src/lib.rs b/src/daft-parquet/src/lib.rs index 86c94cd7d0..49ac686e5d 100644 --- a/src/daft-parquet/src/lib.rs +++ b/src/daft-parquet/src/lib.rs @@ -41,6 +41,13 @@ pub enum Error { path: String, source: arrow2::error::Error, }, + + #[snafu(display("Unable to create page stream for parquet file {}: {}", path, source))] + UnableToCreateParquetPageStream { + path: String, + source: parquet2::error::Error, + }, + #[snafu(display( "Unable to parse parquet metadata to arrow schema for file {}: {}", path, @@ -96,6 +103,13 @@ pub enum Error { path: String, source: tokio::task::JoinError, }, + #[snafu(display( + "Sender of OneShot Channel Dropped before sending data over: {}", + source + ))] + OneShotRecvError { + source: tokio::sync::oneshot::error::RecvError, + }, } impl From for DaftError { diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index b5a29b291a..b489917d65 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -23,21 +23,19 @@ pub fn read_parquet( ) -> DaftResult
{ let runtime_handle = get_runtime(true)?; let _rt_guard = runtime_handle.enter(); - let (reader, ranges) = runtime_handle.block_on(async { - let builder = ParquetReaderBuilder::from_uri(uri, io_client.clone()).await?; - - let builder = if let Some(columns) = columns { - builder.prune_columns(columns)? - } else { - builder - }; - let builder = builder.limit(start_offset, num_rows)?; - let parquet_reader = builder.build()?; - let ranges = parquet_reader.prebuffer_ranges(io_client.clone()).await?; - DaftResult::Ok((parquet_reader, ranges)) - })?; - - reader.read_from_ranges(ranges) + let builder = runtime_handle + .block_on(async { ParquetReaderBuilder::from_uri(uri, io_client.clone()).await })?; + + let builder = if let Some(columns) = columns { + builder.prune_columns(columns)? + } else { + builder + }; + let builder = builder.limit(start_offset, num_rows)?; + let parquet_reader = builder.build()?; + let ranges = parquet_reader.prebuffer_ranges(io_client)?; + + runtime_handle.block_on(async { parquet_reader.read_from_ranges(ranges).await }) } pub fn read_parquet_schema(uri: &str, io_client: Arc) -> DaftResult { @@ -85,7 +83,7 @@ pub fn read_parquet_statistics(uris: &Series, io_client: Arc) -> DaftR .into_iter() .zip(values.iter()) .map(|(t, u)| { - t.with_context(|_| JoinSnafu:: { + t.with_context(|_| JoinSnafu { path: u.unwrap().to_string(), })? }) diff --git a/src/daft-parquet/src/read_planner.rs b/src/daft-parquet/src/read_planner.rs index 537bd6ed08..8cc1757f54 100644 --- a/src/daft-parquet/src/read_planner.rs +++ b/src/daft-parquet/src/read_planner.rs @@ -1,9 +1,13 @@ -use std::{fmt::Display, io::Read, ops::Range, sync::Arc}; +use std::{fmt::Display, ops::Range, sync::Arc}; +use bytes::Bytes; use common_error::DaftResult; use daft_io::IOClient; -use futures::{StreamExt, TryStreamExt}; +use futures::StreamExt; use snafu::ResultExt; +use tokio::task::JoinHandle; + +use crate::JoinSnafu; type RangeList = Vec>; @@ -83,6 +87,38 @@ impl ReadPlanPass for SplitLargeRequestPass { } } +enum RangeCacheState { + InFlight(JoinHandle>), + Ready(Bytes), +} + +struct RangeCacheEntry { + start: usize, + end: usize, + state: tokio::sync::Mutex, +} + +impl RangeCacheEntry { + async fn get_or_wait(&self, range: Range) -> std::result::Result { + { + let mut _guard = self.state.lock().await; + match &mut (*_guard) { + RangeCacheState::InFlight(f) => { + // TODO(sammy): thread in url for join error + let v = f + .await + .context(JoinSnafu { path: "UNKNOWN" }) + .unwrap() + .unwrap(); + *_guard = RangeCacheState::Ready(v.clone()); + Ok(v.slice(range)) + } + RangeCacheState::Ready(v) => Ok(v.slice(range)), + } + } + } +} + pub(crate) struct ReadPlanner { source: String, ranges: RangeList, @@ -117,150 +153,104 @@ impl ReadPlanner { Ok(()) } - pub async fn collect(self, io_client: Arc) -> DaftResult { - let mut stored_ranges: Vec<_> = - futures::stream::iter(self.ranges.into_iter().map(|range| { - // multithread this - let owned_io_client = io_client.clone(); - let owned_url = self.source.clone(); - tokio::spawn(async move { - let get_result = owned_io_client - .single_url_get(owned_url, Some(range.clone())) - .await?; - let bytes = get_result.bytes().await?; - DaftResult::Ok((range.start, bytes.to_vec())) - }) - })) - // TODO(sammy): Use client pool in s3 client - .buffer_unordered(256) - .try_collect::>() - .await - .context(super::JoinSnafu { path: self.source })? - .into_iter() - .collect::>()?; - - stored_ranges.sort_unstable_by_key(|(start, _)| *start); - Ok(RangesContainer { - ranges: stored_ranges, - }) + pub fn collect(self, io_client: Arc) -> DaftResult> { + let mut entries = Vec::with_capacity(self.ranges.len()); + for range in self.ranges { + let owned_io_client = io_client.clone(); + let owned_url = self.source.clone(); + let start = range.start; + let end = range.end; + let join_handle = tokio::spawn(async move { + let get_result = owned_io_client + .single_url_get(owned_url, Some(range.clone())) + .await?; + Ok(get_result.bytes().await?) + }); + let state = RangeCacheState::InFlight(join_handle); + let entry = RangeCacheEntry { + start, + end, + state: tokio::sync::Mutex::new(state), + }; + entries.push(entry); + } + Ok(Arc::new(RangesContainer { ranges: entries })) } } pub(crate) struct RangesContainer { - ranges: Vec<(usize, Vec)>, + ranges: Vec, } impl RangesContainer { - pub fn get_range_reader<'a>(&'a self, range: Range) -> super::Result> { + pub fn get_range_reader( + &self, + range: Range, + ) -> DaftResult { let mut current_pos = range.start; let mut curr_index; - let start_point = self.ranges.binary_search_by_key(¤t_pos, |(v, _)| *v); + let start_point = self.ranges.binary_search_by_key(¤t_pos, |e| e.start); - let mut slice_vec: Vec<&'a [u8]> = vec![]; + let mut needed_entries = vec![]; + let mut ranges_to_slice = vec![]; match start_point { Ok(index) => { - let (byte_start, bytes_at_index) = &self.ranges[index]; - assert_eq!(*byte_start, current_pos); + let entry = &self.ranges[index]; + let len = entry.end - entry.start; + assert_eq!(entry.start, current_pos); let start_offset = 0; - let end_offset = bytes_at_index.len().min(range.end - current_pos); - let curr_slice = &bytes_at_index.as_slice()[start_offset..end_offset]; - slice_vec.push(curr_slice); - current_pos += curr_slice.len(); + let end_offset = len.min(range.end - current_pos); + + needed_entries.push(entry); + ranges_to_slice.push(start_offset..end_offset); + + current_pos += end_offset - start_offset; curr_index = index + 1; } Err(index) => { assert!( index > 0, - "range: {range:?}, start: {}, len: {}", - &self.ranges[index].0, - &self.ranges[index].1.len() + "range: {range:?}, start: {}, end: {}", + &self.ranges[index].start, + &self.ranges[index].end ); let index = index - 1; - let (byte_start, bytes_at_index) = &self.ranges[index]; - let end = byte_start + bytes_at_index.len(); - assert!(current_pos >= *byte_start && current_pos < end, "range: {range:?}, current_pos: {current_pos}, bytes_start: {byte_start}, end: {end}"); - let start_offset = current_pos - byte_start; - let end_offset = bytes_at_index.len().min(range.end - byte_start); - let curr_slice = &bytes_at_index.as_slice()[start_offset..end_offset]; - slice_vec.push(curr_slice); - current_pos += curr_slice.len(); + let entry = &self.ranges[index]; + let start = entry.start; + let end = entry.end; + let len = end - start; + assert!(current_pos >= start && current_pos < end, "range: {range:?}, current_pos: {current_pos}, bytes_start: {start}, end: {end}"); + let start_offset = current_pos - start; + let end_offset = len.min(range.end - start); + needed_entries.push(entry); + ranges_to_slice.push(start_offset..end_offset); + current_pos += end_offset - start_offset; curr_index = index + 1; } }; while current_pos < range.end && curr_index < self.ranges.len() { - let (byte_start, bytes_at_index) = &self.ranges[curr_index]; - assert_eq!(*byte_start, current_pos); + let entry = &self.ranges[curr_index]; + let start = entry.start; + let end = entry.end; + let len = end - start; + assert_eq!(start, current_pos); let start_offset = 0; - let end_offset = bytes_at_index.len().min(range.end - byte_start); - let curr_slice = &bytes_at_index.as_slice()[start_offset..end_offset]; - slice_vec.push(curr_slice); - current_pos += curr_slice.len(); + let end_offset = len.min(range.end - start); + needed_entries.push(entry); + ranges_to_slice.push(start_offset..end_offset); + current_pos += end_offset - start_offset; curr_index += 1; } assert_eq!(current_pos, range.end); - Ok(MultiRead::new(slice_vec, range.end - range.start)) - } -} - -pub(crate) struct MultiRead<'a> { - sources: Vec<&'a [u8]>, - pos_in_sources: usize, - pos_in_current: usize, - bytes_read: usize, - total_size: usize, -} - -impl<'a> MultiRead<'a> { - fn new(sources: Vec<&'a [u8]>, total_size: usize) -> MultiRead<'a> { - MultiRead { - sources, - pos_in_sources: 0, - pos_in_current: 0, - bytes_read: 0, - total_size, - } - } -} + let bytes_iter = tokio_stream::iter(needed_entries.into_iter().zip(ranges_to_slice)) + .then(|(e, r)| async move { e.get_or_wait(r).await }); -impl Read for MultiRead<'_> { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - let current = loop { - if self.pos_in_sources >= self.sources.len() { - return Ok(0); // EOF - } - let current = self.sources[self.pos_in_sources]; - if self.pos_in_current < current.len() { - break current; - } - self.pos_in_current = 0; - self.pos_in_sources += 1; - }; - let read_size = buf.len().min(current.len() - self.pos_in_current); - buf[..read_size].copy_from_slice(¤t[self.pos_in_current..][..read_size]); - self.pos_in_current += read_size; - self.bytes_read += read_size; - Ok(read_size) - } + let stream_reader = tokio_util::io::StreamReader::new(bytes_iter); + let convert = async_compat::Compat::new(stream_reader); - #[inline] - fn read_to_end(&mut self, buf: &mut Vec) -> std::io::Result { - if self.bytes_read >= self.total_size { - return Ok(0); - } - let starting_bytes_read = self.bytes_read; - buf.reserve(self.total_size - self.bytes_read); - while self.bytes_read < self.total_size { - let current = self.sources[self.pos_in_sources]; - let slice = ¤t[self.pos_in_current..]; - buf.extend_from_slice(slice); - self.pos_in_current = 0; - self.pos_in_sources += 1; - self.bytes_read += slice.len(); - } - Ok(self.bytes_read - starting_bytes_read) + Ok(convert) } }