diff --git a/astronomer/providers/cncf/kubernetes/hooks/kubernetes.py b/astronomer/providers/cncf/kubernetes/hooks/kubernetes.py index 0a00c3588..d497c62d3 100644 --- a/astronomer/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/astronomer/providers/cncf/kubernetes/hooks/kubernetes.py @@ -1,4 +1,7 @@ -from typing import Any, Dict +from __future__ import annotations + +import warnings +from typing import Any import aiofiles from airflow.exceptions import AirflowException @@ -6,7 +9,7 @@ from kubernetes_asyncio import client, config -def get_field(extras: Dict[str, Any], field_name: str, strict: bool = False) -> Any: +def get_field(extras: dict[str, Any], field_name: str, strict: bool = False) -> Any: """Get field from extra, first checking short name, then for backward compatibility we check for prefixed name.""" backward_compatibility_prefix = "extra__kubernetes__" if field_name.startswith("extra__"): @@ -24,6 +27,22 @@ def get_field(extras: Dict[str, Any], field_name: str, strict: bool = False) -> class KubernetesHookAsync(KubernetesHook): # noqa: D101 + """ + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook` instead + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook`" + ), + DeprecationWarning, + stacklevel=2, + ) + super().__init__(*args, **kwargs) + async def _load_config(self) -> client.ApiClient: """ cluster_context: Optional[str] = None, diff --git a/astronomer/providers/google/cloud/operators/kubernetes_engine.py b/astronomer/providers/google/cloud/operators/kubernetes_engine.py index 904e87e8a..5a70b4bcd 100644 --- a/astronomer/providers/google/cloud/operators/kubernetes_engine.py +++ b/astronomer/providers/google/cloud/operators/kubernetes_engine.py @@ -2,249 +2,28 @@ from __future__ import annotations -from typing import Any, Sequence +import warnings +from typing import Any -from airflow.exceptions import AirflowException -from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook -from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import ( - KubernetesPodOperator, -) -from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase -from kubernetes.client import models as k8s +from airflow.providers.google.cloud.operators.kubernetes_engine import GKEStartPodOperator -from astronomer.providers.cncf.kubernetes.operators.kubernetes_pod import ( - PodNotFoundException, -) -from astronomer.providers.cncf.kubernetes.triggers.wait_container import ( - PodLaunchTimeoutException, -) -from astronomer.providers.google.cloud.gke_utils import _get_gke_config_file -from astronomer.providers.google.cloud.triggers.kubernetes_engine import ( - GKEStartPodTrigger, -) -from astronomer.providers.utils.typing_compat import Context - -class GKEStartPodOperatorAsync(KubernetesPodOperator): +class GKEStartPodOperatorAsync(GKEStartPodOperator): """ - Executes a task in a Kubernetes pod in the specified Google Kubernetes - Engine cluster - - This Operator assumes that the system has gcloud installed and has configured a - connection id with a service account. - - The **minimum** required to define a cluster to create are the variables - ``task_id``, ``project_id``, ``location``, ``cluster_name``, ``name``, - ``namespace``, and ``image`` - .. seealso:: - - For more detail about Kubernetes Engine authentication have a look at the reference: - https://cloud.google.com/kubernetes-engine/docs/how-to/cluster-access-for-kubectl#internal_ip - - :param location: The name of the Google Kubernetes Engine zone or region in which the - cluster resides, e.g. 'us-central1-a' - :param cluster_name: The name of the Google Kubernetes Engine cluster the pod - should be spawned in - :param use_internal_ip: Use the internal IP address as the endpoint - :param project_id: The Google Developers Console project ID - :param gcp_conn_id: The google cloud connection ID to use. This allows for - users to specify a service account. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - :param regional: The location param is whether a region or a zone - :param is_delete_operator_pod: What to do when the pod reaches its final - state, or the execution is interrupted. If True, delete the - pod; if False, leave the pod. Current default is False, but this will be - changed in the next major release of this provider. + This class is deprecated. + Please use :class: `~airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator` + and set `deferrable` param to `True` instead. """ - def __init__( - self, - *, - location: str, - cluster_name: str, - use_internal_ip: bool = False, - project_id: str | None = None, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - regional: bool = False, - poll_interval: float = 5, - logging_interval: int | None = None, - do_xcom_push: bool = True, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - self.project_id = project_id - self.location = location - self.cluster_name = cluster_name - self.gcp_conn_id = gcp_conn_id - self.use_internal_ip = use_internal_ip - self.impersonation_chain = impersonation_chain - self.regional = regional - self.poll_interval = poll_interval - self.logging_interval = logging_interval - self.do_xcom_push = do_xcom_push - - self.pod_name: str = "" - self.pod_namespace: str = "" - - def _get_or_create_pod(self, context: Context) -> None: - """A wrapper to fetch GKE config and get or create a pod""" - with _get_gke_config_file( - gcp_conn_id=self.gcp_conn_id, - project_id=self.project_id, - cluster_name=self.cluster_name, - impersonation_chain=self.impersonation_chain, - regional=self.regional, - location=self.location, - use_internal_ip=self.use_internal_ip, - cluster_context=self.cluster_context, - ) as config_file: - self.config_file = config_file - self.pod_request_obj = self.build_pod_request_obj(context) - self.pod: k8s.V1Pod = self.get_or_create_pod(self.pod_request_obj, context) - self.pod_name = self.pod.metadata.name - self.pod_namespace = self.pod.metadata.namespace - - def execute(self, context: Context) -> Any: - """Look for a pod, if not found then create one and defer""" - self._get_or_create_pod(context) - self.log.info("Created pod=%s in namespace=%s", self.pod_name, self.pod_namespace) - - event = None - try: - with _get_gke_config_file( - gcp_conn_id=self.gcp_conn_id, - project_id=self.project_id, - cluster_name=self.cluster_name, - impersonation_chain=self.impersonation_chain, - regional=self.regional, - location=self.location, - use_internal_ip=self.use_internal_ip, - cluster_context=self.cluster_context, - ) as config_file: - hook_params: dict[str, Any] = { - "cluster_context": self.cluster_context, - "config_file": config_file, - "in_cluster": self.in_cluster, - } - hook = KubernetesHook(conn_id=None, **hook_params) - client = hook.core_v1_client - pod = client.read_namespaced_pod(self.pod_name, self.pod_namespace) - phase = pod.status.phase - if phase == PodPhase.SUCCEEDED: - event = {"status": "done", "namespace": self.namespace, "pod_name": self.name} - - elif phase == PodPhase.FAILED: - event = { - "status": "failed", - "namespace": self.namespace, - "pod_name": self.name, - "description": "Failed to start pod operator", - } - except Exception as e: - event = {"status": "error", "message": str(e)} - - if event: - return self.trigger_reentry(context, event) - - self.defer( - trigger=GKEStartPodTrigger( - namespace=self.pod_namespace, - name=self.pod_name, - in_cluster=self.in_cluster, - cluster_context=self.cluster_context, - location=self.location, - cluster_name=self.cluster_name, - use_internal_ip=self.use_internal_ip, - project_id=self.project_id, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - regional=self.regional, - poll_interval=self.poll_interval, - pending_phase_timeout=self.startup_timeout_seconds, - logging_interval=self.logging_interval, + def __init__(self, *args: Any, **kwargs: Any) -> None: + warnings.warn( + ( + "This class is deprecated." + "Please use `airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator`" + "and set `deferrable` param to `True` instead." ), - method_name=self.trigger_reentry.__name__, + DeprecationWarning, + stacklevel=2, ) - def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: # type: ignore[override] - """Callback for trigger once task reach terminal state""" - self.trigger_reentry(context=context, event=event) - - @staticmethod - def raise_for_trigger_status(event: dict[str, Any]) -> None: - """Raise exception if pod is not in expected state.""" - if event["status"] == "error": - description = event["description"] - if "error_type" in event and event["error_type"] == "PodLaunchTimeoutException": - raise PodLaunchTimeoutException(description) - else: - raise AirflowException(description) - - def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any: - """ - Point of re-entry from trigger. - - If ``logging_interval`` is None, then at this point the pod should be done and we'll just fetch - the logs and exit. - - If ``logging_interval`` is not None, it could be that the pod is still running and we'll just - grab the latest logs and defer back to the trigger again. - """ - remote_pod = None - self.raise_for_trigger_status(event) - try: - with _get_gke_config_file( - gcp_conn_id=self.gcp_conn_id, - project_id=self.project_id, - cluster_name=self.cluster_name, - impersonation_chain=self.impersonation_chain, - regional=self.regional, - location=self.location, - use_internal_ip=self.use_internal_ip, - ) as config_file: - self.config_file = config_file - self.pod = self.find_pod( - namespace=event["namespace"], - context=context, - ) - - if not self.pod: - raise PodNotFoundException("Could not find pod after resuming from deferral") - - if self.get_logs: - last_log_time = event and event.get("last_log_time") - if last_log_time: - self.log.info("Resuming logs read from time %r", last_log_time) # pragma: no cover - self.pod_manager.fetch_container_logs( - pod=self.pod, - container_name=self.BASE_CONTAINER_NAME, - follow=self.logging_interval is None, - since_time=last_log_time, - ) - - if self.do_xcom_push: - result = self.extract_xcom(pod=self.pod) - remote_pod = self.pod_manager.await_pod_completion(self.pod) - except Exception: - self.cleanup( - pod=self.pod, - remote_pod=remote_pod, - ) - raise - self.cleanup( - pod=self.pod, - remote_pod=remote_pod, - ) - if self.do_xcom_push: - ti = context["ti"] - ti.xcom_push(key="pod_name", value=self.pod.metadata.name) - ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace) - return result + super().__init__(*args, deferrable=True, **kwargs) diff --git a/astronomer/providers/google/cloud/triggers/kubernetes_engine.py b/astronomer/providers/google/cloud/triggers/kubernetes_engine.py index cc0639041..43538c9c1 100644 --- a/astronomer/providers/google/cloud/triggers/kubernetes_engine.py +++ b/astronomer/providers/google/cloud/triggers/kubernetes_engine.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import Any, AsyncIterator, Sequence from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase @@ -17,6 +18,9 @@ class GKEStartPodTrigger(WaitContainerTrigger): """ Fetch GKE cluster config and wait for pod to start up. + This class is deprecated and will be removed in 2.0.0. + Use :class: `~airflow.providers.google.cloud.triggers.kubernetes_engine.GKEStartPodTrigger` instead + :param location: The name of the Google Kubernetes Engine zone or region in which the cluster resides :param cluster_name: The name of the Google Kubernetes Engine cluster the pod should be spawned in @@ -53,6 +57,15 @@ def __init__( pending_phase_timeout: float = 120.0, logging_interval: int | None = None, ): + warnings.warn( + ( + "This module is deprecated and will be removed in 2.0.0." + "Please use `airflow.providers.google.cloud.triggers.kubernetes_engine.GKEStartPodTrigger`" + ), + DeprecationWarning, + stacklevel=2, + ) + super().__init__( container_name=self.BASE_CONTAINER_NAME, pod_name=name, diff --git a/setup.cfg b/setup.cfg index efd1a4565..cfdbe3155 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,7 +54,8 @@ apache.livy = apache-airflow-providers-apache-livy>=3.7.1 paramiko cncf.kubernetes = - apache-airflow-providers-cncf-kubernetes>=4 + # TODO: Update version when below RC is released + apache-airflow-providers-cncf-kubernetes @https://files.pythonhosted.org/packages/c9/82/c7422f848a249e437b2c5810b9ccc266b80ffd767322722b16ac97cc3013/apache_airflow_providers_cncf_kubernetes-8.0.0rc2.tar.gz kubernetes_asyncio databricks = apache-airflow-providers-databricks>=6.1.0 @@ -62,7 +63,7 @@ databricks = dbt.cloud = apache-airflow-providers-dbt-cloud>=3.5.1 google = - apache-airflow-providers-google>=10.14.0 + apache-airflow-providers-google>=10.15.0 gcloud-aio-storage gcloud-aio-bigquery http = @@ -123,9 +124,10 @@ all = apache-airflow-providers-amazon>=8.18.0rc2 apache-airflow-providers-apache-hive>=6.1.5 apache-airflow-providers-apache-livy>=3.7.1 - apache-airflow-providers-cncf-kubernetes>=4 + # TODO: Update version when below RC is released + apache-airflow-providers-cncf-kubernetes @https://files.pythonhosted.org/packages/c9/82/c7422f848a249e437b2c5810b9ccc266b80ffd767322722b16ac97cc3013/apache_airflow_providers_cncf_kubernetes-8.0.0rc2.tar.gz apache-airflow-providers-databricks>=6.1.0 - apache-airflow-providers-google>=10.14.0 + apache-airflow-providers-google>=10.15.0 apache-airflow-providers-http>=4.9.0 apache-airflow-providers-snowflake>=5.3.0 apache-airflow-providers-sftp>=4.9.0 diff --git a/tests/google/cloud/operators/test_kubernetes_engine.py b/tests/google/cloud/operators/test_kubernetes_engine.py index 6af1ee0c9..1e7cacefe 100644 --- a/tests/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/google/cloud/operators/test_kubernetes_engine.py @@ -1,28 +1,6 @@ -from unittest import mock -from unittest.mock import MagicMock, PropertyMock +from airflow.providers.google.cloud.operators.kubernetes_engine import GKEStartPodOperator -import pytest -from airflow.exceptions import AirflowException, TaskDeferred -from airflow.providers.cncf.kubernetes.utils.pod_manager import ( - PodLoggingStatus, - PodPhase, -) -from kubernetes.client import models as k8s -from kubernetes.client.models.v1_object_meta import V1ObjectMeta - -from astronomer.providers.cncf.kubernetes.operators.kubernetes_pod import ( - PodNotFoundException, -) -from astronomer.providers.cncf.kubernetes.triggers.wait_container import ( - PodLaunchTimeoutException, -) -from astronomer.providers.google.cloud.operators.kubernetes_engine import ( - GKEStartPodOperatorAsync, -) -from astronomer.providers.google.cloud.triggers.kubernetes_engine import ( - GKEStartPodTrigger, -) -from tests.utils.airflow_util import create_context +from astronomer.providers.google.cloud.operators.kubernetes_engine import GKEStartPodOperatorAsync PROJECT_ID = "astronomer-***-providers" LOCATION = "us-west1" @@ -32,245 +10,18 @@ GCP_CONN_ID = "google_cloud_default" -# TODO: Improve test class TestGKEStartPodOperatorAsync: - OPERATOR = GKEStartPodOperatorAsync( - task_id="start_pod", - project_id=PROJECT_ID, - location=LOCATION, - cluster_name=GKE_CLUSTER_NAME, - name="astro_k8s_gke_test_pod", - namespace=NAMESPACE, - image="ubuntu", - gcp_conn_id=GCP_CONN_ID, - ) - OPERATOR1 = GKEStartPodOperatorAsync( - task_id="start_pod", - project_id=PROJECT_ID, - location=LOCATION, - cluster_name=GKE_CLUSTER_NAME, - name="astro_k8s_gke_test_pod", - namespace=NAMESPACE, - image="ubuntu", - gcp_conn_id=GCP_CONN_ID, - get_logs=True, - ) - - @mock.patch("astronomer.providers.google.cloud.operators.kubernetes_engine._get_gke_config_file") - @mock.patch( - "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.build_pod_request_obj" - ) - @mock.patch( - "airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.get_or_create_pod" - ) - def test__get_or_create_pod(self, mock_get_or_create_pod, moc_build_pod_request_obj, mock_tmp, context): - """assert that _get_or_create_pod does not return any value""" - my_tmp = mock_tmp.__enter__() - my_tmp.return_value = "/tmp/tmps90l" - moc_build_pod_request_obj.return_value = {} - mock_get_or_create_pod.return_value = k8s.V1Pod( - metadata=V1ObjectMeta(name=POD_NAME, namespace=NAMESPACE) + def test_init(self): + task = GKEStartPodOperatorAsync( + task_id="start_pod", + project_id=PROJECT_ID, + location=LOCATION, + cluster_name=GKE_CLUSTER_NAME, + name="astro_k8s_gke_test_pod", + namespace=NAMESPACE, + image="ubuntu", + gcp_conn_id=GCP_CONN_ID, + logging_interval=1, ) - - assert self.OPERATOR._get_or_create_pod(context=context) is None - - @pytest.mark.parametrize("phase", PodPhase.terminal_states) - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.trigger_reentry" - ) - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.defer" - ) - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.KubernetesHook.core_v1_client", - new_callable=PropertyMock, - ) - @mock.patch("astronomer.providers.google.cloud.operators.kubernetes_engine._get_gke_config_file") - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync._get_or_create_pod" - ) - def test_execute_terminated_before_deferred( - self, - mock__get_or_create_pod, - mock__get_gke_config_file, - mock_client, - mock_defer, - mock_trigger_reentry, - phase, - context, - ): - """ - asserts that a task is deferred and a GKEStartPodTrigger will be fired - when the GKEStartPodOperatorAsync is executed. - """ - mock__get_or_create_pod.return_value = None - mock_client.return_value.read_namespaced_pod.return_value.status.phase = phase - - self.OPERATOR.execute(context) - - assert mock_trigger_reentry.called - assert not mock_defer.called - - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.trigger_reentry" - ) - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.defer" - ) - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.KubernetesHook.core_v1_client", - new_callable=PropertyMock, - ) - @mock.patch("astronomer.providers.google.cloud.operators.kubernetes_engine._get_gke_config_file") - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync._get_or_create_pod" - ) - def test_execute_encounter_exception_before_deferred( - self, - mock__get_or_create_pod, - mock__get_gke_config_file, - mock_client, - mock_defer, - mock_trigger_reentry, - context, - ): - """ - asserts that a task is deferred and a GKEStartPodTrigger will be fired - when the GKEStartPodOperatorAsync is executed. - """ - mock__get_or_create_pod.return_value = None - mock_client.side_effect = Exception - - self.OPERATOR.execute(context) - - assert mock_trigger_reentry.called - assert not mock_defer.called - - @pytest.mark.parametrize("phase", (PodPhase.RUNNING, PodPhase.PENDING)) - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.KubernetesHook.core_v1_client", - new_callable=PropertyMock, - ) - @mock.patch("astronomer.providers.google.cloud.operators.kubernetes_engine._get_gke_config_file") - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync._get_or_create_pod" - ) - def test_execute(self, mock__get_or_create_pod, mock__get_gke_config_file, mock_client, phase, context): - """ - asserts that a task is deferred and a GKEStartPodTrigger will be fired - when the GKEStartPodOperatorAsync is executed. - """ - mock__get_or_create_pod.return_value = None - mock_client.return_value.read_namespaced_pod.return_value.status.phase = phase - - with pytest.raises(TaskDeferred) as exc: - self.OPERATOR.execute(context) - - assert isinstance(exc.value.trigger, GKEStartPodTrigger), "Trigger is not a GKEStartPodTrigger" - - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.trigger_reentry" - ) - def test_execute_complete_success(self, mock_trigger_reentry): - """assert that execute_complete_success log correct message when a task succeed""" - mock_trigger_reentry.return_value = {} - - assert self.OPERATOR.execute_complete(context=create_context(self.OPERATOR), event={}) is None - - def test_execute_complete_fail(self, context): - with pytest.raises(AirflowException): - """assert that execute_complete_success raise exception when a task fail""" - self.OPERATOR.execute_complete( - context=context, event={"status": "error", "description": "Pod not found"} - ) - - def test_raise_for_trigger_status_done(self): - """Assert trigger don't raise exception in case of status is done""" - assert self.OPERATOR.raise_for_trigger_status({"status": "done"}) is None - - @mock.patch("airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.client") - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.cleanup" - ) - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync" - ".raise_for_trigger_status" - ) - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.find_pod" - ) - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs") - @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook._get_default_client") - @mock.patch("astronomer.providers.google.cloud.operators.kubernetes_engine._get_gke_config_file") - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.extract_xcom" - ) - def test_get_logs_not_running( - self, - mock_extract_xcom, - mock_gke_config, - mock_get_default_client, - fetch_container_logs, - await_pod_completion, - find_pod, - raise_for_trigger_status, - cleanup, - mock_client, - ): - mock_extract_xcom.return_value = "{}" - pod = MagicMock() - find_pod.return_value = pod - mock_client.return_value = {} - - context = create_context(self.OPERATOR1) - await_pod_completion.return_value = None - fetch_container_logs.return_value = PodLoggingStatus(False, None) - self.OPERATOR1.trigger_reentry(context, {"namespace": NAMESPACE}) - fetch_container_logs.is_called_with(pod, "base") - - @mock.patch("airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator.client") - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.cleanup" - ) - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync" - ".raise_for_trigger_status" - ) - @mock.patch( - "astronomer.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperatorAsync.find_pod" - ) - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.await_pod_completion") - @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.fetch_container_logs") - @mock.patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook._get_default_client") - @mock.patch("astronomer.providers.google.cloud.operators.kubernetes_engine._get_gke_config_file") - def test_no_pod( - self, - mock_gke_config, - mock_get_default_client, - fetch_container_logs, - await_pod_completion, - find_pod, - raise_for_trigger_status, - cleanup, - mock_client, - ): - """Assert if pod not found then raise exception""" - find_pod.return_value = None - - context = create_context(self.OPERATOR1) - with pytest.raises(PodNotFoundException): - self.OPERATOR1.trigger_reentry(context, {"namespace": NAMESPACE}) - - def test_trigger_error(self, context): - """Assert that trigger_reentry raise exception in case of error""" - - with pytest.raises(PodLaunchTimeoutException): - self.OPERATOR1.execute_complete( - context, - { - "status": "error", - "error_type": "PodLaunchTimeoutException", - "description": "any message", - }, - ) + assert isinstance(task, GKEStartPodOperator) + assert task.deferrable is True