diff --git a/docs/proposals/train_api_proposal.md b/docs/proposals/2003-train-api/README.md
similarity index 99%
rename from docs/proposals/train_api_proposal.md
rename to docs/proposals/2003-train-api/README.md
index a2b8b22a63..dd55ca6c02 100644
--- a/docs/proposals/train_api_proposal.md
+++ b/docs/proposals/2003-train-api/README.md
@@ -1,4 +1,4 @@
-**
Train/Fine-tune API Proposal for LLMs
**
+**KEP-2003: Train/Fine-tune API Proposal for LLMs
**
**Authors:
**
diff --git a/docs/proposals/jax-integration.md b/docs/proposals/2145-jax-integration/README.md
similarity index 74%
rename from docs/proposals/jax-integration.md
rename to docs/proposals/2145-jax-integration/README.md
index 3fb169dc54..fffde73e33 100644
--- a/docs/proposals/jax-integration.md
+++ b/docs/proposals/2145-jax-integration/README.md
@@ -1,7 +1,9 @@
-# Kubeflow Enhancement Proposal: Integrate JAX with Kubeflow Training Operator for Distributed Training on Kubernetes
+# KEP-2145: Integrate JAX with Kubeflow Training Operator for Distributed Training on Kubernetes
+
## Table of Contents
+
- [Summary](#summary)
- [Motivation](#motivation)
- [Goals](#goals)
@@ -62,15 +64,14 @@ As a DevOps engineer, I want to manage JAX distributed training jobs using the K
- Extend the Training Operator Python SDK to simplify the creation and management of `JaxJob` resources.
- Configure JAX to use the Gloo backend for CPU-based distributed training.
-| Environment Variable | JAX Parameter | Description | How to Obtain/Configure |
-|----------------------------|------------------------|-----------------------------------------------------------------------------------------------------------|-----------------------------------------------|
-| `JAX_COORDINATOR_ADDRESS` | `coordinator_address (str)` | the IP address of process 0 in your cluster, together with a port available on that process. Process 0 will start a JAX service exposed via that IP address and port, to which the other processes in the cluster will connect. | Set this in the coordinator pod spec and ensure it's the same for all worker pods. Example: `localhost:1234`. |
-| `JAX_NUM_PROCESSES` | `num_processes (int) ` | The number of processes in the cluster. | Define in both the coordinator and worker pod specs. Example: `2`. |
-| `JAX_PROCESS_ID` | `process_id (int)` | The ID number of the current process. Each process should have a unique ID, , in the range `[0 .. num_processes)`. | Set this in each pod spec. The coordinator is usually `0`, workers are `1, 2, ...`. |
-| `JAX_LOCAL_DEVICE_IDS` | `local_device_ids (int)` | Restricts the visible devices of the current process to `local_device_ids`. | Optional. Set in the pod spec if device visibility needs to be restricted. |
-| `JAX_INITIALIZATION_TIMEOUT`| `initialization_timeout (int)` | Time period (in seconds) for which connection will be retried. If the initialization takes more than the timeout specified, the initialization will error. Defaults to 300 secs i.e. 5 mins. | Optional. Can be set in the pod spec if a different timeout is needed. |
-| `JAX_COORDINATOR_BIND_ADDRESS` | `coordinator_bind_address (str)` | The IP address and port to which the JAX service on process 0 in your cluster will bind. By default, it will bind to all available interfaces using the same port as `coordinator_address`. | Optional. Can be set in the coordinator pod spec. Default binds to all available addresses. |
-
+| Environment Variable | JAX Parameter | Description | How to Obtain/Configure |
+| ------------------------------ | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------- |
+| `JAX_COORDINATOR_ADDRESS` | `coordinator_address (str)` | the IP address of process 0 in your cluster, together with a port available on that process. Process 0 will start a JAX service exposed via that IP address and port, to which the other processes in the cluster will connect. | Set this in the coordinator pod spec and ensure it's the same for all worker pods. Example: `localhost:1234`. |
+| `JAX_NUM_PROCESSES` | `num_processes (int) ` | The number of processes in the cluster. | Define in both the coordinator and worker pod specs. Example: `2`. |
+| `JAX_PROCESS_ID` | `process_id (int)` | The ID number of the current process. Each process should have a unique ID, , in the range `[0 .. num_processes)`. | Set this in each pod spec. The coordinator is usually `0`, workers are `1, 2, ...`. |
+| `JAX_LOCAL_DEVICE_IDS` | `local_device_ids (int)` | Restricts the visible devices of the current process to `local_device_ids`. | Optional. Set in the pod spec if device visibility needs to be restricted. |
+| `JAX_INITIALIZATION_TIMEOUT` | `initialization_timeout (int)` | Time period (in seconds) for which connection will be retried. If the initialization takes more than the timeout specified, the initialization will error. Defaults to 300 secs i.e. 5 mins. | Optional. Can be set in the pod spec if a different timeout is needed. |
+| `JAX_COORDINATOR_BIND_ADDRESS` | `coordinator_bind_address (str)` | The IP address and port to which the JAX service on process 0 in your cluster will bind. By default, it will bind to all available interfaces using the same port as `coordinator_address`. | Optional. Can be set in the coordinator pod spec. Default binds to all available addresses. |
#### Validations for JaxJob
@@ -230,19 +231,18 @@ metadata:
name: jaxjob-worker-${job_id}
spec:
containers:
- - image: ghcr.io/kubeflow/jax:latest
- imagePullPolicy: IfNotPresent
- name: worker
- env:
- - name: JAX_COORDINATOR_ADDRESS
- value: '127.0.0.1:6666'
- - name: JAX_NUM_PROCESSES
- value: 1
- - name: JAX_PROCESS_ID
- value: 0
- # process 0 is coordinator
+ - image: ghcr.io/kubeflow/jax:latest
+ imagePullPolicy: IfNotPresent
+ name: worker
+ env:
+ - name: JAX_COORDINATOR_ADDRESS
+ value: "127.0.0.1:6666"
+ - name: JAX_NUM_PROCESSES
+ value: 1
+ - name: JAX_PROCESS_ID
+ value: 0
+ # process 0 is coordinator
restartPolicy: OnFailure
-
```
## Alternatives