Skip to content

Commit

Permalink
[FEAT] [New Query Planner] All functional tests pass + add to CI. (#1274
Browse files Browse the repository at this point in the history
)

This PR ensures that all functional tests pass for the new query planner
and adds the new query planner to the CI job matrix; all tests are
covered and pass except for the Python query planner optimization tests.

The `use_new_planner` test fixture is removed, since
`DAFT_NEW_QUERY_PLANNER=1 make test` will now work with expected test
coverage.
  • Loading branch information
clarkzinzow authored Aug 15, 2023
1 parent a03aaef commit a4329f1
Show file tree
Hide file tree
Showing 36 changed files with 211 additions and 200 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
matrix:
python-version: ['3.7', '3.10']
daft-runner: [py, ray]
new-query-planner: [1, 0]
pyarrow-version: [6.0.1, 12.0]
exclude:
- daft-runner: ray
Expand Down Expand Up @@ -74,6 +75,7 @@ jobs:
# cargo llvm-cov --no-run --lcov --output-path report-output/rust-coverage-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.daft-runner }}.lcov
env:
DAFT_RUNNER: ${{ matrix.daft-runner }}
DAFT_NEW_QUERY_PLANNER: ${{ matrix.new-query-planner }}

- name: Upload coverage report
uses: actions/upload-artifact@v3
Expand Down Expand Up @@ -149,6 +151,7 @@ jobs:
matrix:
python-version: ['3.7']
daft-runner: [py, ray]
new-query-planner: [1, 0]
steps:
- uses: actions/checkout@v3
with:
Expand Down Expand Up @@ -183,6 +186,7 @@ jobs:
pytest tests/integration/test_tpch.py --durations=50
env:
DAFT_RUNNER: ${{ matrix.daft-runner }}
DAFT_NEW_QUERY_PLANNER: ${{ matrix.new-query-planner }}
- name: Send Slack notification on failure
uses: slackapi/[email protected]
if: ${{ failure() && (github.ref == 'refs/heads/main') }}
Expand Down Expand Up @@ -215,6 +219,7 @@ jobs:
matrix:
python-version: ['3.8'] # can't use 3.7 due to requiring anon mode for adlfs
daft-runner: [py, ray]
new-query-planner: [1, 0]
# These permissions are needed to interact with GitHub's OIDC Token endpoint.
# This is used in the step "Assume GitHub Actions AWS Credentials"
permissions:
Expand Down Expand Up @@ -263,6 +268,7 @@ jobs:
pytest tests/integration/io -m 'integration' --durations=50
env:
DAFT_RUNNER: ${{ matrix.daft-runner }}
DAFT_NEW_QUERY_PLANNER: ${{ matrix.new-query-planner }}
- name: Send Slack notification on failure
uses: slackapi/[email protected]
if: ${{ failure() && (github.ref == 'refs/heads/main') }}
Expand Down
6 changes: 3 additions & 3 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _get_runner_config_from_env() -> _RunnerConfig:

