Skip to content

Commit

Permalink
[CHORE] Remove existing LogicalPlan from all execution concepts (#1208)
Browse files Browse the repository at this point in the history
Currently, a lot of query fragment executors are parameterized by a
LogicalPlan node object, since it holds relevant subfields. This is not
necessary; the executors can be parameterized by the subfields directly.

This PR removes LogicalPlan from all query execution concepts.

1. This is necessary in order to switch to a new logical plan
implementation.
2. It will also provide an immediate performance boost, since redundant
information (often an entire plan subtree) is not serialized with the
execution information anymore.

---------

Co-authored-by: Xiayue Charles Lin <[email protected]>
  • Loading branch information
xcharleslin and Xiayue Charles Lin committed Aug 2, 2023
1 parent 60512a0 commit fdd74f2
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 107 deletions.
144 changes: 76 additions & 68 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import itertools
import pathlib
import sys
from dataclasses import dataclass, field
from typing import Generic, TypeVar
from typing import TYPE_CHECKING, Generic, TypeVar

if TYPE_CHECKING:
import fsspec

if sys.version_info < (3, 8):
from typing_extensions import Protocol
Expand All @@ -14,11 +18,11 @@
CSVSourceInfo,
JSONSourceInfo,
ParquetSourceInfo,
SourceInfo,
StorageType,
)
from daft.expressions import Expression, ExpressionsProjection, col
from daft.logical import logical_plan
from daft.logical.logical_plan import FileWrite, TabularFilesScan
from daft.logical.logical_plan import JoinType
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
from daft.resource_request import ResourceRequest
Expand Down Expand Up @@ -306,10 +310,16 @@ def num_outputs(self) -> int:

@dataclass(frozen=True)
class ReadFile(SingleOutputInstruction):
partition_id: int
index: int | None
logplan: logical_plan.TabularFilesScan
# Known number of rows.
file_rows: int | None
# Max number of rows to read.
limit_rows: int | None
schema: Schema
fs: fsspec.AbstractFileSystem | None
columns_to_read: list[str] | None
source_info: SourceInfo
filepaths_column_name: str

def run(self, inputs: list[Table]) -> list[Table]:
return self._read_file(inputs)
Expand All @@ -318,9 +328,7 @@ def _read_file(self, inputs: list[Table]) -> list[Table]:
assert len(inputs) == 1
[filepaths_partition] = inputs
partition = self._handle_tabular_files_scan(
inputs={self.logplan._filepaths_child.id(): filepaths_partition},
scan=self.logplan,
index=self.index,
filepaths_partition=filepaths_partition,
)
return [partition]

Expand All @@ -329,8 +337,8 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])

num_rows = self.file_rows
# Only take the file read limit into account if we know how big the file is to begin with.
if num_rows is not None and self.logplan._limit_rows is not None:
num_rows = min(num_rows, self.logplan._limit_rows)
if num_rows is not None and self.limit_rows is not None:
num_rows = min(num_rows, self.limit_rows)

return [
PartialPartitionMetadata(
Expand All @@ -340,79 +348,76 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
]

def _handle_tabular_files_scan(
self, inputs: dict[int, Table], scan: TabularFilesScan, index: int | None = None
self,
filepaths_partition: Table,
) -> Table:
child_id = scan._children()[0].id()
prev_partition = inputs[child_id]
data = prev_partition.to_pydict()
data = filepaths_partition.to_pydict()
assert (
scan._filepaths_column_name in data
), f"TabularFilesScan should be ran on vPartitions with '{scan._filepaths_column_name}' column"
filepaths = data[scan._filepaths_column_name]
self.filepaths_column_name in data
), f"TabularFilesScan should be ran on vPartitions with '{self.filepaths_column_name}' column"
filepaths = data[self.filepaths_column_name]

if index is not None:
filepaths = [filepaths[index]]
if self.index is not None:
filepaths = [filepaths[self.index]]

# Common options for reading vPartition
fs = scan._fs
schema = scan._schema
read_options = TableReadOptions(
num_rows=scan._limit_rows,
column_names=scan._column_names, # read only specified columns
num_rows=self.limit_rows,
column_names=self.columns_to_read, # read only specified columns
)

