Skip to content

Commit

Permalink
deprecate GKEStartPodOperatorAsync (#1464)
Browse files Browse the repository at this point in the history
Deprecate GKEStartPodOperatorAsync and proxy them to their Airflow OSS provider's counterpart.

---------

Co-authored-by: Pankaj Koti <[email protected]>
Co-authored-by: Pankaj Koti <[email protected]>
  • Loading branch information
3 people authored Feb 18, 2024
1 parent 2cd1e45 commit 671aa93
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 507 deletions.
23 changes: 21 additions & 2 deletions astronomer/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Any, Dict
from __future__ import annotations

import warnings
from typing import Any

import aiofiles
from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from kubernetes_asyncio import client, config


def get_field(extras: Dict[str, Any], field_name: str, strict: bool = False) -> Any:
def get_field(extras: dict[str, Any], field_name: str, strict: bool = False) -> Any:
"""Get field from extra, first checking short name, then for backward compatibility we check for prefixed name."""
backward_compatibility_prefix = "extra__kubernetes__"
if field_name.startswith("extra__"):
Expand All @@ -24,6 +27,22 @@ def get_field(extras: Dict[str, Any], field_name: str, strict: bool = False) ->


class KubernetesHookAsync(KubernetesHook): # noqa: D101
"""
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook` instead
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook`"
),
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)

async def _load_config(self) -> client.ApiClient:
"""
cluster_context: Optional[str] = None,
Expand Down
253 changes: 16 additions & 237 deletions astronomer/providers/google/cloud/operators/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,249 +2,28 @@

from __future__ import annotations

from typing import Any, Sequence
import warnings
from typing import Any

from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import (
KubernetesPodOperator,
)
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase
from kubernetes.client import models as k8s
from airflow.providers.google.cloud.operators.kubernetes_engine import GKEStartPodOperator

from astronomer.providers.cncf.kubernetes.operators.kubernetes_pod import (
PodNotFoundException,
)
from astronomer.providers.cncf.kubernetes.triggers.wait_container import (
PodLaunchTimeoutException,
)
from astronomer.providers.google.cloud.gke_utils import _get_gke_config_file
from astronomer.providers.google.cloud.triggers.kubernetes_engine import (
GKEStartPodTrigger,
)
from astronomer.providers.utils.typing_compat import Context


class GKEStartPodOperatorAsync(KubernetesPodOperator):
class GKEStartPodOperatorAsync(GKEStartPodOperator):
"""
Executes a task in a Kubernetes pod in the specified Google Kubernetes
Engine cluster
This Operator assumes that the system has gcloud installed and has configured a
connection id with a service account.
The **minimum** required to define a cluster to create are the variables
``task_id``, ``project_id``, ``location``, ``cluster_name``, ``name``,
``namespace``, and ``image``
.. seealso::
For more detail about Kubernetes Engine authentication have a look at the reference:
https://cloud.google.com/kubernetes-engine/docs/how-to/cluster-access-for-kubectl#internal_ip
:param location: The name of the Google Kubernetes Engine zone or region in which the
cluster resides, e.g. 'us-central1-a'
:param cluster_name: The name of the Google Kubernetes Engine cluster the pod
should be spawned in
:param use_internal_ip: Use the internal IP address as the endpoint
:param project_id: The Google Developers Console project ID
:param gcp_conn_id: The google cloud connection ID to use. This allows for
users to specify a service account.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param regional: The location param is whether a region or a zone
:param is_delete_operator_pod: What to do when the pod reaches its final
state, or the execution is interrupted. If True, delete the
pod; if False, leave the pod. Current default is False, but this will be
changed in the next major release of this provider.
This class is deprecated.
Please use :class: `~airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator`
and set `deferrable` param to `True` instead.
"""

def __init__(
self,
*,
location: str,
cluster_name: str,
use_internal_ip: bool = False,
project_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
regional: bool = False,
poll_interval: float = 5,
logging_interval: int | None = None,
do_xcom_push: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self.location = location
self.cluster_name = cluster_name
self.gcp_conn_id = gcp_conn_id
self.use_internal_ip = use_internal_ip
self.impersonation_chain = impersonation_chain
self.regional = regional
self.poll_interval = poll_interval
self.logging_interval = logging_interval
self.do_xcom_push = do_xcom_push

self.pod_name: str = ""
self.pod_namespace: str = ""

def _get_or_create_pod(self, context: Context) -> None:
"""A wrapper to fetch GKE config and get or create a pod"""
with _get_gke_config_file(
gcp_conn_id=self.gcp_conn_id,
project_id=self.project_id,
cluster_name=self.cluster_name,
impersonation_chain=self.impersonation_chain,
regional=self.regional,
location=self.location,
use_internal_ip=self.use_internal_ip,
cluster_context=self.cluster_context,
) as config_file:
self.config_file = config_file
self.pod_request_obj = self.build_pod_request_obj(context)
self.pod: k8s.V1Pod = self.get_or_create_pod(self.pod_request_obj, context)
self.pod_name = self.pod.metadata.name
self.pod_namespace = self.pod.metadata.namespace

def execute(self, context: Context) -> Any:
"""Look for a pod, if not found then create one and defer"""
self._get_or_create_pod(context)
self.log.info("Created pod=%s in namespace=%s", self.pod_name, self.pod_namespace)

event = None
try:
with _get_gke_config_file(
gcp_conn_id=self.gcp_conn_id,
project_id=self.project_id,
cluster_name=self.cluster_name,
impersonation_chain=self.impersonation_chain,
regional=self.regional,
location=self.location,
use_internal_ip=self.use_internal_ip,
cluster_context=self.cluster_context,
) as config_file:
hook_params: dict[str, Any] = {
"cluster_context": self.cluster_context,
"config_file": config_file,
"in_cluster": self.in_cluster,
}
hook = KubernetesHook(conn_id=None, **hook_params)
client = hook.core_v1_client
pod = client.read_namespaced_pod(self.pod_name, self.pod_namespace)
phase = pod.status.phase
if phase == PodPhase.SUCCEEDED:
event = {"status": "done", "namespace": self.namespace, "pod_name": self.name}

elif phase == PodPhase.FAILED:
event = {
"status": "failed",
"namespace": self.namespace,
"pod_name": self.name,
"description": "Failed to start pod operator",
}
except Exception as e:
event = {"status": "error", "message": str(e)}

if event:
return self.trigger_reentry(context, event)

self.defer(
trigger=GKEStartPodTrigger(
namespace=self.pod_namespace,
name=self.pod_name,
in_cluster=self.in_cluster,
cluster_context=self.cluster_context,
location=self.location,
cluster_name=self.cluster_name,
use_internal_ip=self.use_internal_ip,
project_id=self.project_id,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
regional=self.regional,
poll_interval=self.poll_interval,
pending_phase_timeout=self.startup_timeout_seconds,
logging_interval=self.logging_interval,
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This class is deprecated."
"Please use `airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator`"
"and set `deferrable` param to `True` instead."
),
method_name=self.trigger_reentry.__name__,
DeprecationWarning,
stacklevel=2,
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: # type: ignore[override]
"""Callback for trigger once task reach terminal state"""
self.trigger_reentry(context=context, event=event)

@staticmethod
def raise_for_trigger_status(event: dict[str, Any]) -> None:
"""Raise exception if pod is not in expected state."""
if event["status"] == "error":
description = event["description"]
if "error_type" in event and event["error_type"] == "PodLaunchTimeoutException":
raise PodLaunchTimeoutException(description)
else:
raise AirflowException(description)

def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
"""
Point of re-entry from trigger.
If ``logging_interval`` is None, then at this point the pod should be done and we'll just fetch
the logs and exit.
If ``logging_interval`` is not None, it could be that the pod is still running and we'll just
grab the latest logs and defer back to the trigger again.
"""
remote_pod = None
self.raise_for_trigger_status(event)
try:
with _get_gke_config_file(
gcp_conn_id=self.gcp_conn_id,
project_id=self.project_id,
cluster_name=self.cluster_name,
impersonation_chain=self.impersonation_chain,
regional=self.regional,
location=self.location,
use_internal_ip=self.use_internal_ip,
) as config_file:
self.config_file = config_file
self.pod = self.find_pod(
namespace=event["namespace"],
context=context,
)

if not self.pod:
raise PodNotFoundException("Could not find pod after resuming from deferral")

if self.get_logs:
last_log_time = event and event.get("last_log_time")
if last_log_time:
self.log.info("Resuming logs read from time %r", last_log_time) # pragma: no cover
self.pod_manager.fetch_container_logs(
pod=self.pod,
container_name=self.BASE_CONTAINER_NAME,
follow=self.logging_interval is None,
since_time=last_log_time,
)

if self.do_xcom_push:
result = self.extract_xcom(pod=self.pod)
remote_pod = self.pod_manager.await_pod_completion(self.pod)
except Exception:
self.cleanup(
pod=self.pod,
remote_pod=remote_pod,
)
raise
self.cleanup(
pod=self.pod,
remote_pod=remote_pod,
)
if self.do_xcom_push:
ti = context["ti"]
ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
return result
super().__init__(*args, deferrable=True, **kwargs)
13 changes: 13 additions & 0 deletions astronomer/providers/google/cloud/triggers/kubernetes_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import Any, AsyncIterator, Sequence

from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase
Expand All @@ -17,6 +18,9 @@ class GKEStartPodTrigger(WaitContainerTrigger):
"""
Fetch GKE cluster config and wait for pod to start up.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.google.cloud.triggers.kubernetes_engine.GKEStartPodTrigger` instead
:param location: The name of the Google Kubernetes Engine zone or region in which the
cluster resides
:param cluster_name: The name of the Google Kubernetes Engine cluster the pod should be spawned in
Expand Down Expand Up @@ -53,6 +57,15 @@ def __init__(
pending_phase_timeout: float = 120.0,
logging_interval: int | None = None,
):
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.google.cloud.triggers.kubernetes_engine.GKEStartPodTrigger`"
),
DeprecationWarning,
stacklevel=2,
)

super().__init__(
container_name=self.BASE_CONTAINER_NAME,
pod_name=name,
Expand Down
10 changes: 6 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,16 @@ apache.livy =
apache-airflow-providers-apache-livy>=3.7.1
paramiko
cncf.kubernetes =
apache-airflow-providers-cncf-kubernetes>=4
# TODO: Update version when below RC is released
apache-airflow-providers-cncf-kubernetes @https://files.pythonhosted.org/packages/c9/82/c7422f848a249e437b2c5810b9ccc266b80ffd767322722b16ac97cc3013/apache_airflow_providers_cncf_kubernetes-8.0.0rc2.tar.gz
kubernetes_asyncio
databricks =
apache-airflow-providers-databricks>=6.1.0
databricks-sql-connector>=2.0.4;python_version>='3.10'
dbt.cloud =
apache-airflow-providers-dbt-cloud>=3.5.1
google =
apache-airflow-providers-google>=10.14.0
apache-airflow-providers-google>=10.15.0
gcloud-aio-storage
gcloud-aio-bigquery
http =
Expand Down Expand Up @@ -123,9 +124,10 @@ all =
apache-airflow-providers-amazon>=8.18.0rc2
apache-airflow-providers-apache-hive>=6.1.5
apache-airflow-providers-apache-livy>=3.7.1
apache-airflow-providers-cncf-kubernetes>=4
# TODO: Update version when below RC is released
apache-airflow-providers-cncf-kubernetes @https://files.pythonhosted.org/packages/c9/82/c7422f848a249e437b2c5810b9ccc266b80ffd767322722b16ac97cc3013/apache_airflow_providers_cncf_kubernetes-8.0.0rc2.tar.gz
apache-airflow-providers-databricks>=6.1.0
apache-airflow-providers-google>=10.14.0
apache-airflow-providers-google>=10.15.0
apache-airflow-providers-http>=4.9.0
apache-airflow-providers-snowflake>=5.3.0
apache-airflow-providers-sftp>=4.9.0
Expand Down
Loading

0 comments on commit 671aa93

Please sign in to comment.