diff --git a/pkg/controller.v1/jax/jaxjob_controller_suite_test.go b/pkg/controller.v1/jax/jaxjob_controller_suite_test.go index 01ce56b28a..a9471d9c83 100644 --- a/pkg/controller.v1/jax/jaxjob_controller_suite_test.go +++ b/pkg/controller.v1/jax/jaxjob_controller_suite_test.go @@ -41,9 +41,6 @@ import ( //+kubebuilder:scaffold:imports ) -// These tests use Ginkgo (BDD-style Go testing framework). Refer to -// http://onsi.github.io/ginkgo/ to learn more about Ginkgo. - var ( testK8sClient client.Client testEnv *envtest.Environment diff --git a/sdk/python/kubeflow/training/api/training_client.py b/sdk/python/kubeflow/training/api/training_client.py index 78bf0df7f1..60c37cd261 100644 --- a/sdk/python/kubeflow/training/api/training_client.py +++ b/sdk/python/kubeflow/training/api/training_client.py @@ -965,6 +965,8 @@ def get_job_pods( For PaddleJob one of `master` or `worker`. + For JAXJob `worker`. + replica_index: Index for the Job replica. timeout: Kubernetes API server timeout in seconds to execute the request. @@ -986,6 +988,7 @@ def get_job_pods( and replica_type not in constants.XGBOOSTJOB_REPLICA_TYPES and replica_type not in constants.MPIJOB_REPLICA_TYPES and replica_type not in constants.PADDLEJOB_REPLICA_TYPES + and replica_type not in constants.JAXJOB_REPLICA_TYPES ): raise ValueError( f"TFJob replica type must be one of {constants.TFJOB_REPLICA_TYPES}\n" @@ -993,6 +996,7 @@ def get_job_pods( f"XGBoostJob replica type must be one of {constants.XGBOOSTJOB_REPLICA_TYPES}\n" f"MPIJob replica type must be one of {constants.MPIJOB_REPLICA_TYPES}\n" f"PaddleJob replica type must be one of {constants.PADDLEJOB_REPLICA_TYPES}" + f"JAXJob replica type must be one of {constants.PADDLEJOB_REPLICA_TYPES}" ) label_selector = f"{constants.JOB_NAME_LABEL}={name}" @@ -1052,6 +1056,8 @@ def get_job_pod_names( For PaddleJob one of `master` or `worker`. + For JAXJob `worker`. + replica_index: Index for the Job replica. timeout: Kubernetes API server timeout in seconds to execute the request. @@ -1112,6 +1118,8 @@ def get_job_logs( For MPIJob one of `launcher` or `worker`. For PaddleJob one of `master` or `worker`. + + For JAXJob `worker`. replica_index: Optional, index for the Job replica. container: Pod container to get the logs. follow: Whether to follow the log stream of the pod and print logs to StdOut. diff --git a/sdk/python/kubeflow/training/constants/constants.py b/sdk/python/kubeflow/training/constants/constants.py index c355529e9a..08e2205dba 100644 --- a/sdk/python/kubeflow/training/constants/constants.py +++ b/sdk/python/kubeflow/training/constants/constants.py @@ -138,6 +138,12 @@ "docker.io/paddlepaddle/paddle:2.4.0rc0-gpu-cuda11.2-cudnn8.1-trt8.0" ) +# JAXJob constants +JAXJOB_KIND = "JAXJob" +JAXJOB_MODEL = "KubeflowOrgV1JAXJob" +JAXJOB_PLURAL = "jaxjobs" +JAXJOB_CONTAINER = "jax" +JAXJOB_REPLICA_TYPES = (REPLICA_TYPE_WORKER.lower()) # Dictionary to get plural, model, and container for each Job kind. JOB_PARAMETERS = { @@ -171,6 +177,12 @@ "container": PADDLEJOB_CONTAINER, "base_image": PADDLEJOB_BASE_IMAGE, }, + JAXJOB_KIND: { + "model": JAXJOB_MODEL, + "plural": JAXJOB_PLURAL, + "container": JAXJOB_CONTAINER, + "base_image": "TODO", + }, } # Tuple of all Job models. @@ -183,4 +195,5 @@ models.KubeflowOrgV1XGBoostJob, models.KubeflowOrgV1MPIJob, models.KubeflowOrgV1PaddleJob, + models.KubeflowOrgV1JAXJob, ] diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py index 483ac011f4..9fd0d938f2 100644 --- a/sdk/python/test/e2e/test_e2e_jaxjob.py +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -155,8 +155,7 @@ def generate_jaxjob( # def generate_container() -> V1Container: # return V1Container( # name=CONTAINER_NAME, -# image="docker.io/sandipanify/jaxgloo", -# command=["python"], -# args=["-m", "", ""], +# image="docker.io/kubeflow/jaxgloo:latest", +# args=[], # resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}), # )