Skip to content

Commit

Permalink
Add jax webhook test and examples
Browse files Browse the repository at this point in the history
Signed-off-by: Sandipan Panda <[email protected]>
  • Loading branch information
sandipanpanda committed Aug 28, 2024
1 parent df7fd90 commit 212f8bf
Show file tree
Hide file tree
Showing 6 changed files with 329 additions and 8 deletions.
25 changes: 25 additions & 0 deletions examples/jax/cpu-demo/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
FROM python:3.12

RUN pip install jax absl-py kubernetes

RUN apt-get update && apt-get install -y \
build-essential \
cmake \
git \
libgoogle-glog-dev \
libgflags-dev \
libprotobuf-dev \
protobuf-compiler \
&& rm -rf /var/lib/apt/lists/*

RUN git clone https://github.com/facebookincubator/gloo.git \
&& cd gloo \
&& mkdir build \
&& cd build \
&& cmake ../ \
&& make \
&& make install

WORKDIR /app

ADD train.py /app
23 changes: 23 additions & 0 deletions examples/jax/cpu-demo/demo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
apiVersion: "kubeflow.org/v1"
kind: JAXJob
metadata:
name: jaxjob-simple
spec:
jaxReplicaSpecs:
Worker:
replicas: 2
restartPolicy: Never
template:
spec:
containers:
- name: jax-worker
image: sandipanify/jaxgoogle
command: ["python", "train.py"]
args:
- --num_processes="2"
- --job_name=jaxjob-simple
- --sub_domain=training-operator
- --coordinator_port="6666"
ports:
- containerPort: 6666
imagePullPolicy: Always
69 changes: 69 additions & 0 deletions examples/jax/cpu-demo/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import socket
import time

import jax
from absl import app, flags

jax.config.update("jax_cpu_collectives_implementation", "gloo")

flags.DEFINE_integer("num_processes", 1, "Number of processes")
flags.DEFINE_string("job_name", None, "Job name")
flags.DEFINE_string("sub_domain", None, "Service sub domain")
flags.DEFINE_string("coordinator_port", None, "Port the coordinator listens on")
flags.mark_flag_as_required("job_name")
flags.mark_flag_as_required("sub_domain")
flags.mark_flag_as_required("coordinator_port")

FLAGS = flags.FLAGS


def _get_coordinator_ip_address(job_name, sub_domain):
coordinator_fqdn = f"{FLAGS.job_name}-0.{FLAGS.sub_domain}"
print(f"Coordinator host name: {coordinator_fqdn}")

for retry_attempt in range(120):
try:
time.sleep(1)
coordinator_ipaddress = socket.gethostbyname(coordinator_fqdn)
except socket.gaierror:
print(
f"Failed to resolve: {coordinator_fqdn}. Trying again in a second ..."
)
else:
break

print(f"Coordinator IP address: {coordinator_ipaddress}")

return coordinator_ipaddress


def _main(argv):

process_id = int(os.getenv("JOB_COMPLETION_INDEX"))
num_processes = FLAGS.num_processes
coordinator_address = _get_coordinator_ip_address(FLAGS.job_name, FLAGS.sub_domain)
coordinator_address = f"{coordinator_address}:{FLAGS.coordinator_port}"

jax.distributed.initialize(
coordinator_address=coordinator_address,
num_processes=num_processes,
process_id=process_id,
)

print(
f"JAX process {jax.process_index()}/{jax.process_count()} initialized on "
f"{socket.gethostname()}"
)
print(f"JAX global devices:{jax.devices()}")
print(f"JAX local devices:{jax.local_devices()}")

print(jax.device_count())
print(jax.local_device_count())

xs = jax.numpy.ones(jax.local_device_count())
print(jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(xs))


if __name__ == "__main__":
app.run(_main)
198 changes: 198 additions & 0 deletions pkg/webhooks/jax/jaxjob_webhook_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
Copyright 2024 The Kubeflow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package jax

import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/validation/field"
"k8s.io/utils/ptr"

trainingoperator "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
)

func TestValidateV1JAXJob(t *testing.T) {
validJAXReplicaSpecs := map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{
trainingoperator.JAXJobReplicaTypeWorker: {
Replicas: ptr.To[int32](1),
RestartPolicy: trainingoperator.RestartPolicyOnFailure,
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: "jax",
Image: "docker.io/sandipanify/jaxgoogle:latest",
Ports: []corev1.ContainerPort{{
Name: "jaxjob-port",
ContainerPort: 6666,
}},
ImagePullPolicy: corev1.PullAlways,
Command: []string{
"python",
"train.py",
},
}},
},
},
},
}

testCases := map[string]struct {
jaxJob *trainingoperator.JAXJob
wantErr field.ErrorList
}{
"valid JAXJob": {
jaxJob: &trainingoperator.JAXJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: trainingoperator.JAXJobSpec{
JAXReplicaSpecs: validJAXReplicaSpecs,
},
},
},
"jaxJob name does not meet DNS1035": {
jaxJob: &trainingoperator.JAXJob{
ObjectMeta: metav1.ObjectMeta{
Name: "0-test",
},
Spec: trainingoperator.JAXJobSpec{
JAXReplicaSpecs: validJAXReplicaSpecs,
},
},
wantErr: field.ErrorList{
field.Invalid(field.NewPath("metadata").Child("name"), "", ""),
},
},
"no containers": {
jaxJob: &trainingoperator.JAXJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: trainingoperator.JAXJobSpec{
JAXReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{
trainingoperator.JAXJobReplicaTypeWorker: {
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{},
},
},
},
},
},
},
wantErr: field.ErrorList{
field.Required(jaxReplicaSpecPath.
Key(string(trainingoperator.JAXJobReplicaTypeWorker)).
Child("template").
Child("spec").
Child("containers"), ""),
field.Required(jaxReplicaSpecPath.
Key(string(trainingoperator.JAXJobReplicaTypeWorker)).
Child("template").
Child("spec").
Child("containers"), ""),
},
},
"image is empty": {
jaxJob: &trainingoperator.JAXJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: trainingoperator.JAXJobSpec{
JAXReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{
trainingoperator.JAXJobReplicaTypeWorker: {
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "jax",
Image: "",
},
},
},
},
},
},
},
},
wantErr: field.ErrorList{
field.Required(jaxReplicaSpecPath.
Key(string(trainingoperator.JAXJobReplicaTypeWorker)).
Child("template").
Child("spec").
Child("containers").
Index(0).
Child("image"), ""),
},
},
"jaxJob default container name doesn't present": {
jaxJob: &trainingoperator.JAXJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: trainingoperator.JAXJobSpec{
JAXReplicaSpecs: map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec{
trainingoperator.JAXJobReplicaTypeWorker: {
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "",
Image: "",
},
},
},
},
},
},
},
},
wantErr: field.ErrorList{
field.Required(jaxReplicaSpecPath.
Key(string(trainingoperator.JAXJobReplicaTypeWorker)).
Child("template").
Child("spec").
Child("containers"), ""),
},
},
"replicaSpec is nil": {
jaxJob: &trainingoperator.JAXJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
},
Spec: trainingoperator.JAXJobSpec{
JAXReplicaSpecs: nil,
},
},
wantErr: field.ErrorList{
field.Required(jaxReplicaSpecPath, ""),
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
got := validateJAXJob(tc.jaxJob)
if diff := cmp.Diff(tc.wantErr, got, cmpopts.IgnoreFields(field.Error{}, "Detail", "BadValue")); len(diff) != 0 {
t.Errorf("Unexpected error (-want,+got):\n%s", diff)
}
})
}
}
2 changes: 1 addition & 1 deletion sdk/python/kubeflow/training/constants/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
JAXJOB_MODEL = "KubeflowOrgV1JAXJob"
JAXJOB_PLURAL = "jaxjobs"
JAXJOB_CONTAINER = "jax"
JAXJOB_REPLICA_TYPES = (REPLICA_TYPE_WORKER.lower())
JAXJOB_REPLICA_TYPES = REPLICA_TYPE_WORKER.lower()

# Dictionary to get plural, model, and container for each Job kind.
JOB_PARAMETERS = {
Expand Down
20 changes: 13 additions & 7 deletions sdk/python/test/e2e/test_e2e_jaxjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,16 @@ def generate_jaxjob(
)


# def generate_container() -> V1Container:
# return V1Container(
# name=CONTAINER_NAME,
# image="docker.io/kubeflow/jaxgloo:latest",
# args=[],
# resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}),
# )
def generate_container() -> V1Container:
return V1Container(
name=CONTAINER_NAME,
image="docker.io/sandipanify/jaxgoogle:latest",
command=["python", "train.py"],
args=[
"--num_processes=2",
"--job_name=example-job",
"--sub_domain=training-operator",
"--cooordinator_port=6666",
],
resources=V1ResourceRequirements(limits={"memory": "2Gi", "cpu": "0.8"}),
)

0 comments on commit 212f8bf

Please sign in to comment.