Skip to content

Commit

Permalink
Fix python tests
Browse files Browse the repository at this point in the history
Signed-off-by: Sandipan Panda <[email protected]>
  • Loading branch information
sandipanpanda committed Aug 28, 2024
1 parent 061a4ef commit df7fd90
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
3 changes: 0 additions & 3 deletions pkg/controller.v1/jax/jaxjob_controller_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ import (
//+kubebuilder:scaffold:imports
)

// These tests use Ginkgo (BDD-style Go testing framework). Refer to
// http://onsi.github.io/ginkgo/ to learn more about Ginkgo.

var (
testK8sClient client.Client
testEnv *envtest.Environment
Expand Down
8 changes: 8 additions & 0 deletions sdk/python/kubeflow/training/api/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,8 @@ def get_job_pods(
For PaddleJob one of `master` or `worker`.
For JAXJob `worker`.
replica_index: Index for the Job replica.
timeout: Kubernetes API server timeout in seconds to execute the request.
Expand All @@ -986,13 +988,15 @@ def get_job_pods(
and replica_type not in constants.XGBOOSTJOB_REPLICA_TYPES
and replica_type not in constants.MPIJOB_REPLICA_TYPES
and replica_type not in constants.PADDLEJOB_REPLICA_TYPES
and replica_type not in constants.JAXJOB_REPLICA_TYPES
):
raise ValueError(
f"TFJob replica type must be one of {constants.TFJOB_REPLICA_TYPES}\n"
f"PyTorchJob replica type must be one of {constants.PYTORCHJOB_REPLICA_TYPES}\n"
f"XGBoostJob replica type must be one of {constants.XGBOOSTJOB_REPLICA_TYPES}\n"
f"MPIJob replica type must be one of {constants.MPIJOB_REPLICA_TYPES}\n"
f"PaddleJob replica type must be one of {constants.PADDLEJOB_REPLICA_TYPES}"
f"JAXJob replica type must be one of {constants.PADDLEJOB_REPLICA_TYPES}"
)

label_selector = f"{constants.JOB_NAME_LABEL}={name}"
Expand Down Expand Up @@ -1052,6 +1056,8 @@ def get_job_pod_names(
For PaddleJob one of `master` or `worker`.
For JAXJob `worker`.
replica_index: Index for the Job replica.
timeout: Kubernetes API server timeout in seconds to execute the request.
Expand Down Expand Up @@ -1112,6 +1118,8 @@ def get_job_logs(
For MPIJob one of `launcher` or `worker`.
For PaddleJob one of `master` or `worker`.
For JAXJob `worker`.
replica_index: Optional, index for the Job replica.
container: Pod container to get the logs.
follow: Whether to follow the log stream of the pod and print logs to StdOut.
Expand Down
13 changes: 13 additions & 0 deletions sdk/python/kubeflow/training/constants/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@
"docker.io/paddlepaddle/paddle:2.4.0rc0-gpu-cuda11.2-cudnn8.1-trt8.0"
)

# JAXJob constants
JAXJOB_KIND = "JAXJob"
JAXJOB_MODEL = "KubeflowOrgV1JAXJob"
JAXJOB_PLURAL = "jaxjobs"
JAXJOB_CONTAINER = "jax"
JAXJOB_REPLICA_TYPES = (REPLICA_TYPE_WORKER.lower())

# Dictionary to get plural, model, and container for each Job kind.
JOB_PARAMETERS = {
Expand Down Expand Up @@ -171,6 +177,12 @@
"container": PADDLEJOB_CONTAINER,
"base_image": PADDLEJOB_BASE_IMAGE,
},
JAXJOB_KIND: {
"model": JAXJOB_MODEL,
"plural": JAXJOB_PLURAL,
"container": JAXJOB_CONTAINER,
"base_image": "TODO",
},
}

# Tuple of all Job models.
Expand All @@ -183,4 +195,5 @@
models.KubeflowOrgV1XGBoostJob,
models.KubeflowOrgV1MPIJob,
models.KubeflowOrgV1PaddleJob,
models.KubeflowOrgV1JAXJob,
]
5 changes: 2 additions & 3 deletions sdk/python/test/e2e/test_e2e_jaxjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ def generate_jaxjob(
# def generate_container() -> V1Container:
# return V1Container(
# name=CONTAINER_NAME,
# image="docker.io/sandipanify/jaxgloo",
# command=["python"],
# args=["-m", "", ""],
# image="docker.io/kubeflow/jaxgloo:latest",
# args=[],
# resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}),
# )

0 comments on commit df7fd90

Please sign in to comment.