Skip to content

Commit

Permalink
Scheduler cleanup: merge logical_op_runners.py into execution_step (#…
Browse files Browse the repository at this point in the history
…1020)

Co-authored-by: Xiayue Charles Lin <[email protected]>
  • Loading branch information
xcharleslin and Xiayue Charles Lin committed Jun 22, 2023
1 parent eb7e2c8 commit cd59693
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 130 deletions.
126 changes: 121 additions & 5 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,25 @@
else:
from typing import Protocol

import daft
from daft.datasources import (
CSVSourceInfo,
JSONSourceInfo,
ParquetSourceInfo,
StorageType,
)
from daft.expressions import Expression, ExpressionsProjection, col
from daft.logical import logical_plan
from daft.logical.logical_plan import FileWrite, TabularFilesScan
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
from daft.resource_request import ResourceRequest
from daft.runners.partitioning import PartialPartitionMetadata, PartitionMetadata
from daft.table import Table
from daft.runners.partitioning import (
PartialPartitionMetadata,
PartitionMetadata,
TableParseCSVOptions,
TableReadOptions,
)
from daft.table import Table, table_io

PartitionT = TypeVar("PartitionT")
ID_GEN = itertools.count()
Expand Down Expand Up @@ -305,7 +317,7 @@ def run(self, inputs: list[Table]) -> list[Table]:
def _read_file(self, inputs: list[Table]) -> list[Table]:
assert len(inputs) == 1
[filepaths_partition] = inputs
partition = daft.runners.pyrunner.LocalLogicalPartitionOpRunner()._handle_tabular_files_scan(
partition = self._handle_tabular_files_scan(
inputs={self.logplan._filepaths_child.id(): filepaths_partition},
scan=self.logplan,
index=self.index,
Expand All @@ -327,6 +339,84 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
)
]

def _handle_tabular_files_scan(
self, inputs: dict[int, Table], scan: TabularFilesScan, index: int | None = None
) -> Table:
child_id = scan._children()[0].id()
prev_partition = inputs[child_id]
data = prev_partition.to_pydict()
assert (
scan._filepaths_column_name in data
), f"TabularFilesScan should be ran on vPartitions with '{scan._filepaths_column_name}' column"
filepaths = data[scan._filepaths_column_name]

if index is not None:
filepaths = [filepaths[index]]

# Common options for reading vPartition
fs = scan._fs
schema = scan._schema
read_options = TableReadOptions(
num_rows=scan._limit_rows,
column_names=scan._column_names, # read only specified columns
)

if scan._source_info.scan_type() == StorageType.CSV:
assert isinstance(scan._source_info, CSVSourceInfo)
table = Table.concat(
[
table_io.read_csv(
file=fp,
schema=schema,
fs=fs,
csv_options=TableParseCSVOptions(
delimiter=scan._source_info.delimiter,
header_index=0 if scan._source_info.has_headers else None,
),
read_options=read_options,
)
for fp in filepaths
]
)
elif scan._source_info.scan_type() == StorageType.JSON:
assert isinstance(scan._source_info, JSONSourceInfo)
table = Table.concat(
[
table_io.read_json(
file=fp,
schema=schema,
fs=fs,
read_options=read_options,
)
for fp in filepaths
]
)
elif scan._source_info.scan_type() == StorageType.PARQUET:
assert isinstance(scan._source_info, ParquetSourceInfo)
table = Table.concat(
[
table_io.read_parquet(
file=fp,
schema=schema,
fs=fs,
read_options=read_options,
)
for fp in filepaths
]
)
else:
raise NotImplementedError(f"PyRunner has not implemented scan: {scan._source_info.scan_type()}")

expected_schema = (
Schema._from_fields([schema[name] for name in read_options.column_names])
if read_options.column_names is not None
else schema
)
assert (
table.schema() == expected_schema
), f"Expected table to have schema:\n{expected_schema}\n\nReceived instead:\n{table.schema()}"
return table


@dataclass(frozen=True)
class WriteFile(SingleOutputInstruction):
Expand All @@ -338,7 +428,7 @@ def run(self, inputs: list[Table]) -> list[Table]:

def _write_file(self, inputs: list[Table]) -> list[Table]:
[input] = inputs
partition = daft.runners.pyrunner.LocalLogicalPartitionOpRunner()._handle_file_write(
partition = self._handle_file_write(
inputs={self.logplan._children()[0].id(): input},
file_write=self.logplan,
)
Expand All @@ -353,6 +443,32 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata])
)
]

def _handle_file_write(self, inputs: dict[int, Table], file_write: FileWrite) -> Table:
child_id = file_write._children()[0].id()
assert file_write._storage_type == StorageType.PARQUET or file_write._storage_type == StorageType.CSV
if file_write._storage_type == StorageType.PARQUET:
file_names = table_io.write_parquet(
inputs[child_id],
path=file_write._root_dir,
compression=file_write._compression,
partition_cols=file_write._partition_cols,
)
else:
file_names = table_io.write_csv(
inputs[child_id],
path=file_write._root_dir,
compression=file_write._compression,
partition_cols=file_write._partition_cols,
)

output_schema = file_write.schema()
assert len(output_schema) == 1
return Table.from_pydict(
{
output_schema.column_names()[0]: file_names,
}
)


@dataclass(frozen=True)
class Filter(SingleOutputInstruction):
Expand Down
120 changes: 0 additions & 120 deletions daft/execution/logical_op_runners.py

This file was deleted.

5 changes: 0 additions & 5 deletions daft/runners/pyrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from daft.datasources import SourceInfo
from daft.execution import physical_plan, physical_plan_factory
from daft.execution.execution_step import Instruction, MaterializedResult, PartitionTask
from daft.execution.logical_op_runners import LogicalPartitionOpRunner
from daft.filesystem import get_filesystem_from_path, glob_path_with_stats
from daft.internal.gpu import cuda_device_count
from daft.internal.rule_runner import FixedPointPolicy, Once, RuleBatch, RuleRunner
Expand Down Expand Up @@ -136,10 +135,6 @@ def get_schema_from_first_filepath(
return runner_io.sample_schema(first_filepath, source_info, fs)


class LocalLogicalPartitionOpRunner(LogicalPartitionOpRunner):
...


class PyRunner(Runner[Table]):
def __init__(self, use_thread_pool: bool | None) -> None:
super().__init__()
Expand Down

0 comments on commit cd59693

Please sign in to comment.