Skip to content

Commit

Permalink
docs: enhance common module code documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Ygnas authored and openshift-merge-bot[bot] committed Nov 5, 2024
1 parent 1f026b8 commit ac1a1dc
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 19 deletions.
27 changes: 24 additions & 3 deletions src/codeflare_sdk/common/kubernetes_cluster/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,20 @@ def load_kube_config(self):

def config_check() -> str:
"""
Function for loading the config file at the default config location ~/.kube/config if the user has not
specified their own config file or has logged in with their token and server.
Check and load the Kubernetes config from the default location.
This function checks if a Kubernetes config file exists at the default path
(`~/.kube/config`). If none is provided, it tries to load in-cluster config.
If the `config_path` global variable is set by an external module (e.g., `auth.py`),
this path will be used directly.
Returns:
str:
The loaded config path if successful.
Raises:
PermissionError:
If no valid credentials or config file is found.
"""
global config_path
global api_client
Expand Down Expand Up @@ -215,7 +227,16 @@ def _gen_ca_cert_path(ca_cert_path: Optional[str]):


def get_api_client() -> client.ApiClient:
"This function should load the api client with defaults"
"""
Retrieve the Kubernetes API client with the default configuration.
This function returns the current API client instance if already loaded,
or creates a new API client with the default configuration.
Returns:
client.ApiClient:
The Kubernetes API client object.
"""
if api_client != None:
return api_client
to_return = client.ApiClient()
Expand Down
65 changes: 57 additions & 8 deletions src/codeflare_sdk/common/kueue/kueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,23 @@
from kubernetes.client.exceptions import ApiException


