Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIDY FIRST] Type BaseRunner class #10607

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
a0c0291
Add typing to `BaseRunner.__init__`
QMalcolm Aug 23, 2024
d678d73
Add typing to `BuildResult.get_result_status`
QMalcolm Aug 23, 2024
01219eb
Add typing for `BaseRunner.run_with_hooks`
QMalcolm Aug 23, 2024
2092a70
Add type hinting for `BaseRunner._build_run_result`
QMalcolm Aug 26, 2024
7a80aa3
Add type hinting to `error_result`, `ephemeral_result`, and `from_run…
QMalcolm Aug 26, 2024
f4dd99b
Fix typing of `node` in `BaseRunner` init and other methods
QMalcolm Aug 26, 2024
15ef78b
Add `ResultStatus` type union definition and use to fix status typing…
QMalcolm Aug 26, 2024
b909090
Add type hinting to `BaseRunner.compile_and_execute`
QMalcolm Aug 26, 2024
8cc7124
Add typing to `node` property of `ExecutionContext`
QMalcolm Aug 26, 2024
436e1e8
Add type hinting to `execute` and `run` of `BaseRunner`
QMalcolm Aug 26, 2024
486447e
Add type hinting to exception handling methods of `BaseRunner`
QMalcolm Aug 26, 2024
3bd2972
Add type hinting to `safe_run` and `_safe_release_connection` of `Bas…
QMalcolm Aug 26, 2024
d274046
Add type hinting to `before_execute` and `after_execute` of `BaseRunner`
QMalcolm Aug 26, 2024
c3512ed
Add type hinting to `BaseRunner._skip_caused_by_ephemeral_failure`
QMalcolm Aug 26, 2024
008e90a
Add type hinting to `on_skip` and fix unhandled potential runtime error
QMalcolm Aug 26, 2024
3b2481b
Add changie doc for BaseRunner type hinting improvements
QMalcolm Aug 26, 2024
f500372
Update .changes/unreleased/Under the Hood-20240826-141843.yaml
QMalcolm Aug 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240826-141843.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Add type hinting to BaseRunner class to increase stability guarantees
time: 2024-08-26T14:18:43.834018-05:00
custom:
Author: QMalcolm
Issue: "10606"
5 changes: 4 additions & 1 deletion core/dbt/artifacts/schemas/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,12 @@ class FreshnessStatus(StrEnum):
RuntimeErr = NodeStatus.RuntimeErr


ResultStatus = Union[RunStatus, TestStatus, FreshnessStatus]


@dataclass
class BaseResult(dbtClassMixin):
status: Union[RunStatus, TestStatus, FreshnessStatus]
status: ResultStatus
timing: List[TimingInfo]
thread_id: str
execution_time: float
Expand Down
129 changes: 84 additions & 45 deletions core/dbt/task/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
from contextlib import nullcontext
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional, Set, Union

from agate import Table

