diff --git a/docs/proposals/train_api_proposal.md b/docs/proposals/2003-train-api/README.md similarity index 99% rename from docs/proposals/train_api_proposal.md rename to docs/proposals/2003-train-api/README.md index a2b8b22a63..dd55ca6c02 100644 --- a/docs/proposals/train_api_proposal.md +++ b/docs/proposals/2003-train-api/README.md @@ -1,4 +1,4 @@ -**

Train/Fine-tune API Proposal for LLMs

** +**

KEP-2003: Train/Fine-tune API Proposal for LLMs

** **

Authors:

** diff --git a/docs/proposals/jax-integration.md b/docs/proposals/2145-jax-integration/README.md similarity index 74% rename from docs/proposals/jax-integration.md rename to docs/proposals/2145-jax-integration/README.md index 3fb169dc54..fffde73e33 100644 --- a/docs/proposals/jax-integration.md +++ b/docs/proposals/2145-jax-integration/README.md @@ -1,7 +1,9 @@ -# Kubeflow Enhancement Proposal: Integrate JAX with Kubeflow Training Operator for Distributed Training on Kubernetes +# KEP-2145: Integrate JAX with Kubeflow Training Operator for Distributed Training on Kubernetes + ## Table of Contents + - [Summary](#summary) - [Motivation](#motivation) - [Goals](#goals) @@ -62,15 +64,14 @@ As a DevOps engineer, I want to manage JAX distributed training jobs using the K - Extend the Training Operator Python SDK to simplify the creation and management of `JaxJob` resources. - Configure JAX to use the Gloo backend for CPU-based distributed training. -| Environment Variable | JAX Parameter | Description | How to Obtain/Configure | -|----------------------------|------------------------|-----------------------------------------------------------------------------------------------------------|-----------------------------------------------| -| `JAX_COORDINATOR_ADDRESS` | `coordinator_address (str)` | the IP address of process 0 in your cluster, together with a port available on that process. Process 0 will start a JAX service exposed via that IP address and port, to which the other processes in the cluster will connect. | Set this in the coordinator pod spec and ensure it's the same for all worker pods. Example: `localhost:1234`. | -| `JAX_NUM_PROCESSES` | `num_processes (int) ` | The number of processes in the cluster. | Define in both the coordinator and worker pod specs. Example: `2`. | -| `JAX_PROCESS_ID` | `process_id (int)` | The ID number of the current process. Each process should have a unique ID, , in the range `[0 .. num_processes)`. | Set this in each pod spec. The coordinator is usually `0`, workers are `1, 2, ...`. | -| `JAX_LOCAL_DEVICE_IDS` | `local_device_ids (int)` | Restricts the visible devices of the current process to `local_device_ids`. | Optional. Set in the pod spec if device visibility needs to be restricted. | -| `JAX_INITIALIZATION_TIMEOUT`| `initialization_timeout (int)` | Time period (in seconds) for which connection will be retried. If the initialization takes more than the timeout specified, the initialization will error. Defaults to 300 secs i.e. 5 mins. | Optional. Can be set in the pod spec if a different timeout is needed. | -| `JAX_COORDINATOR_BIND_ADDRESS` | `coordinator_bind_address (str)` | The IP address and port to which the JAX service on process 0 in your cluster will bind. By default, it will bind to all available interfaces using the same port as `coordinator_address`. | Optional. Can be set in the coordinator pod spec. Default binds to all available addresses. | - +| Environment Variable | JAX Parameter | Description | How to Obtain/Configure | +| ------------------------------ | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------- | +| `JAX_COORDINATOR_ADDRESS` | `coordinator_address (str)` | the IP address of process 0 in your cluster, together with a port available on that process. Process 0 will start a JAX service exposed via that IP address and port, to which the other processes in the cluster will connect. | Set this in the coordinator pod spec and ensure it's the same for all worker pods. Example: `localhost:1234`. | +| `JAX_NUM_PROCESSES` | `num_processes (int) ` | The number of processes in the cluster. | Define in both the coordinator and worker pod specs. Example: `2`. | +| `JAX_PROCESS_ID` | `process_id (int)` | The ID number of the current process. Each process should have a unique ID, , in the range `[0 .. num_processes)`. | Set this in each pod spec. The coordinator is usually `0`, workers are `1, 2, ...`. | +| `JAX_LOCAL_DEVICE_IDS` | `local_device_ids (int)` | Restricts the visible devices of the current process to `local_device_ids`. | Optional. Set in the pod spec if device visibility needs to be restricted. | +| `JAX_INITIALIZATION_TIMEOUT` | `initialization_timeout (int)` | Time period (in seconds) for which connection will be retried. If the initialization takes more than the timeout specified, the initialization will error. Defaults to 300 secs i.e. 5 mins. | Optional. Can be set in the pod spec if a different timeout is needed. | +| `JAX_COORDINATOR_BIND_ADDRESS` | `coordinator_bind_address (str)` | The IP address and port to which the JAX service on process 0 in your cluster will bind. By default, it will bind to all available interfaces using the same port as `coordinator_address`. | Optional. Can be set in the coordinator pod spec. Default binds to all available addresses. | #### Validations for JaxJob @@ -230,19 +231,18 @@ metadata: name: jaxjob-worker-${job_id} spec: containers: - - image: ghcr.io/kubeflow/jax:latest - imagePullPolicy: IfNotPresent - name: worker - env: - - name: JAX_COORDINATOR_ADDRESS - value: '127.0.0.1:6666' - - name: JAX_NUM_PROCESSES - value: 1 - - name: JAX_PROCESS_ID - value: 0 - # process 0 is coordinator + - image: ghcr.io/kubeflow/jax:latest + imagePullPolicy: IfNotPresent + name: worker + env: + - name: JAX_COORDINATOR_ADDRESS + value: "127.0.0.1:6666" + - name: JAX_NUM_PROCESSES + value: 1 + - name: JAX_PROCESS_ID + value: 0 + # process 0 is coordinator restartPolicy: OnFailure - ``` ## Alternatives