-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #143 from ZaberKo/dev-fix2
Bug Fix & Distributed Training Improvement
- Loading branch information
Showing
99 changed files
with
1,121 additions
and
1,302 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,3 +24,4 @@ src/evox/algorithms/mo/_rm_meda.py | |
tests/test_.py | ||
tests/log.txt | ||
.ipynb_checkpoints/ | ||
/*.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,5 +6,5 @@ Workflows | |
:maxdepth: 1 | ||
|
||
standard | ||
distributed | ||
non_jit | ||
.. distributed | ||
.. non_jit |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Distribute the workflow | ||
|
||
EvoX provides two distributed workflow implementation, one is based on Ray, and the other one is based on jax.distribute. | ||
|
||
## RayDistributedWorkflow | ||
|
||
RayDistributedWorkflow is built upon Ray. It can be used on any ray cluster. The Ray cluster should be setup before running the EvoX program. | ||
|
||
### Setup Ray cluster | ||
|
||
Please refer to [Ray's official documentation](https://docs.ray.io/en/latest/cluster/getting-started.html) for guide on setting up an Ray cluster. | ||
|
||
Here is a simple way to setup the cluster locally. | ||
|
||
- On the head node | ||
```bash | ||
ray start --head | ||
``` | ||
- On worker nodes | ||
```bash | ||
ray start --address="<your head node's ip>:6379" | ||
``` | ||
|
||
If you only have 1 machine, but multiple devices, then there is nothing needs to be done. Ray will setup itself in this case. | ||
|
||
### Setup EvoX | ||
|
||
To scale the workflow using multiple machines through Ray, use the {class}`RayDistributedWorkflow <evox.workflows.RayDistributedWorkflow>` instead of StdWorkflow. | ||
|
||
First, import `workflows` from evox | ||
|
||
```python | ||
from evox import workflows | ||
``` | ||
|
||
then create your algorithm, problem, monitor object as usual. | ||
|
||
```python | ||
algorithm = ... | ||
problem = ... | ||
monitor = ... | ||
``` | ||
|
||
Now use `RayDistributedWorkflow` | ||
```python | ||
workflow = workflows.RayDistributedWorkflow( | ||
algorithm=algorithm, | ||
problem=problem, | ||
monitors=[monitor], | ||
num_workers=4, # the number of machines | ||
options={ # the options that passes to ray | ||
"num_gpus": 1 | ||
} | ||
) | ||
``` | ||
|
||
The `RayDistributedWorkflow` also uses the `workflow.step` function to execute iterations. However, under the hood, it employs a distinct approach that allows for the utilization of multiple devices across different machines. | ||
|
||
```{tip} | ||
It is recommanded that one set the environment variable `XLA_PYTHON_CLIENT_PREALLOCATE=false`. | ||
By default JAX will pre-allocate 80% of the device's memory. | ||
This variable disables the GPU memory preallocation, otherwise running multiple JAX processes may cause OOM. | ||
For more information, please refer to [JAX's documentation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) on this matter. | ||
``` | ||
|
||
## StdWorkflow | ||
|
||
StdWorkflow is short for "Universal Workflow", | ||
which aims to use pure JAX to build a workflow that fits any requirement. | ||
Since `StdWorkflow` is written in pure JAX, it has less overhead and don't need any additional dependencies. | ||
### Setup EvoX | ||
Use `StdWorkflow` to create an workflow, | ||
and use `enable_distributed` and pass in the state to enable this feature. | ||
```python | ||
key = jax.random.PRNGKey(0) # a PRNGKey | ||
workflow = workflows.StdWorkflow( | ||
algorithm, | ||
problem, | ||
monitors=[monitor], | ||
) | ||
state = workflow.init(key) # init as usual | ||
# important: enable this feature | ||
state = workflow.enable_distributed(state) | ||
``` | ||
Then, at the start of your program, before any JAX function is called, do this: | ||
```python | ||
jax.distributed.initialize(coordinator_address=..., num_process=...,process_id=...) | ||
``` | ||
In this system, the `coordinator` serves as the primary or head node. The total number of participating processes is indicated by `num_process`. The process with `process_id=0` acts as the coordinator. | ||
From more information, please refer to [jax.distributed.initialize](https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html) and [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html). | ||
### Run in a cluster | ||
Unlike Ray, JAX's doesn't have the concept of cluster or scheduler. | ||
Instead, it offers tools for enabling distributed interactions among multiple JAX instances. JAX follows the SPMD (single program multiple data) paradigm. To initiate a distributed program in JAX, you simply need to run the same script on different machines. For instance, if your program is named `main.py`, you should execute the following command on all participating machines with different `process_id` argument in `jax.distributed.initialize`: | ||
```bash | ||
python main.py | ||
``` | ||
```{tip} | ||
To have `process_id` in the argument, one can use `argparse` to parse the argument from the commandline. | ||
For example: | ||
```python | ||
import argparse | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('process_id', type=int) | ||
args = parser.parse_args() | ||
jax.distributed.initialize( | ||
coordinator_address=..., | ||
num_processes=..., | ||
process_id=args.process_id, | ||
) | ||
``` | ||
Then call `python main.py 0` on the first machine, `python main 1` on the second machine and so on. | ||
``` | ||
### Run on a single machine | ||
In addition to distributed execution across multiple machines, `StdWorkflow` also supports running on a single machine with multiple GPUs. In this scenario, communication between different devices is facilitated by `nccl`, which is considerably more efficient than cross-machine communication. | ||
The setup process remains unchanged from the previous instructions mentioned above. However, since you are working with only a single machine, the subsequent step for multiple machines is no longer necessary: | ||
```python | ||
jax.distributed.initialize(coordinator_address=..., num_process=...,process_id=...) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,137 +1,91 @@ | ||
# Distribute the workflow | ||
# Distributed Training | ||
|
||
EvoX provides two distributed workflow implementation, one is based on Ray, and the other one is based on jax.distribute. | ||
## Parallel Model | ||
|
||
## RayDistributedWorkflow | ||
All states are replicated across all devices including population. Then, on every device, a sharded candidates are passed to `problem.evaluate()`, and the fitnesses are shared across all device (by `all_gather`). This ensures all devices share the same state data without explicit synchronization. In other word, this parallel model only accelerate the problem's evaluation part, and cannot reduce the memory consumption. We use it as our default distributed strategy, as it offers EC algorithms maximum flexibility. | ||
|
||
RayDistributedWorkflow is built upon Ray. It can be used on any ray cluster. The Ray cluster should be setup before running the EvoX program. | ||
## Multiple devices on a single node | ||
|
||
### Setup Ray cluster | ||
|
||
Please refer to [Ray's official documentation](https://docs.ray.io/en/latest/cluster/getting-started.html) for guide on setting up an Ray cluster. | ||
|
||
Here is a simple way to setup the cluster locally. | ||
|
||
- On the head node | ||
```bash | ||
ray start --head | ||
``` | ||
- On worker nodes | ||
```bash | ||
ray start --address="<your head node's ip>:6379" | ||
``` | ||
|
||
If you only have 1 machine, but multiple devices, then there is nothing needs to be done. Ray will setup itself in this case. | ||
|
||
### Setup EvoX | ||
|
||
To scale the workflow using multiple machines through Ray, use the {class}`RayDistributedWorkflow <evox.workflows.RayDistributedWorkflow>` instead of StdWorkflow. | ||
|
||
First, import `workflows` from evox | ||
Example: | ||
|
||
```python | ||
from evox import workflows | ||
``` | ||
|
||
then create your algorithm, problem, monitor object as usual. | ||
|
||
```python | ||
algorithm = ... | ||
problem = ... | ||
monitor = ... | ||
``` | ||
|
||
Now use `RayDistributedWorkflow` | ||
```python | ||
workflow = workflows.RayDistributedWorkflow( | ||
algorithm=algorithm, | ||
problem=problem, | ||
monitors=[monitor], | ||
num_workers=4, # the number of machines | ||
options={ # the options that passes to ray | ||
"num_gpus": 1 | ||
} | ||
import jax | ||
import jax.tree_util as jtu | ||
from evox import algorithms, problems, workflows | ||
from evox.core.distributed import tree_unpmap | ||
|
||
cso = algorithms.CSO( | ||
lb=jnp.full(shape=(2,), fill_value=-32), | ||
ub=jnp.full(shape=(2,), fill_value=32), | ||
pop_size=16*4, | ||
) | ||
``` | ||
|
||
The `RayDistributedWorkflow` also uses the `workflow.step` function to execute iterations. However, under the hood, it employs a distinct approach that allows for the utilization of multiple devices across different machines. | ||
|
||
```{tip} | ||
It is recommanded that one set the environment variable `XLA_PYTHON_CLIENT_PREALLOCATE=false`. | ||
By default JAX will pre-allocate 80% of the device's memory. | ||
This variable disables the GPU memory preallocation, otherwise running multiple JAX processes may cause OOM. | ||
For more information, please refer to [JAX's documentation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) on this matter. | ||
``` | ||
|
||
## StdWorkflow | ||
ackley = problems.numerical.Ackley() | ||
workflow = workflows.StdWorkflow(cso, ackley) | ||
|
||
StdWorkflow is short for "Universal Workflow", | ||
which aims to use pure JAX to build a workflow that fits any requirement. | ||
Since `StdWorkflow` is written in pure JAX, it has less overhead and don't need any additional dependencies. | ||
key = random.PRNGKey(42) | ||
with jax.default_device(devices[0]): | ||
state = workflow.init(key) | ||
|
||
### Setup EvoX | ||
state = workflow.enable_multi_devices(state, devices) | ||
|
||
Use `StdWorkflow` to create an workflow, | ||
and use `enable_distributed` and pass in the state to enable this feature. | ||
```python | ||
key = jax.random.PRNGKey(0) # a PRNGKey | ||
workflow = workflows.StdWorkflow( | ||
algorithm, | ||
problem, | ||
monitors=[monitor], | ||
) | ||
state = workflow.init(key) # init as usual | ||
# important: enable this feature | ||
state = workflow.enable_distributed(state) | ||
for i in range(100): | ||
train_info, state = workflow.step(state) | ||
train_info = tree_unpmap(train_info, workflow.pmap_axis_name) | ||
print(train_info['transformed_fitness']) | ||
``` | ||
|
||
Then, at the start of your program, before any JAX function is called, do this: | ||
## Multiple devices on multiple nodes | ||
|
||
Example of script `dist_train.py` | ||
|
||
```python | ||
jax.distributed.initialize(coordinator_address=..., num_process=...,process_id=...) | ||
``` | ||
|
||
In this system, the `coordinator` serves as the primary or head node. The total number of participating processes is indicated by `num_process`. The process with `process_id=0` acts as the coordinator. | ||
import argparse | ||
import jax | ||
|
||
From more information, please refer to [jax.distributed.initialize](https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html) and [Using JAX in multi-host and multi-process environments](https://jax.readthedocs.io/en/latest/multi_process.html). | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--addr', type=str, default='127.0.0.1:37233') | ||
parser.add_argument('-n', type=int, required=True) | ||
parser.add_argument('-i', type=int, required=True) | ||
args = parser.parse_args() | ||
|
||
### Run in a cluster | ||
jax.distributed.initialize(coordinator_address=args.addr, num_processes=args.n, process_id=args.i, initialization_timeout=30) | ||
|
||
Unlike Ray, JAX's doesn't have the concept of cluster or scheduler. | ||
Instead, it offers tools for enabling distributed interactions among multiple JAX instances. JAX follows the SPMD (single program multiple data) paradigm. To initiate a distributed program in JAX, you simply need to run the same script on different machines. For instance, if your program is named `main.py`, you should execute the following command on all participating machines with different `process_id` argument in `jax.distributed.initialize`: | ||
total_devices = jax.devices() | ||
devices = jax.local_devices() | ||
|
||
```bash | ||
python main.py | ||
``` | ||
print(f'total_devices: {total_devices}') | ||
print(f'devices: {devices}') | ||
|
||
```{tip} | ||
To have `process_id` in the argument, one can use `argparse` to parse the argument from the commandline. | ||
For example: | ||
from evox import algorithms, problems, workflows | ||
from evox.core.distributed import tree_unpmap | ||
|
||
```python | ||
import argparse | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('process_id', type=int) | ||
args = parser.parse_args() | ||
jax.distributed.initialize( | ||
coordinator_address=..., | ||
num_processes=..., | ||
process_id=args.process_id, | ||
cso = algorithms.CSO( | ||
lb=jnp.full(shape=(2,), fill_value=-32), | ||
ub=jnp.full(shape=(2,), fill_value=32), | ||
pop_size=16*30, | ||
) | ||
``` | ||
ackley = problems.numerical.Ackley() | ||
workflow = workflows.StdWorkflow(cso, ackley) | ||
|
||
Then call `python main.py 0` on the first machine, `python main 1` on the second machine and so on. | ||
key = jax.random.PRNGKey(42) | ||
state = workflow.init(key) | ||
state = workflow.enable_multi_devices(state, devices) | ||
|
||
``` | ||
for i in range(10): | ||
train_info, state = workflow.step(state) | ||
train_info = tree_unpmap(train_info, workflow.pmap_axis_name) | ||
print(train_info['transformed_fitness']) | ||
|
||
### Run on a single machine | ||
jax.distributed.shutdown() | ||
``` | ||
|
||
In addition to distributed execution across multiple machines, `StdWorkflow` also supports running on a single machine with multiple GPUs. In this scenario, communication between different devices is facilitated by `nccl`, which is considerably more efficient than cross-machine communication. | ||
Run script on each node: | ||
|
||
The setup process remains unchanged from the previous instructions mentioned above. However, since you are working with only a single machine, the subsequent step for multiple machines is no longer necessary: | ||
```shell | ||
# node1 with ip 10.233.96.181 | ||
python dist_train.py --addr 10.233.96.181:35429 -n 2 -i 0 | ||
|
||
```python | ||
jax.distributed.initialize(coordinator_address=..., num_process=...,process_id=...) | ||
``` | ||
# node2 | ||
python dist_train.py --addr 10.233.96.181:35429 -n 2 -i 1 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,11 @@ | ||
from .core.workflow import Workflow | ||
from .core.algorithm import Algorithm | ||
from .core.module import * | ||
from .core.algorithm import Algorithm, has_init_ask, has_init_tell | ||
from .core.module import use_state, jit_class, jit_method, Stateful | ||
from .core.problem import Problem | ||
from .core.state import State | ||
from .core.state import State, get_state_sharding | ||
from .core.monitor import Monitor | ||
from .core.pytree_dataclass import dataclass, pytree_field, PyTreeNode | ||
|
||
from . import algorithms, monitors, operators, workflows, problems, utils | ||
# from .core.distributed import POP_AXIS_NAME, ShardingType | ||
|
||
# from . import algorithms, monitors, operators, workflows, problems, utils |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .clustered_algorithm import ClusterdAlgorithm, RandomMaskAlgorithm | ||
from .tree_algorithm import TreeAlgorithm | ||
from .coevolution import VectorizedCoevolution, Coevolution | ||
from .coevolution import VectorizedCoevolution, Coevolution |
Oops, something went wrong.