import dbt.exceptions
import dbt_common.exceptions.base
from dbt import tracking
from dbt.adapters.base.impl import BaseAdapter
from dbt.artifacts.resources.types import NodeType
from dbt.artifacts.schemas.results import (
NodeStatus,
ResultStatus,
RunningStatus,
RunStatus,
TimingInfo,
Expand All @@ -26,6 +30,8 @@
from dbt.config.profile import read_profile
from dbt.constants import DBT_PROJECT_FILE_NAME
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.results import BaseResult
from dbt.events.types import (
CatchableExceptionOnRun,
GenericExceptionOnRun,
Expand Down Expand Up @@ -159,24 +165,31 @@
timing information and the newest (compiled vs executed) form of the node.
"""

def __init__(self, node) -> None:
def __init__(self, node: ResultNode) -> None:
self.timing: List[TimingInfo] = []
self.node = node
self.node: ResultNode = node


class BaseRunner(metaclass=ABCMeta):
def __init__(self, config, adapter, node, node_index, num_nodes) -> None:
self.config = config
self.compiler = Compiler(config)
self.adapter = adapter
self.node = node
self.node_index = node_index
self.num_nodes = num_nodes

self.skip = False
def __init__(
self,
config: RuntimeConfig,
adapter: BaseAdapter,
node: ResultNode,
node_index: int,
num_nodes: int,
) -> None:
self.config: RuntimeConfig = config
self.compiler: Compiler = Compiler(config)
self.adapter: BaseAdapter = adapter
self.node: ResultNode = node
self.node_index: int = node_index
self.num_nodes: int = num_nodes

self.skip: bool = False
self.skip_cause: Optional[RunResult] = None

self.run_ephemeral_models = False
self.run_ephemeral_models: bool = False

@abstractmethod
def compile(self, manifest: Manifest) -> Any:
Expand All @@ -185,7 +198,7 @@
def _node_build_path(self) -> Optional[str]:
return self.node.build_path if hasattr(self.node, "build_path") else None

def get_result_status(self, result) -> Dict[str, str]:
def get_result_status(self, result: BaseResult) -> Dict[str, str]:
if result.status == NodeStatus.Error:
return {"node_status": "error", "node_error": str(result.message)}
elif result.status == NodeStatus.Skipped:
Expand All @@ -197,7 +210,7 @@
else:
return {"node_status": "passed"}

def run_with_hooks(self, manifest):
def run_with_hooks(self, manifest: Manifest):
if self.skip:
return self.on_skip()

Expand All @@ -217,15 +230,15 @@

def _build_run_result(
self,
node,
start_time,
status,
timing_info,
message,
agate_table=None,
adapter_response=None,
failures=None,
):
node: ResultNode,
start_time: float,
status: ResultStatus,
timing_info: List[TimingInfo],
message: Optional[str] = None,
agate_table: Optional[Table] = None,
adapter_response: Optional[Dict[str, Any]] = None,
failures: Optional[int] = None,
) -> RunResult:
execution_time = time.time() - start_time
thread_id = threading.current_thread().name
if adapter_response is None:
Expand All @@ -242,7 +255,13 @@
failures=failures,
)

def error_result(self, node, message, start_time, timing_info):
def error_result(
self,
node: ResultNode,
message: str,
start_time: float,
timing_info: List[TimingInfo],
) -> RunResult:
return self._build_run_result(
node=node,
start_time=start_time,
Expand All @@ -251,7 +270,12 @@
message=message,
)

def ephemeral_result(self, node, start_time, timing_info):
def ephemeral_result(
self,
node: ResultNode,
start_time: float,
timing_info: List[TimingInfo],
) -> RunResult:
return self._build_run_result(
node=node,
start_time=start_time,
Expand All @@ -260,7 +284,12 @@
message=None,
)

def from_run_result(self, result, start_time, timing_info):
def from_run_result(
self,
result: RunResult,
start_time: float,
timing_info: List[TimingInfo],
) -> RunResult:
return self._build_run_result(
node=result.node,
start_time=start_time,
Expand All @@ -272,7 +301,11 @@
failures=result.failures,
)

def compile_and_execute(self, manifest, ctx):
def compile_and_execute(
self,
manifest: Manifest,
ctx: ExecutionContext,
) -> Optional[RunResult]:
result = None
with (
self.adapter.connection_named(self.node.unique_id, self.node)
Expand Down Expand Up @@ -305,7 +338,11 @@

return result

def _handle_catchable_exception(self, e, ctx):
def _handle_catchable_exception(
self,
e: Union[CompilationError, DbtRuntimeError],
ctx: ExecutionContext,
) -> str:
if e.node is None:
e.add_node(ctx.node)

Expand All @@ -316,15 +353,15 @@
)
return str(e)

def _handle_internal_exception(self, e, ctx):
def _handle_internal_exception(self, e: DbtInternalError, ctx: ExecutionContext) -> str:
fire_event(
InternalErrorOnRun(
build_path=self._node_build_path(), exc=str(e), node_info=get_node_info()
)
)
return str(e)

def _handle_generic_exception(self, e, ctx):
def _handle_generic_exception(self, e: Exception, ctx: ExecutionContext) -> str:
fire_event(
GenericExceptionOnRun(
build_path=self._node_build_path(),
Expand All @@ -337,7 +374,7 @@

return str(e)

def handle_exception(self, e, ctx):
def handle_exception(self, e, ctx) -> str:
catchable_errors = (CompilationError, DbtRuntimeError)
if isinstance(e, catchable_errors):
error = self._handle_catchable_exception(e, ctx)
Expand All @@ -347,7 +384,7 @@
error = self._handle_generic_exception(e, ctx)
return error

def safe_run(self, manifest):
def safe_run(self, manifest: Manifest) -> RunResult:
started = time.time()
ctx = ExecutionContext(self.node)
error = None
Expand Down Expand Up @@ -378,7 +415,7 @@
result = self.ephemeral_result(ctx.node, started, ctx.timing)
return result

def _safe_release_connection(self):
def _safe_release_connection(self) -> Optional[str]:
"""Try to release a connection. If an exception is hit, log and return
the error string.
"""
Expand All @@ -394,24 +431,24 @@

return None

def before_execute(self):
raise NotImplementedError()
def before_execute(self) -> None:
raise NotImplementedError("The `before_execute` function hasn't been implemented")

Check warning on line 435 in core/dbt/task/base.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/base.py#L435

Added line #L435 was not covered by tests

def execute(self, compiled_node, manifest):
raise NotImplementedError()
def execute(self, compiled_node: ResultNode, manifest: Manifest) -> RunResult:
raise NotImplementedError(msg="The `execute` function hasn't been implemented")

Check warning on line 438 in core/dbt/task/base.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/base.py#L438

Added line #L438 was not covered by tests

def run(self, compiled_node, manifest):
def run(self, compiled_node: ResultNode, manifest: Manifest) -> RunResult:
return self.execute(compiled_node, manifest)

def after_execute(self, result):
raise NotImplementedError()
def after_execute(self, result: RunResult) -> None:
raise NotImplementedError("The `after_execute` function hasn't been implemented")

Check warning on line 444 in core/dbt/task/base.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/task/base.py#L444

Added line #L444 was not covered by tests

def _skip_caused_by_ephemeral_failure(self):
def _skip_caused_by_ephemeral_failure(self) -> bool:
if self.skip_cause is None or self.skip_cause.node is None:
return False
return self.skip_cause.node.is_ephemeral_model

def on_skip(self):
def on_skip(self) -> RunResult:
schema_name = getattr(self.node, "schema", "")
node_name = self.node.name

Expand All @@ -427,7 +464,9 @@
relation=node_name,
index=self.node_index,
total=self.num_nodes,
status=self.skip_cause.status,
status=(
self.skip_cause.status if self.skip_cause is not None else "unknown"
),
)
)
# skip_cause here should be the run_result from the ephemeral model
Expand Down Expand Up @@ -461,7 +500,7 @@
node_result = RunResult.from_node(self.node, RunStatus.Skipped, error_message)
return node_result

def do_skip(self, cause=None):
def do_skip(self, cause: Optional[RunResult] = None) -> None:
self.skip = True
self.skip_cause = cause

Expand Down
Loading