Skip to content

Commit

Permalink
Add envvar tests
Browse files Browse the repository at this point in the history
Signed-off-by: Sandipan Panda <[email protected]>
  • Loading branch information
sandipanpanda committed Sep 19, 2024
1 parent a86c927 commit 9219e94
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/jax/cpu-demo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def _main(argv):
)

print(
f"JAX process {jax.process_index()}/{jax.process_count()} initialized on "
f"JAX process {jax.process_index()}/{jax.process_count() - 1} initialized on "
f"{socket.gethostname()}"
)
print(f"JAX global devices:{jax.devices()}")
print(f"JAX local devices:{jax.local_devices()}")

print(jax.device_count())
print(jax.local_device_count())
print(f"JAX device count:{jax.device_count()}")
print(f"JAX local device count:{jax.local_device_count()}")

xs = jax.numpy.ones(jax.local_device_count())
print(jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(xs))
Expand Down
142 changes: 142 additions & 0 deletions pkg/controller.v1/jax/envvar_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package jax

import (
"strconv"
"testing"

"github.com/google/go-cmp/cmp"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/utils/ptr"

kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
)

func TestSetPodEnv(t *testing.T) {
// Define some helper variables/constants for the test cases
validPort := int32(6666)
validIndex := "0"
invalidIndex := "invalid"

// Define a valid JAXJob structure
validJAXJob := &kubeflowv1.JAXJob{
ObjectMeta: metav1.ObjectMeta{Name: "test-jaxjob"},
Spec: kubeflowv1.JAXJobSpec{
JAXReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
kubeflowv1.JAXJobReplicaTypeWorker: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: "jax",
Image: "docker.io/sandipanify/jaxgoogle:latest",
Ports: []corev1.ContainerPort{{
Name: kubeflowv1.JAXJobDefaultPortName,
ContainerPort: validPort,
}},
ImagePullPolicy: corev1.PullAlways,
Command: []string{
"python",
"train.py",
},
}},
},
},
},
},
},
}

// Define the expected environment variables to be set
expectedEnvVars := []corev1.EnvVar{
{Name: "PYTHONUNBUFFERED", Value: "1"},
{Name: "COORDINATOR_PORT", Value: strconv.Itoa(int(validPort))},
{Name: "COORDINATOR_ADDRESS", Value: "test-jaxjob-worker-0"},
{Name: "NUM_PROCESSES", Value: "1"},
{Name: "PROCESS_ID", Value: validIndex},
}

// Define the test cases
cases := map[string]struct {
jaxJob *kubeflowv1.JAXJob
podTemplate *corev1.PodTemplateSpec
rtype kubeflowv1.ReplicaType
index string
wantPodEnvVars []corev1.EnvVar
wantErr bool
}{
"successful environment variable setup": {
jaxJob: validJAXJob,
podTemplate: &corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{}},
},
},
rtype: kubeflowv1.JAXJobReplicaTypeWorker,
index: validIndex,
wantPodEnvVars: expectedEnvVars,
wantErr: false,
},
"invalid index for PROCESS_ID": {
jaxJob: validJAXJob,
podTemplate: &corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{}},
},
},
rtype: kubeflowv1.JAXJobReplicaTypeWorker,
index: invalidIndex,
wantErr: true,
},
"missing container port in JAXJob": {
jaxJob: &kubeflowv1.JAXJob{
Spec: kubeflowv1.JAXJobSpec{
JAXReplicaSpecs: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
kubeflowv1.JAXJobReplicaTypeWorker: {
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: "jax",
Ports: []corev1.ContainerPort{
{Name: "wrong-port", ContainerPort: 0},
},
}},
},
},
},
},
},
},
podTemplate: &corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{}},
},
},
rtype: kubeflowv1.JAXJobReplicaTypeWorker,
index: validIndex,
wantErr: true,
},
}

// Execute the test cases
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
err := setPodEnv(tc.jaxJob, tc.podTemplate, string(tc.rtype), tc.index)

// Check if an error was expected
if (err != nil) != tc.wantErr {
t.Errorf("setPodEnv() error = %v, wantErr %v", err, tc.wantErr)
}

// If no error was expected, verify the environment variables
if !tc.wantErr {
for i, container := range tc.podTemplate.Spec.Containers {
if diff := cmp.Diff(tc.wantPodEnvVars, container.Env); diff != "" {
t.Errorf("Unexpected env vars for container %d (-want,+got):\n%s", i, diff)
}
}
}
})
}
}

0 comments on commit 9219e94

Please sign in to comment.