diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index c5433e12174f..77dbb569cb10 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -2249,6 +2249,9 @@ impl DataFrame { if offset == 0 && length == self.height() { return self.clone(); } + if length == 0 { + return self.clear(); + } let col = self .columns .iter() diff --git a/crates/polars-io/src/csv/read/read_impl/batched_mmap.rs b/crates/polars-io/src/csv/read/read_impl/batched_mmap.rs index cb8ce04947d8..f30f3105de51 100644 --- a/crates/polars-io/src/csv/read/read_impl/batched_mmap.rs +++ b/crates/polars-io/src/csv/read/read_impl/batched_mmap.rs @@ -170,7 +170,7 @@ impl<'a> CoreReader<'a> { to_cast: self.to_cast, ignore_errors: self.ignore_errors, truncate_ragged_lines: self.truncate_ragged_lines, - n_rows: self.n_rows, + remaining: self.n_rows.unwrap_or(usize::MAX), encoding: self.encoding, separator: self.separator, schema: self.schema, @@ -197,7 +197,7 @@ pub struct BatchedCsvReaderMmap<'a> { truncate_ragged_lines: bool, to_cast: Vec, ignore_errors: bool, - n_rows: Option, + remaining: usize, encoding: CsvEncoding, separator: u8, schema: SchemaRef, @@ -211,14 +211,9 @@ pub struct BatchedCsvReaderMmap<'a> { impl<'a> BatchedCsvReaderMmap<'a> { pub fn next_batches(&mut self, n: usize) -> PolarsResult>> { - if n == 0 { + if n == 0 || self.remaining == 0 { return Ok(None); } - if let Some(n_rows) = self.n_rows { - if self.rows_read >= n_rows as IdxSize { - return Ok(None); - } - } // get next `n` offset positions. let file_chunks_iter = (&mut self.file_chunks_iter).take(n); @@ -274,8 +269,15 @@ impl<'a> BatchedCsvReaderMmap<'a> { if self.row_index.is_some() { update_row_counts2(&mut chunks, self.rows_read) } - for df in &chunks { - self.rows_read += df.height() as IdxSize; + for df in &mut chunks { + let h = df.height(); + + if self.remaining < h { + *df = df.slice(0, self.remaining) + }; + self.remaining = self.remaining.saturating_sub(h); + + self.rows_read += h as IdxSize; } Ok(Some(chunks)) } diff --git a/crates/polars-io/src/csv/read/read_impl/batched_read.rs b/crates/polars-io/src/csv/read/read_impl/batched_read.rs index b42be05f14b4..8c405f88e3fc 100644 --- a/crates/polars-io/src/csv/read/read_impl/batched_read.rs +++ b/crates/polars-io/src/csv/read/read_impl/batched_read.rs @@ -246,7 +246,6 @@ impl<'a> CoreReader<'a> { Ok(BatchedCsvReaderRead { chunk_size: self.chunk_size, - finished: false, file_chunk_reader: chunk_iter, file_chunks: vec![], projection, @@ -260,20 +259,20 @@ impl<'a> CoreReader<'a> { to_cast: self.to_cast, ignore_errors: self.ignore_errors, truncate_ragged_lines: self.truncate_ragged_lines, - n_rows: self.n_rows, + remaining: self.n_rows.unwrap_or(usize::MAX), encoding: self.encoding, separator: self.separator, schema: self.schema, rows_read: 0, _cat_lock, decimal_comma: self.decimal_comma, + finished: false, }) } } pub struct BatchedCsvReaderRead<'a> { chunk_size: usize, - finished: bool, file_chunk_reader: ChunkReader<'a>, file_chunks: Vec<(SyncPtr, usize)>, projection: Vec, @@ -287,7 +286,7 @@ pub struct BatchedCsvReaderRead<'a> { to_cast: Vec, ignore_errors: bool, truncate_ragged_lines: bool, - n_rows: Option, + remaining: usize, encoding: CsvEncoding, separator: u8, schema: SchemaRef, @@ -297,19 +296,15 @@ pub struct BatchedCsvReaderRead<'a> { #[cfg(not(feature = "dtype-categorical"))] _cat_lock: Option, decimal_comma: bool, + finished: bool, } // impl<'a> BatchedCsvReaderRead<'a> { /// `n` number of batches. pub fn next_batches(&mut self, n: usize) -> PolarsResult>> { - if n == 0 || self.finished { + if n == 0 || self.remaining == 0 || self.finished { return Ok(None); } - if let Some(n_rows) = self.n_rows { - if self.rows_read >= n_rows as IdxSize { - return Ok(None); - } - } // get next `n` offset positions. @@ -331,7 +326,7 @@ impl<'a> BatchedCsvReaderRead<'a> { // get the final slice self.file_chunks .push(self.file_chunk_reader.get_buf_remaining()); - self.finished = true + self.finished = true; } // depleted the offsets iterator, we are done as well. @@ -380,8 +375,15 @@ impl<'a> BatchedCsvReaderRead<'a> { if self.row_index.is_some() { update_row_counts2(&mut chunks, self.rows_read) } - for df in &chunks { - self.rows_read += df.height() as IdxSize; + for df in &mut chunks { + let h = df.height(); + + if self.remaining < h { + *df = df.slice(0, self.remaining) + }; + self.remaining = self.remaining.saturating_sub(h); + + self.rows_read += h as IdxSize; } Ok(Some(chunks)) } diff --git a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs index 840429855f5b..93e6a55aa6f3 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/construct_pipeline.rs @@ -68,10 +68,12 @@ fn jit_insert_slice( sink_nodes: &mut Vec<(usize, Node, Rc>)>, operator_offset: usize, ) { - // if the join/union has a slice, we add a new slice node + // if the join has a slice, we add a new slice node // note that we take the offset + 1, because we want to // slice AFTER the join has happened and the join will be an // operator + // NOTE: Don't do this for union, that doesn't work. + // TODO! Deal with this in the optimizer. use IR::*; let (offset, len) = match lp_arena.get(node) { Join { options, .. } if options.args.slice.is_some() => { @@ -80,19 +82,11 @@ fn jit_insert_slice( }; (offset, len) }, - Union { - options: - UnionOptions { - slice: Some((offset, len)), - .. - }, - .. - } => (*offset, *len), _ => return, }; let slice_node = lp_arena.add(Slice { - input: Node::default(), + input: node, offset, len: len as IdxSize, }); @@ -178,7 +172,6 @@ pub(super) fn construct( }, PipelineNode::Union(node) => { operator_nodes.push(node); - jit_insert_slice(node, lp_arena, &mut sink_nodes, operator_offset); let op = get_operator(node, lp_arena, expr_arena, &to_physical_piped_expr)?; operators.push(op); }, diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index 7ffdbd7935af..54fe0b1a68f3 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -81,21 +81,6 @@ fn insert_file_sink(mut root: Node, lp_arena: &mut Arena) -> Node { root } -fn insert_slice( - root: Node, - offset: i64, - len: IdxSize, - lp_arena: &mut Arena, - state: &mut Branch, -) { - let node = lp_arena.add(IR::Slice { - input: root, - offset, - len: len as IdxSize, - }); - state.operators_sinks.push(PipelineNode::Sink(node)); -} - pub(crate) fn insert_streaming_nodes( root: Node, lp_arena: &mut Arena, @@ -244,20 +229,8 @@ pub(crate) fn insert_streaming_nodes( ) } }, - Scan { - file_options: options, - scan_type, - .. - } if scan_type.streamable() => { + Scan { scan_type, .. } if scan_type.streamable() => { if state.streamable { - #[cfg(feature = "csv")] - if matches!(scan_type, FileScan::Csv { .. }) { - // the batched csv reader doesn't stop exactly at n_rows - if let Some(n_rows) = options.n_rows { - insert_slice(root, 0, n_rows as IdxSize, lp_arena, &mut state); - } - } - state.sources.push(root); pipeline_trees[current_idx].push(state) } @@ -320,38 +293,7 @@ pub(crate) fn insert_streaming_nodes( state.sources.push(root); pipeline_trees[current_idx].push(state); }, - Union { - options: - UnionOptions { - slice: Some((offset, len)), - .. - }, - .. - } if *offset >= 0 => { - insert_slice(root, *offset, *len as IdxSize, lp_arena, &mut state); - state.streamable = true; - let Union { inputs, .. } = lp_arena.get(root) else { - unreachable!() - }; - for (i, input) in inputs.iter().enumerate() { - let mut state = if i == 0 { - // Note the clone! - let mut state = state.clone(); - state.join_count += inputs.len() as u32 - 1; - state - } else { - let mut state = state.split_from_sink(); - state.join_count = 0; - state - }; - state.operators_sinks.push(PipelineNode::Union(root)); - stack.push(StackFrame::new(*input, state, current_idx)); - } - }, - Union { - inputs, - options: UnionOptions { slice: None, .. }, - } => { + Union { inputs, .. } => { { state.streamable = true; for (i, input) in inputs.iter().enumerate() { diff --git a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs index 18b3c9d85631..5a0975f4a654 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs @@ -209,7 +209,6 @@ impl SlicePushDown { Ok(lp) } (Union {mut inputs, mut options }, Some(state)) => { - options.slice = Some((state.offset, state.len as usize)); if state.offset == 0 { for input in &mut inputs { let input_lp = lp_arena.take(*input); @@ -217,7 +216,17 @@ impl SlicePushDown { lp_arena.replace(*input, input_lp); } } - Ok(Union {inputs, options}) + // The in-memory union node is slice aware. + // We still set this information, but the streaming engine will ignore it. + options.slice = Some((state.offset, state.len as usize)); + let lp = Union {inputs, options}; + + if self.streaming { + // Ensure the slice node remains. + self.no_pushdown_finish_opt(lp, Some(state), lp_arena) + } else { + Ok(lp) + } }, (Join { input_left, diff --git a/crates/polars-utils/src/arena.rs b/crates/polars-utils/src/arena.rs index df367b733f1f..31818eb03d86 100644 --- a/crates/polars-utils/src/arena.rs +++ b/crates/polars-utils/src/arena.rs @@ -104,6 +104,13 @@ impl Arena { } } +impl Arena { + pub fn duplicate(&mut self, node: Node) -> Node { + let item = self.items[node.0].clone(); + self.add(item) + } +} + impl Arena { #[inline] pub fn take(&mut self, idx: Node) -> T { diff --git a/py-polars/tests/unit/streaming/test_streaming_io.py b/py-polars/tests/unit/streaming/test_streaming_io.py index d405fec1183c..982ba225e9d9 100644 --- a/py-polars/tests/unit/streaming/test_streaming_io.py +++ b/py-polars/tests/unit/streaming/test_streaming_io.py @@ -30,6 +30,11 @@ def test_scan_slice_streaming(io_files_path: Path) -> None: df = pl.scan_csv(foods_file_path).head(5).collect(streaming=True) assert df.shape == (5, 4) + # globbing + foods_file_path = io_files_path / "foods*.csv" + df = pl.scan_csv(foods_file_path).head(5).collect(streaming=True) + assert df.shape == (5, 4) + @pytest.mark.parametrize("dtype", [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16]) def test_scan_csv_overwrite_small_dtypes(