From da11d1116c29322c481d0b8f174df8d6f05004aa Mon Sep 17 00:00:00 2001 From: Sandipan Panda <87253083+sandipanpanda@users.noreply.github.com> Date: Fri, 20 Sep 2024 23:03:29 +0530 Subject: [PATCH] Update JAX image to use image published by Kubeflow (#2264) * Use JAX image published by Kubeflow Signed-off-by: Sandipan Panda * Update README to include JAX API Definition Signed-off-by: Sandipan Panda --------- Signed-off-by: Sandipan Panda --- README.md | 3 ++- examples/jax/cpu-demo/demo.yaml | 2 +- pkg/controller.v1/jax/envvar_test.go | 2 +- pkg/webhooks/jax/jaxjob_webhook_test.go | 2 +- sdk/python/kubeflow/training/constants/constants.py | 2 +- sdk/python/test/e2e/test_e2e_jaxjob.py | 2 +- 6 files changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index d85947ecc1..88513bca5f 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Kubeflow Training Operator is a Kubernetes-native project for fine-tuning and scalable distributed training of machine learning (ML) models created with various ML frameworks -such as PyTorch, TensorFlow, HuggingFace, Jax, DeepSpeed, XGBoost, PaddlePaddle and others. +such as PyTorch, TensorFlow, HuggingFace, [JAX](https://jax.readthedocs.io/en/latest/), DeepSpeed, XGBoost, PaddlePaddle and others. You can run high-performance computing (HPC) tasks with the Training Operator and `MPIJob` since it supports running Message Passing Interface (MPI) on Kubernetes which is heavily used for HPC. @@ -103,6 +103,7 @@ For a complete reference of the custom resource definitions, please refer to the - [XGBoost API Definition](pkg/apis/kubeflow.org/v1/xgboost_types.go) - [MPI API Definition](pkg/apis/kubeflow.org/v1/mpi_types.go) - [PaddlePaddle API Definition](pkg/apis/kubeflow.org/v1/paddlepaddle_types.go) +- [JAX API Definition](pkg/apis/kubeflow.org/v1/jax_types.go) For details on the Training Operator custom resources APIs, refer to [the following API documentation](docs/api/kubeflow.org_v1_generated.asciidoc) diff --git a/examples/jax/cpu-demo/demo.yaml b/examples/jax/cpu-demo/demo.yaml index 85c99c9b18..bffd3cc16f 100644 --- a/examples/jax/cpu-demo/demo.yaml +++ b/examples/jax/cpu-demo/demo.yaml @@ -12,7 +12,7 @@ spec: spec: containers: - name: jax - image: docker.io/sandipanify/jaxgoogle:latest + image: docker.io/kubeflow/jaxjob-simple:latest command: - "python3" - "train.py" diff --git a/pkg/controller.v1/jax/envvar_test.go b/pkg/controller.v1/jax/envvar_test.go index e8155c5274..9920e89bbb 100644 --- a/pkg/controller.v1/jax/envvar_test.go +++ b/pkg/controller.v1/jax/envvar_test.go @@ -30,7 +30,7 @@ func TestSetPodEnv(t *testing.T) { Spec: corev1.PodSpec{ Containers: []corev1.Container{{ Name: "jax", - Image: "docker.io/sandipanify/jaxgoogle:latest", + Image: "docker.io/kubeflow/jaxjob-simple:latest", Ports: []corev1.ContainerPort{{ Name: kubeflowv1.JAXJobDefaultPortName, ContainerPort: validPort, diff --git a/pkg/webhooks/jax/jaxjob_webhook_test.go b/pkg/webhooks/jax/jaxjob_webhook_test.go index ed3bd2bd1b..bfbc0eb29c 100644 --- a/pkg/webhooks/jax/jaxjob_webhook_test.go +++ b/pkg/webhooks/jax/jaxjob_webhook_test.go @@ -38,7 +38,7 @@ func TestValidateV1JAXJob(t *testing.T) { Spec: corev1.PodSpec{ Containers: []corev1.Container{{ Name: "jax", - Image: "docker.io/sandipanify/jaxgoogle:latest", + Image: "docker.io/kubeflow/jaxjob-simple:latest", Ports: []corev1.ContainerPort{{ Name: "jaxjob-port", ContainerPort: 6666, diff --git a/sdk/python/kubeflow/training/constants/constants.py b/sdk/python/kubeflow/training/constants/constants.py index 0bb6fe495e..e745468736 100644 --- a/sdk/python/kubeflow/training/constants/constants.py +++ b/sdk/python/kubeflow/training/constants/constants.py @@ -144,7 +144,7 @@ JAXJOB_PLURAL = "jaxjobs" JAXJOB_CONTAINER = "jax" JAXJOB_REPLICA_TYPES = REPLICA_TYPE_WORKER.lower() -JAXJOB_BASE_IMAGE = "kubeflow/jaxjob-simple:latest" +JAXJOB_BASE_IMAGE = "docker.io/kubeflow/jaxjob-simple:latest" # Dictionary to get plural, model, and container for each Job kind. JOB_PARAMETERS = { diff --git a/sdk/python/test/e2e/test_e2e_jaxjob.py b/sdk/python/test/e2e/test_e2e_jaxjob.py index cf350f1c11..6223c8a988 100644 --- a/sdk/python/test/e2e/test_e2e_jaxjob.py +++ b/sdk/python/test/e2e/test_e2e_jaxjob.py @@ -155,7 +155,7 @@ def generate_jaxjob( def generate_container() -> V1Container: return V1Container( name=CONTAINER_NAME, - image="docker.io/sandipanify/jaxgoogle:latest", + image="docker.io/kubeflow/jaxjob-simple:latest", command=["python", "train.py"], resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}), )