def _get_planner_from_env() -> bool:
"""Returns whether or not to use the new query planner."""
return bool(int(os.getenv("DAFT_DEVELOPER_RUST_QUERY_PLANNER", default="0")))
return bool(int(os.getenv("DAFT_NEW_QUERY_PLANNER", default="0")))


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -193,7 +193,7 @@ def set_new_planner() -> DaftContext:
WARNING: The new query planner is currently experimental and only partially implemented.
Alternatively, users can set this behavior via an environment variable: DAFT_DEVELOPER_RUST_QUERY_PLANNER=1
Alternatively, users can set this behavior via an environment variable: DAFT_NEW_QUERY_PLANNER=1
Returns:
DaftContext: Daft context after enabling the new query planner.
Expand All @@ -210,7 +210,7 @@ def set_new_planner() -> DaftContext:
def set_old_planner() -> DaftContext:
"""Enable the old query planner.
Alternatively, users can set this behavior via an environment variable: DAFT_DEVELOPER_RUST_QUERY_PLANNER=0
Alternatively, users can set this behavior via an environment variable: DAFT_NEW_QUERY_PLANNER=0
Returns:
DaftContext: Daft context after enabling the old query planner.
Expand Down
5 changes: 4 additions & 1 deletion daft/logical/rust_logical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ def explode(self, explode_expressions: ExpressionsProjection) -> RustLogicalPlan
def count(self) -> RustLogicalPlanBuilder:
# TODO(Clark): Add dedicated logical/physical ops when introducing metadata-based count optimizations.
first_col = col(self.schema().column_names()[0])
builder = self._builder.aggregate([first_col._count(CountMode.All)], [])
builder = self._builder.aggregate([first_col._count(CountMode.All)._expr], [])
rename_expr = ExpressionsProjection([first_col.alias("count")])
schema = rename_expr.resolve_schema(Schema._from_pyschema(builder.schema()))
builder = builder.project(rename_expr.to_inner_py_exprs(), schema._schema, ResourceRequest())
return RustLogicalPlanBuilder(builder)

def distinct(self) -> RustLogicalPlanBuilder:
Expand Down
15 changes: 0 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@
import pyarrow as pa
import pytest

from daft.context import (
DaftContext,
_set_context,
get_context,
set_new_planner,
set_old_planner,
)


def pytest_configure(config):
config.addinivalue_line(
Expand Down Expand Up @@ -44,13 +36,6 @@ def uuid_ext_type() -> UuidType:
pa.unregister_extension_type(ext_type.NAME)


@pytest.fixture(params=[False, True])
def use_new_planner(request) -> DaftContext:
old_ctx = get_context()
yield set_new_planner() if request.param else set_old_planner()
_set_context(old_ctx)


def assert_df_equals(
daft_df: pd.DataFrame,
pd_df: pd.DataFrame,
Expand Down
28 changes: 14 additions & 14 deletions tests/cookbook/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tests.conftest import assert_df_equals


def test_sum(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_sum(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""Sums across an entire column for the entire table"""
daft_df = daft_df.repartition(repartition_nparts).sum(col("Unique Key").alias("unique_key_sum"))
service_requests_csv_pd_df = pd.DataFrame.from_records(
Expand All @@ -17,7 +17,7 @@ def test_sum(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_pl
assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_sum")


def test_mean(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_mean(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""Averages across a column for entire table"""
daft_df = daft_df.repartition(repartition_nparts).mean(col("Unique Key").alias("unique_key_mean"))
service_requests_csv_pd_df = pd.DataFrame.from_records(
Expand All @@ -27,7 +27,7 @@ def test_mean(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_p
assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_mean")


def test_min(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_min(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""min across a column for entire table"""
daft_df = daft_df.repartition(repartition_nparts).min(col("Unique Key").alias("unique_key_min"))
service_requests_csv_pd_df = pd.DataFrame.from_records(
Expand All @@ -37,7 +37,7 @@ def test_min(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_pl
assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_min")


def test_max(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_max(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""max across a column for entire table"""
daft_df = daft_df.repartition(repartition_nparts).max(col("Unique Key").alias("unique_key_max"))
service_requests_csv_pd_df = pd.DataFrame.from_records(
Expand All @@ -47,7 +47,7 @@ def test_max(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_pl
assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_max")


def test_count(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_count(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""count a column for entire table"""
daft_df = daft_df.repartition(repartition_nparts).count(col("Unique Key").alias("unique_key_count"))
service_requests_csv_pd_df = pd.DataFrame.from_records(
Expand All @@ -58,7 +58,7 @@ def test_count(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_
assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_count")


def test_list(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_list(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""list agg a column for entire table"""
daft_df = daft_df.repartition(repartition_nparts).agg_list(col("Unique Key").alias("unique_key_list")).collect()
unique_key_list = service_requests_csv_pd_df["Unique Key"].to_list()
Expand All @@ -68,7 +68,7 @@ def test_list(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_p
assert set(result_list[0]) == set(unique_key_list)


def test_global_agg(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_global_agg(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""Averages across a column for entire table"""
daft_df = daft_df.repartition(repartition_nparts).agg(
[
Expand All @@ -92,7 +92,7 @@ def test_global_agg(daft_df, service_requests_csv_pd_df, repartition_nparts, use
assert_df_equals(daft_pd_df, service_requests_csv_pd_df, sort_key="unique_key_mean")


def test_filtered_sum(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_filtered_sum(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""Sums across an entire column for the entire table filtered by a certain condition"""
daft_df = (
daft_df.repartition(repartition_nparts)
Expand All @@ -119,7 +119,7 @@ def test_filtered_sum(daft_df, service_requests_csv_pd_df, repartition_nparts, u
pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"),
],
)
def test_sum_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner):
def test_sum_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys):
"""Sums across groups"""
daft_df = daft_df.repartition(repartition_nparts).groupby(*[col(k) for k in keys]).sum(col("Unique Key"))
service_requests_csv_pd_df = service_requests_csv_pd_df.groupby(keys).sum("Unique Key").reset_index()
Expand All @@ -134,7 +134,7 @@ def test_sum_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, ke
pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"),
],
)
def test_mean_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner):
def test_mean_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys):
"""Sums across groups"""
daft_df = daft_df.repartition(repartition_nparts).groupby(*[col(k) for k in keys]).mean(col("Unique Key"))
service_requests_csv_pd_df = service_requests_csv_pd_df.groupby(keys).mean("Unique Key").reset_index()
Expand All @@ -149,7 +149,7 @@ def test_mean_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, k
pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"),
],
)
def test_count_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner):
def test_count_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys):
"""count across groups"""
daft_df = daft_df.repartition(repartition_nparts).groupby(*[col(k) for k in keys]).count()
service_requests_csv_pd_df = service_requests_csv_pd_df.groupby(keys).count().reset_index()
Expand All @@ -167,7 +167,7 @@ def test_count_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts,
pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"),
],
)
def test_min_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner):
def test_min_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys):
"""min across groups"""
daft_df = (
daft_df.repartition(repartition_nparts)
Expand All @@ -188,7 +188,7 @@ def test_min_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, ke
pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"),
],
)
def test_max_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner):
def test_max_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, keys):
"""max across groups"""
daft_df = (
daft_df.repartition(repartition_nparts)
Expand All @@ -209,7 +209,7 @@ def test_max_groupby(daft_df, service_requests_csv_pd_df, repartition_nparts, ke
pytest.param(["Borough", "Complaint Type"], id="NumGroupSortKeys:2"),
],
)
def test_sum_groupby_sorted(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner):
def test_sum_groupby_sorted(daft_df, service_requests_csv_pd_df, repartition_nparts, keys):
"""Test sorting after a groupby"""
daft_df = (
daft_df.repartition(repartition_nparts)
Expand Down
10 changes: 5 additions & 5 deletions tests/cookbook/test_count_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from daft.expressions import col


def test_count_rows(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_count_rows(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""Count rows for the entire table"""
daft_df_row_count = daft_df.repartition(repartition_nparts).count_rows()
assert daft_df_row_count == service_requests_csv_pd_df.shape[0]


def test_filtered_count_rows(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_filtered_count_rows(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""Count rows on a table filtered by a certain condition"""
daft_df_row_count = daft_df.repartition(repartition_nparts).where(col("Borough") == "BROOKLYN").count_rows()

Expand All @@ -26,20 +26,20 @@ def test_filtered_count_rows(daft_df, service_requests_csv_pd_df, repartition_np
pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"),
],
)
def test_groupby_count_rows(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner):
def test_groupby_count_rows(daft_df, service_requests_csv_pd_df, repartition_nparts, keys):
"""Count rows after group by"""
daft_df = daft_df.repartition(repartition_nparts).groupby(*[col(k) for k in keys]).sum(col("Unique Key"))
service_requests_csv_pd_df = service_requests_csv_pd_df.groupby(keys).sum("Unique Key").reset_index()
assert daft_df.count_rows() == len(service_requests_csv_pd_df)


def test_dataframe_length_after_collect(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_dataframe_length_after_collect(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""Count rows after group by"""
daft_df = daft_df.repartition(repartition_nparts).collect()
assert len(daft_df) == len(service_requests_csv_pd_df)


def test_dataframe_length_before_collect(daft_df, use_new_planner):
def test_dataframe_length_before_collect(daft_df):
"""Count rows for the entire table"""
with pytest.raises(RuntimeError) as err_info:
len(daft_df)
2 changes: 1 addition & 1 deletion tests/cookbook/test_distinct.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
pytest.param(["Borough", "Complaint Type"], id="NumGroupByKeys:2"),
],
)
def test_distinct_all_columns(daft_df, service_requests_csv_pd_df, repartition_nparts, keys, use_new_planner):
def test_distinct_all_columns(daft_df, service_requests_csv_pd_df, repartition_nparts, keys):
"""Sums across groups"""
daft_df = daft_df.repartition(repartition_nparts).select(*[col(k) for k in keys]).distinct()

Expand Down
8 changes: 4 additions & 4 deletions tests/cookbook/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
),
],
)
def test_filter(daft_df_ops, daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_filter(daft_df_ops, daft_df, service_requests_csv_pd_df, repartition_nparts):
"""Filter the dataframe, retrieve the top N results and select a subset of columns"""

daft_noise_complaints = daft_df_ops(daft_df.repartition(repartition_nparts))
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_filter(daft_df_ops, daft_df, service_requests_csv_pd_df, repartition_np
),
],
)
def test_complex_filter(daft_df_ops, daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_complex_filter(daft_df_ops, daft_df, service_requests_csv_pd_df, repartition_nparts):
"""Filter the dataframe with a complex filter and select a subset of columns"""
daft_noise_complaints_brooklyn = daft_df_ops(daft_df.repartition(repartition_nparts))

Expand Down Expand Up @@ -127,7 +127,7 @@ def test_complex_filter(daft_df_ops, daft_df, service_requests_csv_pd_df, repart
),
],
)
def test_chain_filter(daft_df_ops, daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_chain_filter(daft_df_ops, daft_df, service_requests_csv_pd_df, repartition_nparts):
"""Filter the dataframe with a chain of filters and select a subset of columns"""
daft_noise_complaints_brooklyn = daft_df_ops(daft_df.repartition(repartition_nparts))

Expand All @@ -142,7 +142,7 @@ def test_chain_filter(daft_df_ops, daft_df, service_requests_csv_pd_df, repartit
assert_df_equals(daft_pd_df, pd_noise_complaints_brooklyn)


def test_filter_on_projection(use_new_planner):
def test_filter_on_projection():
"""Filter the dataframe with on top of a projection"""
df = daft.from_pydict({"x": [1, 1, 1, 1, 1]})
df = df.select(col("x") * 2)
Expand Down
8 changes: 4 additions & 4 deletions tests/cookbook/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tests.conftest import assert_df_equals


def test_simple_join(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_simple_join(daft_df, service_requests_csv_pd_df, repartition_nparts):
daft_df = daft_df.repartition(repartition_nparts)
daft_df_left = daft_df.select(col("Unique Key"), col("Borough"))
daft_df_right = daft_df.select(col("Unique Key"), col("Created Date"))
Expand All @@ -21,7 +21,7 @@ def test_simple_join(daft_df, service_requests_csv_pd_df, repartition_nparts, us
assert_df_equals(daft_pd_df, service_requests_csv_pd_df)


def test_simple_self_join(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_simple_self_join(daft_df, service_requests_csv_pd_df, repartition_nparts):
daft_df = daft_df.repartition(repartition_nparts)
daft_df = daft_df.select(col("Unique Key"), col("Borough"))

Expand All @@ -38,7 +38,7 @@ def test_simple_self_join(daft_df, service_requests_csv_pd_df, repartition_npart
assert_df_equals(daft_pd_df, service_requests_csv_pd_df)


def test_simple_join_missing_rvalues(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_simple_join_missing_rvalues(daft_df, service_requests_csv_pd_df, repartition_nparts):
daft_df_right = daft_df.sort("Unique Key").limit(25).repartition(repartition_nparts)
daft_df_left = daft_df.repartition(repartition_nparts)
daft_df_left = daft_df_left.select(col("Unique Key"), col("Borough"))
Expand All @@ -58,7 +58,7 @@ def test_simple_join_missing_rvalues(daft_df, service_requests_csv_pd_df, repart
assert_df_equals(daft_pd_df, service_requests_csv_pd_df)


def test_simple_join_missing_lvalues(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner):
def test_simple_join_missing_lvalues(daft_df, service_requests_csv_pd_df, repartition_nparts):
daft_df_right = daft_df.repartition(repartition_nparts)
daft_df_left = daft_df.sort(col("Unique Key")).limit(25).repartition(repartition_nparts)
daft_df_left = daft_df_left.select(col("Unique Key"), col("Borough"))
Expand Down
Loading

0 comments on commit a4329f1

Please sign in to comment.