From ef84e430d77fd749d8d10cd7171a9037d9abb3fe Mon Sep 17 00:00:00 2001 From: Abhishek Mohan Date: Wed, 2 Aug 2023 17:29:18 -0400 Subject: [PATCH] Resolved MyPy errors when adding Airflow pre-commit dependency. (#434) ## Description Provided different solutions for new MyPy errors that were the result of adding "apache-airflow" as a dependency within the pre-commit config. ## Checklist - [ ] I have made corresponding changes to the documentation (if required) - [ ] I have added tests that prove my fix is effective or that my feature works --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- cosmos/airflow/dag.py | 2 +- cosmos/airflow/graph.py | 2 ++ cosmos/airflow/task_group.py | 6 +++--- cosmos/dataset.py | 7 +++++-- cosmos/dbt/parser/output.py | 4 ++-- cosmos/hooks/subprocess.py | 4 ++-- cosmos/operators/base.py | 4 ++-- cosmos/operators/docker.py | 2 +- cosmos/operators/kubernetes.py | 2 +- cosmos/operators/local.py | 35 +++++++++++++++++++--------------- cosmos/operators/virtualenv.py | 8 +++----- 12 files changed, 43 insertions(+), 35 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e57183e70..e8b39ffca 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -71,7 +71,7 @@ repos: hooks: - id: mypy name: mypy-python-sdk - additional_dependencies: [types-PyYAML, types-attrs, attrs, types-requests, types-python-dateutil] + additional_dependencies: [types-PyYAML, types-attrs, attrs, types-requests, types-python-dateutil, apache-airflow] files: ^cosmos ci: diff --git a/cosmos/airflow/dag.py b/cosmos/airflow/dag.py index 948a3558b..d5465ac81 100644 --- a/cosmos/airflow/dag.py +++ b/cosmos/airflow/dag.py @@ -10,7 +10,7 @@ from cosmos.converter import airflow_kwargs, specific_kwargs, DbtToAirflowConverter -class DbtDag(DAG, DbtToAirflowConverter): # type: ignore[misc] # ignores subclass MyPy error +class DbtDag(DAG, DbtToAirflowConverter): """ Render a dbt project as an Airflow DAG. """ diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index d6a67c663..8e6cfd213 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -11,6 +11,7 @@ from cosmos.core.graph.entities import Task as TaskMetadata from cosmos.dataset import get_dbt_dataset from cosmos.dbt.graph import DbtNode +from airflow.models import BaseOperator logger = logging.getLogger(__name__) @@ -159,6 +160,7 @@ def build_airflow_graph( :param emit_datasets: Decides if Cosmos should add outlets to model classes or not. """ tasks_map = {} + task_or_group: TaskGroup | BaseOperator # In most cases, we'll map one DBT node to one Airflow task # The exception are the test nodes, since it would be too slow to run test tasks individually. diff --git a/cosmos/airflow/task_group.py b/cosmos/airflow/task_group.py index 1d75b48e6..dcdedb685 100644 --- a/cosmos/airflow/task_group.py +++ b/cosmos/airflow/task_group.py @@ -9,7 +9,7 @@ from cosmos.converter import airflow_kwargs, specific_kwargs, DbtToAirflowConverter -class DbtTaskGroup(TaskGroup, DbtToAirflowConverter): # type: ignore[misc] # ignores subclass MyPy error +class DbtTaskGroup(TaskGroup, DbtToAirflowConverter): """ Render a dbt project as an Airflow Task Group. """ @@ -20,7 +20,7 @@ def __init__( *args: Any, **kwargs: Any, ) -> None: - group_id = group_id - TaskGroup.__init__(self, group_id=group_id, *args, **airflow_kwargs(**kwargs)) + kwargs["group_id"] = group_id + TaskGroup.__init__(self, *args, **airflow_kwargs(**kwargs)) kwargs["task_group"] = self DbtToAirflowConverter.__init__(self, *args, **specific_kwargs(**kwargs)) diff --git a/cosmos/dataset.py b/cosmos/dataset.py index 5927de319..e174a5199 100644 --- a/cosmos/dataset.py +++ b/cosmos/dataset.py @@ -15,8 +15,11 @@ def __init__(self, id: str, *args: Tuple[Any], **kwargs: str): self.id = id logger.warning("Datasets are not supported in Airflow < 2.5.0") - def __eq__(self, other: "Dataset") -> bool: - return bool(self.id == other.id) + def __eq__(self, other: object) -> bool: + if isinstance(other, Dataset): + return bool(self.id == other.id) # type: ignore[attr-defined] + else: + return NotImplemented def get_dbt_dataset(connection_id: str, project_name: str, model_name: str) -> Dataset: diff --git a/cosmos/dbt/parser/output.py b/cosmos/dbt/parser/output.py index deb8d9804..791c4b605 100644 --- a/cosmos/dbt/parser/output.py +++ b/cosmos/dbt/parser/output.py @@ -2,10 +2,10 @@ import re from typing import List, Tuple -from airflow.hooks.subprocess import SubprocessResult +from cosmos.hooks.subprocess import FullOutputSubprocessResult -def parse_output(result: SubprocessResult, keyword: str) -> int: +def parse_output(result: FullOutputSubprocessResult, keyword: str) -> int: """ Parses the dbt test output message and returns the number of errors or warnings. diff --git a/cosmos/hooks/subprocess.py b/cosmos/hooks/subprocess.py index 2f8dbbfe7..9951cb5b8 100644 --- a/cosmos/hooks/subprocess.py +++ b/cosmos/hooks/subprocess.py @@ -20,12 +20,12 @@ class FullOutputSubprocessResult(NamedTuple): full_output: list[str] -class FullOutputSubprocessHook(BaseHook): # type: ignore[misc] # ignores subclass MyPy error +class FullOutputSubprocessHook(BaseHook): """Hook for running processes with the ``subprocess`` module.""" def __init__(self) -> None: self.sub_process: Popen[bytes] | None = None - super().__init__() + super().__init__() # type: ignore[no-untyped-call] def run_command( self, diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 764062a32..275f83e4f 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) -class DbtBaseOperator(BaseOperator): # type: ignore[misc] # ignores subclass MyPy error +class DbtBaseOperator(BaseOperator): """ Executes a dbt core cli command. @@ -98,7 +98,7 @@ def __init__( cancel_query_on_kill: bool = True, dbt_executable_path: str = "dbt", dbt_cmd_flags: list[str] | None = None, - **kwargs: str, + **kwargs: Any, ) -> None: self.project_dir = project_dir self.conn_id = conn_id diff --git a/cosmos/operators/docker.py b/cosmos/operators/docker.py index a5839c715..d019d7b6b 100644 --- a/cosmos/operators/docker.py +++ b/cosmos/operators/docker.py @@ -38,7 +38,7 @@ def __init__( super().__init__(image=image, **kwargs) def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any: - self.build_command(cmd_flags, context) + self.build_command(context, cmd_flags) self.log.info(f"Running command: {self.command}") # type: ignore[has-type] return super().execute(context) diff --git a/cosmos/operators/kubernetes.py b/cosmos/operators/kubernetes.py index 995aa70a7..2833bd550 100644 --- a/cosmos/operators/kubernetes.py +++ b/cosmos/operators/kubernetes.py @@ -46,7 +46,7 @@ def build_env_args(self, env: dict[str, str | bytes | PathLike[Any]]) -> None: self.env_vars = convert_env_vars({**env, **env_vars_dict}) def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> Any: - self.build_kube_args(cmd_flags, context) + self.build_kube_args(context, cmd_flags) self.log.info(f"Running command: {self.arguments}") return super().execute(context) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 3d0f3a97b..3fbc248c2 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -6,11 +6,12 @@ import signal import tempfile from pathlib import Path -from typing import Any, Callable, Sequence, Tuple +from typing import Any, Callable, Sequence import yaml from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.models.taskinstance import TaskInstance from airflow.utils.context import Context from airflow.utils.session import NEW_SESSION, provide_session from sqlalchemy.orm import Session @@ -64,7 +65,7 @@ def __init__( self.should_store_compiled_sql = should_store_compiled_sql super().__init__(**kwargs) - @cached_property # type: ignore[misc] # ignores internal untyped decorator + @cached_property def subprocess_hook(self) -> FullOutputSubprocessHook: """Returns hook for running the bash command.""" return FullOutputSubprocessHook() @@ -78,7 +79,7 @@ def exception_handling(self, result: FullOutputSubprocessResult) -> None: *result.full_output, ) - @provide_session # type: ignore[misc] # ignores internal untyped decorator + @provide_session def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None: """ Takes the compiled SQL files from the dbt run and stores them in the compiled_sql rendered template. @@ -110,18 +111,22 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se from airflow.models.renderedtifields import RenderedTaskInstanceFields ti = context["ti"] - ti.task.template_fields = self.template_fields - rtif = RenderedTaskInstanceFields(ti, render_templates=False) - - # delete the old records - session.query(RenderedTaskInstanceFields).filter( - RenderedTaskInstanceFields.dag_id == self.dag_id, - RenderedTaskInstanceFields.task_id == self.task_id, - RenderedTaskInstanceFields.run_id == ti.run_id, - ).delete() - session.add(rtif) - - def run_subprocess(self, *args: Tuple[Any], **kwargs: Any) -> FullOutputSubprocessResult: + + if isinstance(ti, TaskInstance): # verifies ti is a TaskInstance in order to access and use the "task" field + ti.task.template_fields = self.template_fields + rtif = RenderedTaskInstanceFields(ti, render_templates=False) + + # delete the old records + session.query(RenderedTaskInstanceFields).filter( + RenderedTaskInstanceFields.dag_id == self.dag_id, + RenderedTaskInstanceFields.task_id == self.task_id, + RenderedTaskInstanceFields.run_id == ti.run_id, + ).delete() + session.add(rtif) + else: + logger.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.") + + def run_subprocess(self, *args: Any, **kwargs: Any) -> FullOutputSubprocessResult: subprocess_result: FullOutputSubprocessResult = self.subprocess_hook.run_command(*args, **kwargs) return subprocess_result diff --git a/cosmos/operators/virtualenv.py b/cosmos/operators/virtualenv.py index 2566097d7..6e2c986d4 100644 --- a/cosmos/operators/virtualenv.py +++ b/cosmos/operators/virtualenv.py @@ -3,7 +3,7 @@ import logging from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Tuple +from typing import TYPE_CHECKING, Any from airflow.compat.functools import cached_property @@ -53,7 +53,7 @@ def __init__( super().__init__(**kwargs) self._venv_tmp_dir = TemporaryDirectory() - @cached_property # type: ignore[misc] # ignores internal untyped decorator + @cached_property def venv_dbt_path( self, ) -> str: @@ -85,9 +85,7 @@ def venv_dbt_path( self.log.info("Using dbt version %s available at %s", dbt_version, dbt_binary) return str(dbt_binary) - def run_subprocess( # type: ignore[override] - self, *args: Tuple[Any], command: list[str], **kwargs: Any - ) -> FullOutputSubprocessResult: + def run_subprocess(self, *args: Any, command: list[str], **kwargs: Any) -> FullOutputSubprocessResult: if self.py_requirements: command[0] = self.venv_dbt_path