diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index e7a2c3f053..3f8f99da07 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -1,9 +1,13 @@ from __future__ import annotations import itertools +import pathlib import sys from dataclasses import dataclass, field -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar + +if TYPE_CHECKING: + import fsspec if sys.version_info < (3, 8): from typing_extensions import Protocol @@ -14,11 +18,11 @@ CSVSourceInfo, JSONSourceInfo, ParquetSourceInfo, + SourceInfo, StorageType, ) from daft.expressions import Expression, ExpressionsProjection, col -from daft.logical import logical_plan -from daft.logical.logical_plan import FileWrite, TabularFilesScan +from daft.logical.logical_plan import JoinType from daft.logical.map_partition_ops import MapPartitionOp from daft.logical.schema import Schema from daft.resource_request import ResourceRequest @@ -306,10 +310,16 @@ def num_outputs(self) -> int: @dataclass(frozen=True) class ReadFile(SingleOutputInstruction): - partition_id: int index: int | None - logplan: logical_plan.TabularFilesScan + # Known number of rows. file_rows: int | None + # Max number of rows to read. + limit_rows: int | None + schema: Schema + fs: fsspec.AbstractFileSystem | None + columns_to_read: list[str] | None + source_info: SourceInfo + filepaths_column_name: str def run(self, inputs: list[Table]) -> list[Table]: return self._read_file(inputs) @@ -318,9 +328,7 @@ def _read_file(self, inputs: list[Table]) -> list[Table]: assert len(inputs) == 1 [filepaths_partition] = inputs partition = self._handle_tabular_files_scan( - inputs={self.logplan._filepaths_child.id(): filepaths_partition}, - scan=self.logplan, - index=self.index, + filepaths_partition=filepaths_partition, ) return [partition] @@ -329,8 +337,8 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) num_rows = self.file_rows # Only take the file read limit into account if we know how big the file is to begin with. - if num_rows is not None and self.logplan._limit_rows is not None: - num_rows = min(num_rows, self.logplan._limit_rows) + if num_rows is not None and self.limit_rows is not None: + num_rows = min(num_rows, self.limit_rows) return [ PartialPartitionMetadata( @@ -340,79 +348,76 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) ] def _handle_tabular_files_scan( - self, inputs: dict[int, Table], scan: TabularFilesScan, index: int | None = None + self, + filepaths_partition: Table, ) -> Table: - child_id = scan._children()[0].id() - prev_partition = inputs[child_id] - data = prev_partition.to_pydict() + data = filepaths_partition.to_pydict() assert ( - scan._filepaths_column_name in data - ), f"TabularFilesScan should be ran on vPartitions with '{scan._filepaths_column_name}' column" - filepaths = data[scan._filepaths_column_name] + self.filepaths_column_name in data + ), f"TabularFilesScan should be ran on vPartitions with '{self.filepaths_column_name}' column" + filepaths = data[self.filepaths_column_name] - if index is not None: - filepaths = [filepaths[index]] + if self.index is not None: + filepaths = [filepaths[self.index]] # Common options for reading vPartition - fs = scan._fs - schema = scan._schema read_options = TableReadOptions( - num_rows=scan._limit_rows, - column_names=scan._column_names, # read only specified columns + num_rows=self.limit_rows, + column_names=self.columns_to_read, # read only specified columns ) - if scan._source_info.scan_type() == StorageType.CSV: - assert isinstance(scan._source_info, CSVSourceInfo) + if self.source_info.scan_type() == StorageType.CSV: + assert isinstance(self.source_info, CSVSourceInfo) table = Table.concat( [ table_io.read_csv( file=fp, - schema=schema, - fs=fs, + schema=self.schema, + fs=self.fs, csv_options=TableParseCSVOptions( - delimiter=scan._source_info.delimiter, - header_index=0 if scan._source_info.has_headers else None, + delimiter=self.source_info.delimiter, + header_index=0 if self.source_info.has_headers else None, ), read_options=read_options, ) for fp in filepaths ] ) - elif scan._source_info.scan_type() == StorageType.JSON: - assert isinstance(scan._source_info, JSONSourceInfo) + elif self.source_info.scan_type() == StorageType.JSON: + assert isinstance(self.source_info, JSONSourceInfo) table = Table.concat( [ table_io.read_json( file=fp, - schema=schema, - fs=fs, + schema=self.schema, + fs=self.fs, read_options=read_options, ) for fp in filepaths ] ) - elif scan._source_info.scan_type() == StorageType.PARQUET: - assert isinstance(scan._source_info, ParquetSourceInfo) + elif self.source_info.scan_type() == StorageType.PARQUET: + assert isinstance(self.source_info, ParquetSourceInfo) table = Table.concat( [ table_io.read_parquet( file=fp, - schema=schema, - fs=fs, + schema=self.schema, + fs=self.fs, read_options=read_options, - io_config=scan._source_info.io_config, - use_native_downloader=scan._source_info.use_native_downloader, + io_config=self.source_info.io_config, + use_native_downloader=self.source_info.use_native_downloader, ) for fp in filepaths ] ) else: - raise NotImplementedError(f"PyRunner has not implemented scan: {scan._source_info.scan_type()}") + raise NotImplementedError(f"PyRunner has not implemented scan: {self.source_info.scan_type()}") expected_schema = ( - Schema._from_fields([schema[name] for name in read_options.column_names]) + Schema._from_fields([self.schema[name] for name in read_options.column_names]) if read_options.column_names is not None - else schema + else self.schema ) assert ( table.schema() == expected_schema @@ -422,8 +427,11 @@ def _handle_tabular_files_scan( @dataclass(frozen=True) class WriteFile(SingleOutputInstruction): - partition_id: int - logplan: logical_plan.FileWrite + file_type: StorageType + schema: Schema + root_dir: str | pathlib.Path + compression: str | None + partition_cols: ExpressionsProjection | None def run(self, inputs: list[Table]) -> list[Table]: return self._write_file(inputs) @@ -431,8 +439,7 @@ def run(self, inputs: list[Table]) -> list[Table]: def _write_file(self, inputs: list[Table]) -> list[Table]: [input] = inputs partition = self._handle_file_write( - inputs={self.logplan._children()[0].id(): input}, - file_write=self.logplan, + input=input, ) return [partition] @@ -445,29 +452,27 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) ) ] - def _handle_file_write(self, inputs: dict[int, Table], file_write: FileWrite) -> Table: - child_id = file_write._children()[0].id() - assert file_write._storage_type == StorageType.PARQUET or file_write._storage_type == StorageType.CSV - if file_write._storage_type == StorageType.PARQUET: + def _handle_file_write(self, input: Table) -> Table: + assert self.file_type == StorageType.PARQUET or self.file_type == StorageType.CSV + if self.file_type == StorageType.PARQUET: file_names = table_io.write_parquet( - inputs[child_id], - path=file_write._root_dir, - compression=file_write._compression, - partition_cols=file_write._partition_cols, + input, + path=self.root_dir, + compression=self.compression, + partition_cols=self.partition_cols, ) else: file_names = table_io.write_csv( - inputs[child_id], - path=file_write._root_dir, - compression=file_write._compression, - partition_cols=file_write._partition_cols, + input, + path=self.root_dir, + compression=self.compression, + partition_cols=self.partition_cols, ) - output_schema = file_write.schema() - assert len(output_schema) == 1 + assert len(self.schema) == 1 return Table.from_pydict( { - output_schema.column_names()[0]: file_names, + self.schema.column_names()[0]: file_names, } ) @@ -516,7 +521,7 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) @dataclass(frozen=True) class LocalCount(SingleOutputInstruction): - logplan: logical_plan.LocalCount + schema: Schema def run(self, inputs: list[Table]) -> list[Table]: return self._count(inputs) @@ -524,7 +529,7 @@ def run(self, inputs: list[Table]) -> list[Table]: def _count(self, inputs: list[Table]) -> list[Table]: [input] = inputs partition = Table.from_pydict({"count": [len(input)]}) - assert partition.schema() == self.logplan.schema() + assert partition.schema() == self.schema return [partition] def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: @@ -629,7 +634,10 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) @dataclass(frozen=True) class Join(SingleOutputInstruction): - logplan: logical_plan.Join + left_on: ExpressionsProjection + right_on: ExpressionsProjection + output_projection: ExpressionsProjection + how: JoinType def run(self, inputs: list[Table]) -> list[Table]: return self._join(inputs) @@ -638,10 +646,10 @@ def _join(self, inputs: list[Table]) -> list[Table]: [left, right] = inputs result = left.join( right, - left_on=self.logplan._left_on, - right_on=self.logplan._right_on, - output_projection=self.logplan._output_projection, - how=self.logplan._how.value, + left_on=self.left_on, + right_on=self.right_on, + output_projection=self.output_projection, + how=self.how.value, ) return [result] diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 83786cdac8..aa683eb75f 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -14,11 +14,16 @@ from __future__ import annotations import math +import pathlib from collections import deque -from typing import Generator, Iterator, TypeVar, Union +from typing import TYPE_CHECKING, Generator, Iterator, TypeVar, Union + +if TYPE_CHECKING: + import fsspec from loguru import logger +from daft.datasources import SourceInfo, StorageType from daft.execution import execution_step from daft.execution.execution_step import ( Instruction, @@ -28,7 +33,9 @@ ReduceInstruction, SingleOutputPartitionTask, ) -from daft.logical import logical_plan +from daft.expressions import ExpressionsProjection +from daft.logical.logical_plan import JoinType +from daft.logical.schema import Schema from daft.resource_request import ResourceRequest from daft.runners.partitioning import PartialPartitionMetadata @@ -59,7 +66,13 @@ def partition_read( def file_read( child_plan: InProgressPhysicalPlan[PartitionT], - scan_info: logical_plan.TabularFilesScan, + # Max number of rows to read. + limit_rows: int | None, + schema: Schema, + fs: fsspec.AbstractFileSystem | None, + columns_to_read: list[str] | None, + source_info: SourceInfo, + filepaths_column_name: str, ) -> InProgressPhysicalPlan[PartitionT]: """child_plan represents partitions with filenames. @@ -86,10 +99,14 @@ def file_read( partial_metadatas=[done_task.partition_metadata()], ).add_instruction( instruction=execution_step.ReadFile( - partition_id=output_partition_index, - logplan=scan_info, index=i, file_rows=file_rows[i], + limit_rows=limit_rows, + schema=schema, + fs=fs, + columns_to_read=columns_to_read, + source_info=source_info, + filepaths_column_name=filepaths_column_name, ), # Set the filesize as the memory request. # (Note: this is very conservative; file readers empirically use much more peak memory than 1x file size.) @@ -116,17 +133,27 @@ def file_read( def file_write( child_plan: InProgressPhysicalPlan[PartitionT], - write_info: logical_plan.FileWrite, + file_type: StorageType, + schema: Schema, + root_dir: str | pathlib.Path, + compression: str | None, + partition_cols: ExpressionsProjection | None, ) -> InProgressPhysicalPlan[PartitionT]: """Write the results of `child_plan` into files described by `write_info`.""" yield from ( step.add_instruction( - execution_step.WriteFile(partition_id=index, logplan=write_info), + execution_step.WriteFile( + file_type=file_type, + schema=schema, + root_dir=root_dir, + compression=compression, + partition_cols=partition_cols, + ), ) if isinstance(step, PartitionTaskBuilder) else step - for index, step in enumerate_open_executions(child_plan) + for step in child_plan ) @@ -146,7 +173,10 @@ def pipeline_instruction( def join( left_plan: InProgressPhysicalPlan[PartitionT], right_plan: InProgressPhysicalPlan[PartitionT], - join: logical_plan.Join, + left_on: ExpressionsProjection, + right_on: ExpressionsProjection, + output_projection: ExpressionsProjection, + how: JoinType, ) -> InProgressPhysicalPlan[PartitionT]: """Pairwise join the partitions from `left_child_plan` and `right_child_plan` together.""" @@ -171,7 +201,14 @@ def join( resource_request=ResourceRequest( memory_bytes=next_left.partition_metadata().size_bytes + next_right.partition_metadata().size_bytes ), - ).add_instruction(instruction=execution_step.Join(join)) + ).add_instruction( + instruction=execution_step.Join( + left_on=left_on, + right_on=right_on, + output_projection=output_projection, + how=how, + ) + ) yield join_step # Exhausted all ready inputs; execute a single child step to get more join inputs. @@ -244,13 +281,14 @@ def local_limit( def global_limit( child_plan: InProgressPhysicalPlan[PartitionT], - global_limit: logical_plan.GlobalLimit, + limit_rows: int, + num_partitions: int, ) -> InProgressPhysicalPlan[PartitionT]: """Return the first n rows from the `child_plan`.""" - remaining_rows = global_limit._num + remaining_rows = limit_rows assert remaining_rows >= 0, f"Invalid value for limit: {remaining_rows}" - remaining_partitions = global_limit.num_partitions() + remaining_partitions = num_partitions materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque() @@ -434,18 +472,19 @@ def split( def coalesce( child_plan: InProgressPhysicalPlan[PartitionT], - coalesce: logical_plan.Coalesce, + from_num_partitions: int, + to_num_partitions: int, ) -> InProgressPhysicalPlan[PartitionT]: """Coalesce the results of the child_plan into fewer partitions. The current implementation only does partition merging, no rebalancing. """ - coalesce_from = coalesce._children()[0].num_partitions() - coalesce_to = coalesce.num_partitions() - assert coalesce_to <= coalesce_from, f"Cannot coalesce upwards from {coalesce_from} to {coalesce_to} partitions." + assert ( + to_num_partitions <= from_num_partitions + ), f"Cannot coalesce upwards from {from_num_partitions} to {to_num_partitions} partitions." - boundaries = [math.ceil((coalesce_from / coalesce_to) * i) for i in range(coalesce_to + 1)] + boundaries = [math.ceil((from_num_partitions / to_num_partitions) * i) for i in range(to_num_partitions + 1)] starts, stops = boundaries[:-1], boundaries[1:] # For each output partition, the number of input partitions to merge in. merges_per_result = deque([stop - start for start, stop in zip(starts, stops)]) @@ -536,7 +575,9 @@ def reduce( def sort( child_plan: InProgressPhysicalPlan[PartitionT], - sort_info: logical_plan.Sort, + sort_by: ExpressionsProjection, + descending: list[bool], + num_partitions: int, ) -> InProgressPhysicalPlan[PartitionT]: """Sort the result of `child_plan` according to `sort_info`.""" @@ -561,7 +602,7 @@ def sort( partial_metadatas=None, ) .add_instruction( - instruction=execution_step.Sample(sort_by=sort_info._sort_by), + instruction=execution_step.Sample(sort_by=sort_by), ) .finalize_partition_task_single_output() ) @@ -582,9 +623,9 @@ def sort( ) .add_instruction( execution_step.ReduceToQuantiles( - num_quantiles=sort_info.num_partitions(), - sort_by=sort_info._sort_by, - descending=sort_info._descending, + num_quantiles=num_partitions, + sort_by=sort_by, + descending=descending, ), ) .finalize_partition_task_single_output() @@ -606,9 +647,9 @@ def sort( ), ).add_instruction( instruction=execution_step.FanoutRange[PartitionT]( - _num_outputs=sort_info.num_partitions(), - sort_by=sort_info._sort_by, - descending=sort_info._descending, + _num_outputs=num_partitions, + sort_by=sort_by, + descending=descending, ), ) for source in consume_deque(source_materializations) @@ -617,20 +658,20 @@ def sort( # Execute a sorting reduce on it. yield from reduce( fanout_plan=range_fanout_plan, - num_partitions=sort_info.num_partitions(), + num_partitions=num_partitions, reduce_instruction=execution_step.ReduceMergeAndSort( - sort_by=sort_info._sort_by, - descending=sort_info._descending, + sort_by=sort_by, + descending=descending, ), ) -def fanout_random(child_plan: InProgressPhysicalPlan[PartitionT], node: logical_plan.Repartition): +def fanout_random(child_plan: InProgressPhysicalPlan[PartitionT], num_partitions: int): """Splits the results of `child_plan` randomly into a list of `node.num_partitions()` number of partitions""" seed = 0 for step in child_plan: if isinstance(step, PartitionTaskBuilder): - instruction = execution_step.FanoutRandom(node.num_partitions(), seed) + instruction = execution_step.FanoutRandom(num_partitions, seed) step = step.add_instruction(instruction) yield step seed += 1 diff --git a/daft/execution/physical_plan_factory.py b/daft/execution/physical_plan_factory.py index f3b277a386..a76d326828 100644 --- a/daft/execution/physical_plan_factory.py +++ b/daft/execution/physical_plan_factory.py @@ -34,7 +34,15 @@ def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) -> child_plan = _get_physical_plan(child_node, psets) if isinstance(node, logical_plan.TabularFilesScan): - return physical_plan.file_read(child_plan=child_plan, scan_info=node) + return physical_plan.file_read( + child_plan=child_plan, + limit_rows=node._limit_rows, + schema=node._schema, + fs=node._fs, + columns_to_read=node._column_names, + source_info=node._source_info, + filepaths_column_name=node._filepaths_column_name, + ) elif isinstance(node, logical_plan.Filter): return physical_plan.pipeline_instruction( @@ -67,7 +75,7 @@ def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) -> elif isinstance(node, logical_plan.LocalCount): return physical_plan.pipeline_instruction( child_plan=child_plan, - pipeable_instruction=execution_step.LocalCount(logplan=node), + pipeable_instruction=execution_step.LocalCount(schema=node.schema()), resource_request=node.resource_request(), ) @@ -79,14 +87,25 @@ def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) -> ) elif isinstance(node, logical_plan.FileWrite): - return physical_plan.file_write(child_plan, node) + return physical_plan.file_write( + child_plan=child_plan, + file_type=node._storage_type, + schema=node.schema(), + root_dir=node._root_dir, + compression=node._compression, + partition_cols=node._partition_cols, + ) elif isinstance(node, logical_plan.LocalLimit): # Note that the GlobalLimit physical plan also dynamically dispatches its own LocalLimit instructions. return physical_plan.local_limit(child_plan, node._num) elif isinstance(node, logical_plan.GlobalLimit): - return physical_plan.global_limit(child_plan, node) + return physical_plan.global_limit( + child_plan=child_plan, + limit_rows=node._num, + num_partitions=node.num_partitions(), + ) elif isinstance(node, logical_plan.Repartition): # Case: simple repartition (split) @@ -104,7 +123,10 @@ def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) -> # Do the fanout. fanout_plan: physical_plan.InProgressPhysicalPlan if node._scheme == PartitionScheme.RANDOM: - fanout_plan = physical_plan.fanout_random(child_plan, node) + fanout_plan = physical_plan.fanout_random( + child_plan=child_plan, + num_partitions=node.num_partitions(), + ) elif node._scheme == PartitionScheme.HASH: fanout_instruction = execution_step.FanoutHash( _num_outputs=node.num_partitions(), @@ -126,10 +148,19 @@ def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) -> ) elif isinstance(node, logical_plan.Sort): - return physical_plan.sort(child_plan, node) + return physical_plan.sort( + child_plan=child_plan, + sort_by=node._sort_by, + descending=node._descending, + num_partitions=node.num_partitions(), + ) elif isinstance(node, logical_plan.Coalesce): - return physical_plan.coalesce(child_plan, node) + return physical_plan.coalesce( + child_plan=child_plan, + from_num_partitions=node._children()[0].num_partitions(), + to_num_partitions=node.num_partitions(), + ) else: raise NotImplementedError(f"Unsupported plan type {node}") @@ -142,7 +173,10 @@ def _get_physical_plan(node: LogicalPlan, psets: dict[str, list[PartitionT]]) -> return physical_plan.join( left_plan=_get_physical_plan(left_child, psets), right_plan=_get_physical_plan(right_child, psets), - join=node, + left_on=node._left_on, + right_on=node._right_on, + output_projection=node._output_projection, + how=node._how, ) elif isinstance(node, logical_plan.Concat):