diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 34c7c1e695..b14d424d0c 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from daft.context import get_context from daft.daft import ( FileFormat, IOConfig, @@ -32,15 +33,17 @@ def scan_with_tasks( """ # TODO(Clark): Currently hardcoded to have 1 file per instruction # We can instead right-size and bundle the ScanTask into single-instruction bulk reads. + + cfg = get_context().daft_execution_config + for scan_task in scan_tasks: scan_step = execution_step.PartitionTaskBuilder[PartitionT]( inputs=[], partial_metadatas=None, ).add_instruction( instruction=execution_step.ScanWithTask(scan_task), - # Set the filesize as the memory request. - # (Note: this is very conservative; file readers empirically use much more peak memory than 1x file size.) - resource_request=ResourceRequest(memory_bytes=scan_task.size_bytes()), + # Set the estimated in-memory size as the memory request. + resource_request=ResourceRequest(memory_bytes=scan_task.estimate_in_memory_size_bytes(cfg)), ) yield scan_step diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index e36517aa2a..019bb5e50c 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -57,11 +57,7 @@ pub(super) fn translate_single_logical_node( ); // Apply transformations on the ScanTasks to optimize - let scan_tasks = daft_scan::scan_task_iters::merge_by_sizes( - scan_tasks, - cfg.scan_tasks_min_size_bytes, - cfg.scan_tasks_max_size_bytes, - ); + let scan_tasks = daft_scan::scan_task_iters::merge_by_sizes(scan_tasks, cfg); let scan_tasks = scan_tasks.collect::>>()?; if scan_tasks.is_empty() { let clustering_spec = diff --git a/src/daft-scan/src/scan_task_iters.rs b/src/daft-scan/src/scan_task_iters.rs index 2c1666c95b..3d14ea38dd 100644 --- a/src/daft-scan/src/scan_task_iters.rs +++ b/src/daft-scan/src/scan_task_iters.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use daft_io::IOStatsContext; use daft_parquet::read::read_parquet_metadata; @@ -10,7 +11,7 @@ use crate::{ ChunkSpec, DataFileSource, ScanTask, ScanTaskRef, }; -type BoxScanTaskIter = Box>>; +type BoxScanTaskIter<'a> = Box> + 'a>; /// Coalesces ScanTasks by their [`ScanTask::size_bytes()`] /// @@ -24,33 +25,30 @@ type BoxScanTaskIter = Box>>; /// * `scan_tasks`: A Boxed Iterator of ScanTaskRefs to perform merging on /// * `min_size_bytes`: Minimum size in bytes of a ScanTask, after which no more merging will be performed /// * `max_size_bytes`: Maximum size in bytes of a ScanTask, capping the maximum size of a merged ScanTask -pub fn merge_by_sizes( - scan_tasks: BoxScanTaskIter, - min_size_bytes: usize, - max_size_bytes: usize, -) -> BoxScanTaskIter { +pub fn merge_by_sizes<'a>( + scan_tasks: BoxScanTaskIter<'a>, + cfg: &'a DaftExecutionConfig, +) -> BoxScanTaskIter<'a> { Box::new(MergeByFileSize { iter: scan_tasks, - min_size_bytes, - max_size_bytes, + cfg, accumulator: None, }) } -struct MergeByFileSize { - iter: BoxScanTaskIter, - min_size_bytes: usize, - max_size_bytes: usize, +struct MergeByFileSize<'a> { + iter: BoxScanTaskIter<'a>, + cfg: &'a DaftExecutionConfig, // Current element being accumulated on accumulator: Option, } -impl MergeByFileSize { +impl<'a> MergeByFileSize<'a> { fn accumulator_ready(&self) -> bool { if let Some(acc) = &self.accumulator - && let Some(acc_bytes) = acc.size_bytes() - && acc_bytes >= self.min_size_bytes + && let Some(acc_bytes) = acc.estimate_in_memory_size_bytes(Some(self.cfg)) + && acc_bytes >= self.cfg.scan_tasks_min_size_bytes { true } else { @@ -69,10 +67,12 @@ impl MergeByFileSize { && other.storage_config == accumulator.storage_config && other.pushdowns == accumulator.pushdowns; - let sum_smaller_than_max_size_bytes = if let Some(child_bytes) = other.size_bytes() - && let Some(accumulator_bytes) = accumulator.size_bytes() + let sum_smaller_than_max_size_bytes = if let Some(child_bytes) = + other.estimate_in_memory_size_bytes(Some(self.cfg)) + && let Some(accumulator_bytes) = + accumulator.estimate_in_memory_size_bytes(Some(self.cfg)) { - child_bytes + accumulator_bytes <= self.max_size_bytes + child_bytes + accumulator_bytes <= self.cfg.scan_tasks_max_size_bytes } else { false }; @@ -81,7 +81,7 @@ impl MergeByFileSize { } } -impl Iterator for MergeByFileSize { +impl<'a> Iterator for MergeByFileSize<'a> { type Item = DaftResult; fn next(&mut self) -> Option { @@ -104,7 +104,11 @@ impl Iterator for MergeByFileSize { None => return self.accumulator.take().map(Ok), }; - if next_item.size_bytes().is_none() || !self.can_merge(&next_item) { + if next_item + .estimate_in_memory_size_bytes(Some(self.cfg)) + .is_none() + || !self.can_merge(&next_item) + { return self.accumulator.replace(next_item).map(Ok); } diff --git a/tests/io/test_merge_scan_tasks.py b/tests/io/test_merge_scan_tasks.py index e69d1a1105..d52b2e0c02 100644 --- a/tests/io/test_merge_scan_tasks.py +++ b/tests/io/test_merge_scan_tasks.py @@ -41,7 +41,7 @@ def test_merge_scan_task_exceed_max(csv_files): def test_merge_scan_task_below_max(csv_files): - with override_merge_scan_tasks_configs(21, 22): + with override_merge_scan_tasks_configs(11, 12): df = daft.read_csv(str(csv_files)) assert ( df.num_partitions() == 2 @@ -49,7 +49,7 @@ def test_merge_scan_task_below_max(csv_files): def test_merge_scan_task_above_min(csv_files): - with override_merge_scan_tasks_configs(19, 40): + with override_merge_scan_tasks_configs(9, 20): df = daft.read_csv(str(csv_files)) assert ( df.num_partitions() == 2 @@ -57,7 +57,7 @@ def test_merge_scan_task_above_min(csv_files): def test_merge_scan_task_below_min(csv_files): - with override_merge_scan_tasks_configs(35, 40): + with override_merge_scan_tasks_configs(17, 20): df = daft.read_csv(str(csv_files)) assert ( df.num_partitions() == 1