Skip to content

Commit

Permalink
Resolved MyPy errors when adding Airflow pre-commit dependency. (astr…
Browse files Browse the repository at this point in the history
…onomer#434)

## Description

Provided different solutions for new MyPy errors that were the result of
adding "apache-airflow" as a dependency within the pre-commit config.

## Checklist

- [ ] I have made corresponding changes to the documentation (if
required)
- [ ] I have added tests that prove my fix is effective or that my
feature works

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Abhishek Mohan and pre-commit-ci[bot] committed Aug 2, 2023
1 parent 8c49661 commit ef84e43
Show file tree
Hide file tree
Showing 12 changed files with 43 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ repos:
hooks:
- id: mypy
name: mypy-python-sdk
additional_dependencies: [types-PyYAML, types-attrs, attrs, types-requests, types-python-dateutil]
additional_dependencies: [types-PyYAML, types-attrs, attrs, types-requests, types-python-dateutil, apache-airflow]
files: ^cosmos

ci:
Expand Down
2 changes: 1 addition & 1 deletion cosmos/airflow/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cosmos.converter import airflow_kwargs, specific_kwargs, DbtToAirflowConverter


class DbtDag(DAG, DbtToAirflowConverter): # type: ignore[misc] # ignores subclass MyPy error
class DbtDag(DAG, DbtToAirflowConverter):
"""
Render a dbt project as an Airflow DAG.
"""
Expand Down
2 changes: 2 additions & 0 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from cosmos.core.graph.entities import Task as TaskMetadata
from cosmos.dataset import get_dbt_dataset
from cosmos.dbt.graph import DbtNode
from airflow.models import BaseOperator


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -159,6 +160,7 @@ def build_airflow_graph(
:param emit_datasets: Decides if Cosmos should add outlets to model classes or not.
"""
tasks_map = {}
task_or_group: TaskGroup | BaseOperator

# In most cases, we'll map one DBT node to one Airflow task
# The exception are the test nodes, since it would be too slow to run test tasks individually.
Expand Down
6 changes: 3 additions & 3 deletions cosmos/airflow/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from cosmos.converter import airflow_kwargs, specific_kwargs, DbtToAirflowConverter


class DbtTaskGroup(TaskGroup, DbtToAirflowConverter): # type: ignore[misc] # ignores subclass MyPy error
class DbtTaskGroup(TaskGroup, DbtToAirflowConverter):
"""
Render a dbt project as an Airflow Task Group.
"""
Expand All @@ -20,7 +20,7 @@ def __init__(
*args: Any,
**kwargs: Any,
) -> None:
group_id = group_id
TaskGroup.__init__(self, group_id=group_id, *args, **airflow_kwargs(**kwargs))
kwargs["group_id"] = group_id
TaskGroup.__init__(self, *args, **airflow_kwargs(**kwargs))
kwargs["task_group"] = self
DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs))
7 changes: 5 additions & 2 deletions cosmos/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ def __init__(self, id: str, *args: Tuple[Any], **kwargs: str):
self.id = id
logger.warning("Datasets are not supported in Airflow < 2.5.0")

def __eq__(self, other: "Dataset") -> bool:
return bool(self.id == other.id)
def __eq__(self, other: object) -> bool:
if isinstance(other, Dataset):
return bool(self.id == other.id) # type: ignore[attr-defined]
else:
return NotImplemented


def get_dbt_dataset(connection_id: str, project_name: str, model_name: str) -> Dataset:
Expand Down
4 changes: 2 additions & 2 deletions cosmos/dbt/parser/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import re
from typing import List, Tuple

from airflow.hooks.subprocess import SubprocessResult
from cosmos.hooks.subprocess import FullOutputSubprocessResult


def parse_output(result: SubprocessResult, keyword: str) -> int:
def parse_output(result: FullOutputSubprocessResult, keyword: str) -> int:
"""
Parses the dbt test output message and returns the number of errors or warnings.
Expand Down
4 changes: 2 additions & 2 deletions cosmos/hooks/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ class FullOutputSubprocessResult(NamedTuple):
full_output: list[str]


class FullOutputSubprocessHook(BaseHook): # type: ignore[misc] # ignores subclass MyPy error
class FullOutputSubprocessHook(BaseHook):
"""Hook for running processes with the ``subprocess`` module."""

def __init__(self) -> None:
self.sub_process: Popen[bytes] | None = None
super().__init__()
super().__init__() # type: ignore[no-untyped-call]

def run_command(
self,
Expand Down
4 changes: 2 additions & 2 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
logger = logging.getLogger(__name__)


class DbtBaseOperator(BaseOperator): # type: ignore[misc] # ignores subclass MyPy error
class DbtBaseOperator(BaseOperator):
"""
Executes a dbt core cli command.
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(
cancel_query_on_kill: bool = True,
dbt_executable_path: str = "dbt",
dbt_cmd_flags: list[str] | None = None,
**kwargs: str,
**kwargs: Any,
) -> None:
self.project_dir = project_dir
self.conn_id = conn_id
Expand Down
2 changes: 1 addition & 1 deletion cosmos/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
super().__init__(image=image, **kwargs)

def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any:
self.build_command(cmd_flags, context)
self.build_command(context, cmd_flags)
self.log.info(f"Running command: {self.command}") # type: ignore[has-type]
return super().execute(context)

Expand Down
2 changes: 1 addition & 1 deletion cosmos/operators/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def build_env_args(self, env: dict[str, str | bytes | PathLike[Any]]) -> None:
self.env_vars = convert_env_vars({**env, **env_vars_dict})

def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any:
self.build_kube_args(cmd_flags, context)
self.build_kube_args(context, cmd_flags)
self.log.info(f"Running command: {self.arguments}")
return super().execute(context)

Expand Down
35 changes: 20 additions & 15 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import signal
import tempfile
from pathlib import Path
from typing import Any, Callable, Sequence, Tuple
from typing import Any, Callable, Sequence

import yaml
from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.models.taskinstance import TaskInstance
from airflow.utils.context import Context
from airflow.utils.session import NEW_SESSION, provide_session
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(
self.should_store_compiled_sql = should_store_compiled_sql
super().__init__(**kwargs)

@cached_property # type: ignore[misc] # ignores internal untyped decorator
@cached_property
def subprocess_hook(self) -> FullOutputSubprocessHook:
"""Returns hook for running the bash command."""
return FullOutputSubprocessHook()
Expand All @@ -78,7 +79,7 @@ def exception_handling(self, result: FullOutputSubprocessResult) -> None:
*result.full_output,
)

@provide_session # type: ignore[misc] # ignores internal untyped decorator
@provide_session
def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None:
"""
Takes the compiled SQL files from the dbt run and stores them in the compiled_sql rendered template.
Expand Down Expand Up @@ -110,18 +111,22 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se
from airflow.models.renderedtifields import RenderedTaskInstanceFields

ti = context["ti"]
ti.task.template_fields = self.template_fields
rtif = RenderedTaskInstanceFields(ti, render_templates=False)

# delete the old records
session.query(RenderedTaskInstanceFields).filter(
RenderedTaskInstanceFields.dag_id == self.dag_id,
RenderedTaskInstanceFields.task_id == self.task_id,
RenderedTaskInstanceFields.run_id == ti.run_id,
).delete()
session.add(rtif)

def run_subprocess(self, *args: Tuple[Any], **kwargs: Any) -> FullOutputSubprocessResult:

if isinstance(ti, TaskInstance): # verifies ti is a TaskInstance in order to access and use the "task" field
ti.task.template_fields = self.template_fields
rtif = RenderedTaskInstanceFields(ti, render_templates=False)

# delete the old records
session.query(RenderedTaskInstanceFields).filter(
RenderedTaskInstanceFields.dag_id == self.dag_id,
RenderedTaskInstanceFields.task_id == self.task_id,
RenderedTaskInstanceFields.run_id == ti.run_id,
).delete()
session.add(rtif)
else:
logger.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.")

def run_subprocess(self, *args: Any, **kwargs: Any) -> FullOutputSubprocessResult:
subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(*args, **kwargs)
return subprocess_result

Expand Down
8 changes: 3 additions & 5 deletions cosmos/operators/virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Tuple
from typing import TYPE_CHECKING, Any


from airflow.compat.functools import cached_property
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
super().__init__(**kwargs)
self._venv_tmp_dir = TemporaryDirectory()

@cached_property # type: ignore[misc] # ignores internal untyped decorator
@cached_property
def venv_dbt_path(
self,
) -> str:
Expand Down Expand Up @@ -85,9 +85,7 @@ def venv_dbt_path(
self.log.info("Using dbt version %s available at %s", dbt_version, dbt_binary)
return str(dbt_binary)

def run_subprocess( # type: ignore[override]
self, *args: Tuple[Any], command: list[str], **kwargs: Any
) -> FullOutputSubprocessResult:
def run_subprocess(self, *args: Any, command: list[str], **kwargs: Any) -> FullOutputSubprocessResult:
if self.py_requirements:
command[0] = self.venv_dbt_path

Expand Down

0 comments on commit ef84e43

Please sign in to comment.