Skip to content

Commit

Permalink
[BUG] Check env in benchmarking script (#3297)
Browse files Browse the repository at this point in the history
Using `ctx.get_or_create_runner` in benchmarking warmup code / metrics
builder causes subsequent `ray.inits` to crash. Just check the
`DAFT_RUNNER` environment var instead, which should be set.

Tested:
- local ->
https://github.com/Eventual-Inc/daft-benchmarking/actions/runs/11838323155
- remote ->
https://github.com/Eventual-Inc/daft-benchmarking/actions/runs/11838783067

---------

Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
colin-ho and Colin Ho authored Nov 14, 2024
1 parent 05048d9 commit 25c3b26
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions benchmarking/tpch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
import subprocess
import warnings
from datetime import datetime, timezone
from typing import Any, Callable
from typing import Any, Callable, Literal

import ray

import daft
from benchmarking.tpch import answers, data_generation
from daft import DataFrame
from daft.context import get_context
from daft.runners.profiler import profiler

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -130,8 +129,7 @@ def run_all_benchmarks(
):
get_df = get_df_with_parquet_folder(parquet_folder)

daft_context = get_context()
metrics_builder = MetricsBuilder(daft_context.get_or_create_runner().name)
metrics_builder = MetricsBuilder(get_daft_benchmark_runner_name())

for i in questions:
# Run as a Ray Job if dashboard URL is provided
Expand Down Expand Up @@ -194,6 +192,16 @@ def get_daft_version() -> str:
return daft.get_version()


def get_daft_benchmark_runner_name() -> Literal["ray"] | Literal["py"] | Literal["native"]:
"""Test utility that checks the environment variable for the runner that is being used for the benchmarking"""
name = os.getenv("DAFT_RUNNER")
assert name is not None, "Tests must be run with $DAFT_RUNNER env var"
name = name.lower()

assert name in {"ray", "py", "native"}, f"Runner name not recognized: {name}"
return name


def get_ray_runtime_env(requirements: str | None) -> dict:
runtime_env = {
"py_modules": [daft],
Expand All @@ -210,13 +218,10 @@ def get_ray_runtime_env(requirements: str | None) -> dict:

def warmup_environment(requirements: str | None, parquet_folder: str):
"""Performs necessary setup of Daft on the current benchmarking environment"""
ctx = daft.context.get_context()

if ctx.get_or_create_runner().name == "ray":
if get_daft_benchmark_runner_name() == "ray":
runtime_env = get_ray_runtime_env(requirements)

ray.init(
address=ctx._runner.ray_address,
runtime_env=runtime_env,
)

Expand Down

0 comments on commit 25c3b26

Please sign in to comment.