Skip to content

Commit

Permalink
address reviews 3
Browse files Browse the repository at this point in the history
Signed-off-by: Sandipan Panda <[email protected]>
  • Loading branch information
sandipanpanda committed Sep 10, 2024
1 parent c75c9be commit a3303b7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 58 deletions.
40 changes: 5 additions & 35 deletions examples/jax/cpu-demo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,20 @@

import os
import socket
import time

import jax
from absl import app, flags
from absl import app

jax.config.update("jax_cpu_collectives_implementation", "gloo")

flags.DEFINE_integer("num_processes", 1, "Number of processes")
flags.DEFINE_string("job_name", None, "Job name")
flags.DEFINE_string("sub_domain", None, "Service sub domain")
flags.DEFINE_string("coordinator_port", None, "Port the coordinator listens on")
flags.mark_flag_as_required("job_name")
flags.mark_flag_as_required("sub_domain")
flags.mark_flag_as_required("coordinator_port")

FLAGS = flags.FLAGS


def _get_coordinator_ip_address(job_name, sub_domain):
coordinator_fqdn = f"{FLAGS.job_name}-0.{FLAGS.sub_domain}"
print(f"Coordinator host name: {coordinator_fqdn}")

for retry_attempt in range(120):
try:
time.sleep(1)
coordinator_ipaddress = socket.gethostbyname(coordinator_fqdn)
except socket.gaierror:
print(
f"Failed to resolve: {coordinator_fqdn}. Trying again in a second ..."
)
else:
break

print(f"Coordinator IP address: {coordinator_ipaddress}")

return coordinator_ipaddress


def _main(argv):

process_id = int(os.getenv("PROCESS_ID"))
num_processes = FLAGS.num_processes
coordinator_address = _get_coordinator_ip_address(FLAGS.job_name, FLAGS.sub_domain)
coordinator_address = f"{coordinator_address}:{FLAGS.coordinator_port}"
num_processes = int(os.getenv("NUM_PROCESSES"))
coordinator_address = os.getenv("COORDINATOR_ADDRESS")
coordinator_port = int(os.getenv("COORDINATOR_PORT"))
coordinator_address = f"{coordinator_address}:{coordinator_port}"

jax.distributed.initialize(
coordinator_address=coordinator_address,
Expand Down
20 changes: 0 additions & 20 deletions manifests/base/webhook/manifests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,3 @@ webhooks:
resources:
- xgboostjobs
sideEffects: None
- admissionReviewVersions:
- v1
clientConfig:
service:
name: webhook-service
namespace: system
path: /validate-kubeflow-org-v1-jaxjob
failurePolicy: Fail
name: validator.jaxjob.training-operator.kubeflow.org
rules:
- apiGroups:
- kubeflow.org
apiVersions:
- v1
operations:
- CREATE
- UPDATE
resources:
- jaxjobs
sideEffects: None
6 changes: 3 additions & 3 deletions pkg/controller.v1/jax/jaxjob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,13 @@ var _ = Describe("JAXJob controller", func() {
Type: kubeflowv1.JobCreated,
Status: corev1.ConditionTrue,
Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobCreatedReason),
Message: fmt.Sprintf("JAXJob %s is created.", ns.Name+"/"+name),
Message: fmt.Sprintf("JAXJob %s is created.", name),
},
{
Type: kubeflowv1.JobRunning,
Status: corev1.ConditionTrue,
Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobRunningReason),
Message: fmt.Sprintf("JAXJob %s is running.", name),
Message: fmt.Sprintf("JAXJob %s/%s is running.", ns.Name, name),
},
}, testutil.IgnoreJobConditionsTimes))

Expand Down Expand Up @@ -300,7 +300,7 @@ var _ = Describe("JAXJob controller", func() {
Type: kubeflowv1.JobRunning,
Status: corev1.ConditionTrue,
Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobRunningReason),
Message: fmt.Sprintf("JAXJob %s is running.", name),
Message: fmt.Sprintf("JAXJob %s/%s is running.", ns.Name, name),
},
}, testutil.IgnoreJobConditionsTimes))

Expand Down

0 comments on commit a3303b7

Please sign in to comment.