if scan._source_info.scan_type() == StorageType.CSV:
assert isinstance(scan._source_info, CSVSourceInfo)
if self.source_info.scan_type() == StorageType.CSV:
assert isinstance(self.source_info, CSVSourceInfo)
table = Table.concat(
[
table_io.read_csv(
file=fp,
schema=schema,
fs=fs,
schema=self.schema,
fs=self.fs,
csv_options=TableParseCSVOptions(
delimiter=scan._source_info.delimiter,
header_index=0 if scan._source_info.has_headers else None,
delimiter=self.source_info.delimiter,
header_index=0 if self.source_info.has_headers else None,
),
read_options=read_options,
)
for fp in filepaths
]
)
elif scan._source_info.scan_type() == StorageType.JSON:
assert isinstance(scan._source_info, JSONSourceInfo)
elif self.source_info.scan_type() == StorageType.JSON:
assert isinstance(self.source_info, JSONSourceInfo)
table = Table.concat(
[
table_io.read_json(
file=fp,
schema=schema,
fs=fs,
schema=self.schema,
fs=self.fs,
read_options=read_options,
)
for fp in filepaths
]
)
elif scan._source_info.scan_type() == StorageType.PARQUET:
assert isinstance(scan._source_info, ParquetSourceInfo)
elif self.source_info.scan_type() == StorageType.PARQUET:
assert isinstance(self.source_info, ParquetSourceInfo)
table = Table.concat(
[
table_io.read_parquet(
file=fp,
schema=schema,
fs=fs,
schema=self.schema,
fs=self.fs,
read_options=read_options,
io_config=scan._source_info.io_config,
use_native_downloader=scan._source_info.use_native_downloader,
io_config=self.source_info.io_config,
use_native_downloader=self.source_info.use_native_downloader,
)
for fp in filepaths
]
)
else:
raise NotImplementedError(f"PyRunner has not implemented scan: {scan._source_info.scan_type()}")
raise NotImplementedError(f"PyRunner has not implemented scan: {self.source_info.scan_type()}")

expected_schema = (
Schema._from_fields([schema[name] for name in read_options.column_names])
Schema._from_fields([self.schema[name] for name in read_options.column_names])
if read_options.column_names is not None
else schema
else self.schema
)
assert (
table.schema() == expected_schema
Expand All @@ -422,17 +427,19 @@ def _handle_tabular_files_scan(

@dataclass(frozen=True)
class WriteFile(SingleOutputInstruction):
partition_id: int
logplan: logical_plan.FileWrite
file_type: StorageType
schema: Schema
root_dir: str | pathlib.Path
compression: str | None
partition_cols: ExpressionsProjection | None

def run(self, inputs: list[Table]) -> list[Table]:
return self._write_file(inputs)

def _write_file(self, inputs: list[Table]) -> list[Table]:
[input] = inputs
partition = self._handle_file_write(
inputs={self.logplan._children()[0].id(): input},
file_write=self.logplan,
input=input,
)
return [partition]

Expand All @@ -445,29 +452,27 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
)
]

def _handle_file_write(self, inputs: dict[int, Table], file_write: FileWrite) -> Table:
child_id = file_write._children()[0].id()
assert file_write._storage_type == StorageType.PARQUET or file_write._storage_type == StorageType.CSV
if file_write._storage_type == StorageType.PARQUET:
def _handle_file_write(self, input: Table) -> Table:
assert self.file_type == StorageType.PARQUET or self.file_type == StorageType.CSV
if self.file_type == StorageType.PARQUET:
file_names = table_io.write_parquet(
inputs[child_id],
path=file_write._root_dir,
compression=file_write._compression,
partition_cols=file_write._partition_cols,
input,
path=self.root_dir,
compression=self.compression,
partition_cols=self.partition_cols,
)
else:
file_names = table_io.write_csv(
inputs[child_id],
path=file_write._root_dir,
compression=file_write._compression,
partition_cols=file_write._partition_cols,
input,
path=self.root_dir,
compression=self.compression,
partition_cols=self.partition_cols,
)

output_schema = file_write.schema()
assert len(output_schema) == 1
assert len(self.schema) == 1
return Table.from_pydict(
{
output_schema.column_names()[0]: file_names,
self.schema.column_names()[0]: file_names,
}
)

Expand Down Expand Up @@ -516,15 +521,15 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])

@dataclass(frozen=True)
class LocalCount(SingleOutputInstruction):
logplan: logical_plan.LocalCount
schema: Schema

def run(self, inputs: list[Table]) -> list[Table]:
return self._count(inputs)

def _count(self, inputs: list[Table]) -> list[Table]:
[input] = inputs
partition = Table.from_pydict({"count": [len(input)]})
assert partition.schema() == self.logplan.schema()
assert partition.schema() == self.schema
return [partition]

def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]:
Expand Down Expand Up @@ -629,7 +634,10 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])

@dataclass(frozen=True)
class Join(SingleOutputInstruction):
logplan: logical_plan.Join
left_on: ExpressionsProjection
right_on: ExpressionsProjection
output_projection: ExpressionsProjection
how: JoinType

def run(self, inputs: list[Table]) -> list[Table]:
return self._join(inputs)
Expand All @@ -638,10 +646,10 @@ def _join(self, inputs: list[Table]) -> list[Table]:
[left, right] = inputs
result = left.join(
right,
left_on=self.logplan._left_on,
right_on=self.logplan._right_on,
output_projection=self.logplan._output_projection,
how=self.logplan._how.value,
left_on=self.left_on,
right_on=self.right_on,
output_projection=self.output_projection,
how=self.how.value,
)
return [result]

Expand Down
Loading

0 comments on commit fdd74f2

Please sign in to comment.