diff --git a/examples/jax/cpu-demo/train.py b/examples/jax/cpu-demo/train.py index 5dd7c9da1d..0383374555 100644 --- a/examples/jax/cpu-demo/train.py +++ b/examples/jax/cpu-demo/train.py @@ -4,50 +4,20 @@ import os import socket -import time import jax -from absl import app, flags +from absl import app jax.config.update("jax_cpu_collectives_implementation", "gloo") -flags.DEFINE_integer("num_processes", 1, "Number of processes") -flags.DEFINE_string("job_name", None, "Job name") -flags.DEFINE_string("sub_domain", None, "Service sub domain") -flags.DEFINE_string("coordinator_port", None, "Port the coordinator listens on") -flags.mark_flag_as_required("job_name") -flags.mark_flag_as_required("sub_domain") -flags.mark_flag_as_required("coordinator_port") - -FLAGS = flags.FLAGS - - -def _get_coordinator_ip_address(job_name, sub_domain): - coordinator_fqdn = f"{FLAGS.job_name}-0.{FLAGS.sub_domain}" - print(f"Coordinator host name: {coordinator_fqdn}") - - for retry_attempt in range(120): - try: - time.sleep(1) - coordinator_ipaddress = socket.gethostbyname(coordinator_fqdn) - except socket.gaierror: - print( - f"Failed to resolve: {coordinator_fqdn}. Trying again in a second ..." - ) - else: - break - - print(f"Coordinator IP address: {coordinator_ipaddress}") - - return coordinator_ipaddress - def _main(argv): process_id = int(os.getenv("PROCESS_ID")) - num_processes = FLAGS.num_processes - coordinator_address = _get_coordinator_ip_address(FLAGS.job_name, FLAGS.sub_domain) - coordinator_address = f"{coordinator_address}:{FLAGS.coordinator_port}" + num_processes = int(os.getenv("NUM_PROCESSES")) + coordinator_address = os.getenv("COORDINATOR_ADDRESS") + coordinator_port = int(os.getenv("COORDINATOR_PORT")) + coordinator_address = f"{coordinator_address}:{coordinator_port}" jax.distributed.initialize( coordinator_address=coordinator_address, diff --git a/manifests/base/webhook/manifests.yaml b/manifests/base/webhook/manifests.yaml index 2e5d5fb7dd..2c381d0cd1 100644 --- a/manifests/base/webhook/manifests.yaml +++ b/manifests/base/webhook/manifests.yaml @@ -104,23 +104,3 @@ webhooks: resources: - xgboostjobs sideEffects: None -- admissionReviewVersions: - - v1 - clientConfig: - service: - name: webhook-service - namespace: system - path: /validate-kubeflow-org-v1-jaxjob - failurePolicy: Fail - name: validator.jaxjob.training-operator.kubeflow.org - rules: - - apiGroups: - - kubeflow.org - apiVersions: - - v1 - operations: - - CREATE - - UPDATE - resources: - - jaxjobs - sideEffects: None diff --git a/pkg/controller.v1/jax/jaxjob_controller_test.go b/pkg/controller.v1/jax/jaxjob_controller_test.go index 3de5c00c85..7e7b8280ad 100644 --- a/pkg/controller.v1/jax/jaxjob_controller_test.go +++ b/pkg/controller.v1/jax/jaxjob_controller_test.go @@ -188,13 +188,13 @@ var _ = Describe("JAXJob controller", func() { Type: kubeflowv1.JobCreated, Status: corev1.ConditionTrue, Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobCreatedReason), - Message: fmt.Sprintf("JAXJob %s is created.", ns.Name+"/"+name), + Message: fmt.Sprintf("JAXJob %s is created.", name), }, { Type: kubeflowv1.JobRunning, Status: corev1.ConditionTrue, Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobRunningReason), - Message: fmt.Sprintf("JAXJob %s is running.", name), + Message: fmt.Sprintf("JAXJob %s/%s is running.", ns.Name, name), }, }, testutil.IgnoreJobConditionsTimes)) @@ -300,7 +300,7 @@ var _ = Describe("JAXJob controller", func() { Type: kubeflowv1.JobRunning, Status: corev1.ConditionTrue, Reason: commonutil.NewReason(kubeflowv1.JAXJobKind, commonutil.JobRunningReason), - Message: fmt.Sprintf("JAXJob %s is running.", name), + Message: fmt.Sprintf("JAXJob %s/%s is running.", ns.Name, name), }, }, testutil.IgnoreJobConditionsTimes))