Skip to content

Commit

Permalink
add functions for creating ray with oauth proxy in front of the dashb…
Browse files Browse the repository at this point in the history
…oard (#298)

* add functions for creating ray with oauth proxy in front of the dashboard

Signed-off-by: Kevin <[email protected]>

* add unit test for OAuth create

Signed-off-by: Kevin <[email protected]>

* add tests for replace and generate sidecar

Signed-off-by: Kevin <[email protected]>

---------

Signed-off-by: Kevin <[email protected]>
  • Loading branch information
KPostOffice authored Oct 20, 2023
1 parent 06a3a59 commit e44dd8e
Show file tree
Hide file tree
Showing 8 changed files with 667 additions and 127 deletions.
4 changes: 3 additions & 1 deletion src/codeflare_sdk/cluster/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import urllib3
from ..utils.kube_api_helpers import _kube_api_error_handling

from typing import Optional

global api_client
api_client = None
global config_path
Expand Down Expand Up @@ -188,7 +190,7 @@ def config_check() -> str:
return config_path


def api_config_handler() -> str:
def api_config_handler() -> Optional[client.ApiClient]:
"""
This function is used to load the api client if the user has logged in
"""
Expand Down
96 changes: 82 additions & 14 deletions src/codeflare_sdk/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,19 @@
from time import sleep
from typing import List, Optional, Tuple, Dict

import openshift as oc
from kubernetes import config
from ray.job_submission import JobSubmissionClient
import urllib3

from .auth import config_check, api_config_handler
from ..utils import pretty_print
from ..utils.generate_yaml import generate_appwrapper
from ..utils.kube_api_helpers import _kube_api_error_handling
from ..utils.openshift_oauth import (
create_openshift_oauth_objects,
delete_openshift_oauth_objects,
)
from .config import ClusterConfiguration
from .model import (
AppWrapper,
Expand All @@ -40,6 +47,8 @@
import os
import requests

from kubernetes import config


class Cluster:
"""
Expand All @@ -61,6 +70,39 @@ def __init__(self, config: ClusterConfiguration):
self.config = config
self.app_wrapper_yaml = self.create_app_wrapper()
self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0]
self._client = None

@property
def _client_headers(self):
k8_client = api_config_handler() or client.ApiClient()
return {
"Authorization": k8_client.configuration.get_api_key_with_prefix(
"authorization"
)
}

@property
def _client_verify_tls(self):
return not self.config.openshift_oauth

@property
def client(self):
if self._client:
return self._client
if self.config.openshift_oauth:
print(
api_config_handler().configuration.get_api_key_with_prefix(
"authorization"
)
)
self._client = JobSubmissionClient(
self.cluster_dashboard_uri(),
headers=self._client_headers,
verify=self._client_verify_tls,
)
else:
self._client = JobSubmissionClient(self.cluster_dashboard_uri())
return self._client

def evaluate_dispatch_priority(self):
priority_class = self.config.dispatch_priority
Expand Down Expand Up @@ -147,6 +189,7 @@ def create_app_wrapper(self):
image_pull_secrets=image_pull_secrets,
dispatch_priority=dispatch_priority,
priority_val=priority_val,
openshift_oauth=self.config.openshift_oauth,
)

# creates a new cluster with the provided or default spec
Expand All @@ -156,6 +199,11 @@ def up(self):
the MCAD queue.
"""
namespace = self.config.namespace
if self.config.openshift_oauth:
create_openshift_oauth_objects(
cluster_name=self.config.name, namespace=namespace
)

try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
Expand Down Expand Up @@ -190,6 +238,11 @@ def down(self):
except Exception as e: # pragma: no cover
return _kube_api_error_handling(e)

if self.config.openshift_oauth:
delete_openshift_oauth_objects(
cluster_name=self.config.name, namespace=namespace
)

def status(
self, print_to_console: bool = True
) -> Tuple[CodeFlareClusterStatus, bool]:
Expand Down Expand Up @@ -258,7 +311,16 @@ def status(
return status, ready

def is_dashboard_ready(self) -> bool:
response = requests.get(self.cluster_dashboard_uri(), timeout=5)
try:
response = requests.get(
self.cluster_dashboard_uri(),
headers=self._client_headers,
timeout=5,
verify=self._client_verify_tls,
)
except requests.exceptions.SSLError:
# SSL exception occurs when oauth ingress has been created but cluster is not up
return False
if response.status_code == 200:
return True
else:
Expand Down Expand Up @@ -330,7 +392,13 @@ def cluster_dashboard_uri(self) -> str:
return _kube_api_error_handling(e)

for route in routes["items"]:
if route["metadata"]["name"] == f"ray-dashboard-{self.config.name}":
if route["metadata"][
"name"
] == f"ray-dashboard-{self.config.name}" or route["metadata"][
"name"
].startswith(
f"{self.config.name}-ingress"
):
protocol = "https" if route["spec"].get("tls") else "http"
return f"{protocol}://{route['spec']['host']}"
return "Dashboard route not available yet, have you run cluster.up()?"
Expand All @@ -339,30 +407,24 @@ def list_jobs(self) -> List:
"""
This method accesses the head ray node in your cluster and lists the running jobs.
"""
dashboard_route = self.cluster_dashboard_uri()
client = JobSubmissionClient(dashboard_route)
return client.list_jobs()
return self.client.list_jobs()

def job_status(self, job_id: str) -> str:
"""
This method accesses the head ray node in your cluster and returns the job status for the provided job id.
"""
dashboard_route = self.cluster_dashboard_uri()
client = JobSubmissionClient(dashboard_route)
return client.get_job_status(job_id)
return self.client.get_job_status(job_id)

def job_logs(self, job_id: str) -> str:
"""
This method accesses the head ray node in your cluster and returns the logs for the provided job id.
"""
dashboard_route = self.cluster_dashboard_uri()
client = JobSubmissionClient(dashboard_route)
return client.get_job_logs(job_id)
return self.client.get_job_logs(job_id)

def torchx_config(
self, working_dir: str = None, requirements: str = None
) -> Dict[str, str]:
dashboard_address = f"{self.cluster_dashboard_uri().lstrip('http://')}"
dashboard_address = urllib3.util.parse_url(self.cluster_dashboard_uri()).host
to_return = {
"cluster_name": self.config.name,
"dashboard_address": dashboard_address,
Expand Down Expand Up @@ -591,7 +653,7 @@ def _get_app_wrappers(


def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
if "status" in rc and "state" in rc["status"]:
if "state" in rc["status"]:
status = RayClusterStatus(rc["status"]["state"].lower())
else:
status = RayClusterStatus.UNKNOWN
Expand All @@ -606,7 +668,13 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
)
ray_route = None
for route in routes["items"]:
if route["metadata"]["name"] == f"ray-dashboard-{rc['metadata']['name']}":
if route["metadata"][
"name"
] == f"ray-dashboard-{rc['metadata']['name']}" or route["metadata"][
"name"
].startswith(
f"{rc['metadata']['name']}-ingress"
):
protocol = "https" if route["spec"].get("tls") else "http"
ray_route = f"{protocol}://{route['spec']['host']}"

Expand Down
1 change: 1 addition & 0 deletions src/codeflare_sdk/cluster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ class ClusterConfiguration:
local_interactive: bool = False
image_pull_secrets: list = field(default_factory=list)
dispatch_priority: str = None
openshift_oauth: bool = False # NOTE: to use the user must have permission to create a RoleBinding for system:auth-delegator
Loading

0 comments on commit e44dd8e

Please sign in to comment.