def get_default_kueue_name(namespace: str):
# If the local queue is set, use it. Otherwise, try to use the default queue.
def get_default_kueue_name(namespace: str) -> Optional[str]:
"""
Retrieves the default Kueue name from the provided namespace.
This function attempts to fetch the local queues in the given namespace and checks if any of them is annotated
as the default queue. If found, the name of the default queue is returned.
The default queue is marked with the annotation "kueue.x-k8s.io/default-queue" set to "true."
Args:
namespace (str):
The Kubernetes namespace where the local queues are located.
Returns:
Optional[str]:
The name of the default queue if it exists, otherwise None.
"""
try:
config_check()
api_instance = client.CustomObjectsApi(get_api_client())
Expand Down Expand Up @@ -58,12 +73,14 @@ def list_local_queues(
Depending on the version of the local queue API, the available flavors may not be present in the response.
Args:
namespace (str, optional): The namespace to list local queues from. Defaults to None.
flavors (List[str], optional): The flavors to filter local queues by. Defaults to None.
namespace (str, optional):
The namespace to list local queues from. Defaults to None.
flavors (List[str], optional):
The flavors to filter local queues by. Defaults to None.
Returns:
List[dict]: A list of dictionaries containing the name of the local queue and the available flavors
List[dict]:
A list of dictionaries containing the name of the local queue and the available flavors
"""

from ...ray.cluster.cluster import get_current_namespace

if namespace is None: # pragma: no cover
Expand Down Expand Up @@ -92,8 +109,22 @@ def list_local_queues(
return to_return


def local_queue_exists(namespace: str, local_queue_name: str):
# get all local queues in the namespace
def local_queue_exists(namespace: str, local_queue_name: str) -> bool:
"""
Checks if a local queue with the provided name exists in the given namespace.
This function queries the local queues in the specified namespace and verifies if any queue matches the given name.
Args:
namespace (str):
The namespace where the local queues are located.
local_queue_name (str):
The name of the local queue to check for existence.
Returns:
bool:
True if the local queue exists, False otherwise.
"""
try:
config_check()
api_instance = client.CustomObjectsApi(get_api_client())
Expand All @@ -113,6 +144,24 @@ def local_queue_exists(namespace: str, local_queue_name: str):


def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]):
"""
Adds a local queue name label to the provided item.
If the local queue is not provided, the default local queue for the namespace is used. The function validates if the
local queue exists, and if it does, the local queue name label is added to the resource metadata.
Args:
item (dict):
The resource where the label will be added.
namespace (str):
The namespace of the local queue.
local_queue (str, optional):
The name of the local queue to use. Defaults to None.
Raises:
ValueError:
If the provided or default local queue does not exist in the namespace.
"""
lq_name = local_queue or get_default_kueue_name(namespace)
if lq_name == None:
return
Expand Down
10 changes: 7 additions & 3 deletions src/codeflare_sdk/common/utils/demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ def copy_demo_nbs(dir: str = "./demo-notebooks", overwrite: bool = False):
Any files that exist in the directory that don't match these values will remain untouched.
Args:
dir (str): The directory to copy the demo notebooks to. Defaults to "./demo-notebooks". overwrite (bool):
overwrite (bool): Whether to overwrite files in the directory if it already exists. Defaults to False.
dir (str):
The directory to copy the demo notebooks to. Defaults to "./demo-notebooks".
overwrite (bool):
Whether to overwrite files in the directory if it already exists. Defaults to False.
Raises:
FileExistsError: If the directory already exists.
FileExistsError:
If the directory already exists.
"""
# does dir exist already?
if overwrite is False and pathlib.Path(dir).exists():
Expand Down
73 changes: 68 additions & 5 deletions src/codeflare_sdk/common/utils/generate_cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,20 @@


def generate_ca_cert(days: int = 30):
# Generate base64 encoded ca.key and ca.cert
# Similar to:
# openssl req -x509 -nodes -newkey rsa:2048 -keyout ca.key -days 1826 -out ca.crt -subj '/CN=root-ca'
# base64 -i ca.crt -i ca.key
"""
Generates a self-signed CA certificate and private key, encoded in base64 format.
Similar to:
openssl req -x509 -nodes -newkey rsa:2048 -keyout ca.key -days 1826 -out ca.crt -subj '/CN=root-ca'
Args:
days (int):
The number of days for which the CA certificate will be valid. Default is 30.
Returns:
Tuple[str, str]:
A tuple containing the base64-encoded private key and CA certificate.
"""

private_key = rsa.generate_private_key(
public_exponent=65537,
Expand Down Expand Up @@ -79,6 +89,25 @@ def generate_ca_cert(days: int = 30):


def get_secret_name(cluster_name, namespace, api_instance):
"""
Retrieves the name of the Kubernetes secret containing the CA certificate for the given Ray cluster.
Args:
cluster_name (str):
The name of the Ray cluster.
namespace (str):
The Kubernetes namespace where the Ray cluster is located.
api_instance (client.CoreV1Api):
An instance of the Kubernetes CoreV1Api.
Returns:
str:
The name of the Kubernetes secret containing the CA certificate.
Raises:
KeyError:
If no secret matching the cluster name is found.
"""
label_selector = f"ray.openshift.ai/cluster-name={cluster_name}"
try:
secrets = api_instance.list_namespaced_secret(
Expand All @@ -97,7 +126,26 @@ def get_secret_name(cluster_name, namespace, api_instance):


def generate_tls_cert(cluster_name, namespace, days=30):
# Create a folder tls-<cluster>-<namespace> and store three files: ca.crt, tls.crt, and tls.key
"""
Generates a TLS certificate and key for a Ray cluster, saving them locally along with the CA certificate.
Args:
cluster_name (str):
The name of the Ray cluster.
namespace (str):
The Kubernetes namespace where the Ray cluster is located.
days (int):
The number of days for which the TLS certificate will be valid. Default is 30.
Files Created:
- ca.crt: The CA certificate.
- tls.crt: The TLS certificate signed by the CA.
- tls.key: The private key for the TLS certificate.
Raises:
Exception:
If an error occurs while retrieving the CA secret.
"""
tls_dir = os.path.join(os.getcwd(), f"tls-{cluster_name}-{namespace}")
if not os.path.exists(tls_dir):
os.makedirs(tls_dir)
Expand Down Expand Up @@ -181,6 +229,21 @@ def generate_tls_cert(cluster_name, namespace, days=30):


def export_env(cluster_name, namespace):
"""
Sets environment variables to configure TLS for a Ray cluster.
Args:
cluster_name (str):
The name of the Ray cluster.
namespace (str):
The Kubernetes namespace where the Ray cluster is located.
Environment Variables Set:
- RAY_USE_TLS: Enables TLS for Ray.
- RAY_TLS_SERVER_CERT: Path to the TLS server certificate.
- RAY_TLS_SERVER_KEY: Path to the TLS server private key.
- RAY_TLS_CA_CERT: Path to the CA certificate.
"""
tls_dir = os.path.join(os.getcwd(), f"tls-{cluster_name}-{namespace}")
os.environ["RAY_USE_TLS"] = "1"
os.environ["RAY_TLS_SERVER_CERT"] = os.path.join(tls_dir, "tls.crt")
Expand Down

0 comments on commit ac1a1dc

Please sign in to comment.