diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 9b38d936..558b2537 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: ["3.10", "3.11"] steps: - uses: actions/checkout@v3 @@ -30,20 +30,11 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install build wheel setuptools pytest - - name: Build and install package Python 3.9 (no GPJax) - if: matrix.python-version == '3.9' - run: | - output=$(python -m build --wheel) - pip install dist/${output##* }[gymnasium,envpool,neuroevolution,distributed,test] - name: Build and install package Python 3.10 and above if: matrix.python-version == '3.10' || matrix.python-version == '3.11' run: | output=$(python -m build --wheel) pip install dist/${output##* }[full,test] - - name: Test with pytest Python 3.9 - if: matrix.python-version == '3.9' - run: | - pytest -k 'not test_im_moea and not test_gp' - name: Test with pytest Python 3.10 and above if: matrix.python-version == '3.10' || matrix.python-version == '3.11' run: | diff --git a/docs/source/api/problems/neuroevolution/supervised_learning/index.rst b/docs/source/api/problems/neuroevolution/supervised_learning/index.rst index 760e7193..5eaae71a 100644 --- a/docs/source/api/problems/neuroevolution/supervised_learning/index.rst +++ b/docs/source/api/problems/neuroevolution/supervised_learning/index.rst @@ -5,4 +5,4 @@ Supervised Learning .. toctree:: :maxdepth: 1 - torchvision + tfds diff --git a/docs/source/api/problems/neuroevolution/supervised_learning/tfds.rst b/docs/source/api/problems/neuroevolution/supervised_learning/tfds.rst new file mode 100644 index 00000000..9cfff5f4 --- /dev/null +++ b/docs/source/api/problems/neuroevolution/supervised_learning/tfds.rst @@ -0,0 +1,6 @@ +================== +Tensorflow Dataset +================== + +.. autoclass:: evox.problems.neuroevolution.TensorflowDataset + :members: diff --git a/docs/source/api/problems/neuroevolution/supervised_learning/torchvision.rst b/docs/source/api/problems/neuroevolution/supervised_learning/torchvision.rst deleted file mode 100644 index c95c5b50..00000000 --- a/docs/source/api/problems/neuroevolution/supervised_learning/torchvision.rst +++ /dev/null @@ -1,6 +0,0 @@ -=================== -Torchvision Dataset -=================== - -.. autoclass:: evox.problems.neuroevolution.supervised_learning.TorchvisionDataset - :members: diff --git a/docs/source/conf.py b/docs/source/conf.py index 3e23dd98..0fd0c3e3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -57,8 +57,8 @@ "envpool", "gymnasium", "ray", - "torch", - "torchvision", + "tensorflow_datasets", + "grain", "gpjax", ] diff --git a/pyproject.toml b/pyproject.toml index bef78f8d..75d9f256 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ test = [ "chex >= 0.1.0", "flax >= 0.5.0", "pytest >= 6.0.0", + "tensorflow >= 2.12.0", ] vis = [ @@ -55,8 +56,9 @@ gymnasium = ["gymnasium >= 0.29.0"] envpool = ["envpool >= 0.8.0"] neuroevolution = [ - "torch >= 1.0.0", - "torchvision >= 0.1.0", + "tensorflow-datasets >= 4.0.0", + "grain >= 0.1.0", + "brax >= 0.1.0", ] distributed = ["ray >= 2.0.0"] @@ -64,12 +66,15 @@ distributed = ["ray >= 2.0.0"] full = [ "gymnasium >= 0.29.0", "ray >= 2.0.0", - "torch >= 1.0.0", - "torchvision >= 0.1.0", "envpool >= 0.8.0", "gpjax >= 0.8.0", "plotly >= 5.0.0", "pandas >= 2.0.0", + "tensorflow-datasets >= 4.0.0", + "grain >= 0.1.0", + "brax >= 0.1.0", + "plotly >= 5.0.0", + "pandas >= 2.0.0", ] gp = ["gpjax >= 0.8.0"] diff --git a/requirements/test-requirements.txt b/requirements/test-requirements.txt index 00bbedf6..c3ca4c99 100644 --- a/requirements/test-requirements.txt +++ b/requirements/test-requirements.txt @@ -5,7 +5,7 @@ flax >= 0.5.0 pytest >= 6.0.0 gymnasium >= 0.29.0 ray >= 2.0.0 -torch >= 1.0.0 -torchvision >= 0.1.0 +tensorflow-datasets >= 4.0.0, +grain >= 0.1.0, envpool >= 0.8.0 gpjax >= 0.8.0 diff --git a/src/evox/problems/neuroevolution/reinforcement_learning/env_pool.py b/src/evox/problems/neuroevolution/reinforcement_learning/env_pool.py index 5dd0518c..99d220f7 100644 --- a/src/evox/problems/neuroevolution/reinforcement_learning/env_pool.py +++ b/src/evox/problems/neuroevolution/reinforcement_learning/env_pool.py @@ -5,32 +5,10 @@ import jax.numpy as jnp import numpy as np from jax import jit, random, vmap, lax -from jax.tree_util import tree_map from jax.experimental import io_callback from evox import Problem, State, jit_class - - -def _x32_func_call(func): - def inner_func(*args, **kwargs): - return _to_x32_if_needed(func(*args, **kwargs)) - - return inner_func - - -def _to_x32_if_needed(values): - if jax.config.jax_enable_x64: - # we have 64-bit enabled, so nothing to do - return values - - def to_x32(value): - if value.dtype == np.float64: - return value.astype(np.float32) - elif value.dtype == np.int64: - return value.astype(np.int32) - else: - return value - return tree_map(to_x32, values) +from evox.utils.io import to_x32_if_needed, x32_func_call @jit_class @@ -60,8 +38,8 @@ def evaluate(self, state, pop): key, subkey = random.split(state.key) seed = random.randint(subkey, (1,), 0, jnp.iinfo(jnp.int32).max) io_callback(self.env.seed, None, seed) - obs, info = _to_x32_if_needed(self.env.reset(None)) - obs, info = io_callback(_x32_func_call(self.env.reset), (obs, info), None) + obs, info = to_x32_if_needed(self.env.reset(None)) + obs, info = io_callback(x32_func_call(self.env.reset), (obs, info), None) total_reward = 0 i = 0 @@ -75,11 +53,11 @@ def cond_func(loop_state): def step(loop_state): i, done, total_reward, obs = loop_state action = self.batch_policy(pop, obs) - obs, reward, terminated, truncated, info = _to_x32_if_needed( + obs, reward, terminated, truncated, info = to_x32_if_needed( self.env.step(np.zeros(action.shape)) ) obs, reward, terminated, truncated, info = io_callback( - _x32_func_call(lambda action: self.env.step(np.copy(action))), + x32_func_call(lambda action: self.env.step(np.copy(action))), (obs, reward, terminated, truncated, info), action, ) diff --git a/src/evox/problems/neuroevolution/supervised_learning/__init__.py b/src/evox/problems/neuroevolution/supervised_learning/__init__.py index 1e7ee974..ec58c1c1 100644 --- a/src/evox/problems/neuroevolution/supervised_learning/__init__.py +++ b/src/evox/problems/neuroevolution/supervised_learning/__init__.py @@ -1,10 +1,10 @@ try: # optional dependency: torchvision, optax - from .torchvision_dataset import TorchvisionDataset + from .tfds import TensorflowDataset except ImportError as e: original_error_msg = str(e) - def TorchvisionDataset(*args, **kwargs): + def TensorflowDataset(*args, **kwargs): raise ImportError( - f'TorchvisionDataset requires torchvision, optax but got "{original_error_msg}" when importing' + f'TensorflowDataset requires tensorflow-datasets, grain but got "{original_error_msg}" when importing' ) diff --git a/src/evox/problems/neuroevolution/supervised_learning/tfds.py b/src/evox/problems/neuroevolution/supervised_learning/tfds.py new file mode 100644 index 00000000..dab3eec9 --- /dev/null +++ b/src/evox/problems/neuroevolution/supervised_learning/tfds.py @@ -0,0 +1,135 @@ +from dataclasses import field +from typing import Any, Callable, List, Optional + +import grain.python as pygrain +import jax +import jax.numpy as jnp +import tensorflow_datasets as tfds +from jax.tree_util import tree_map + +from evox import Problem, Static, dataclass, jit_class +from evox.utils.io import x32_func_call + + +def get_dtype_shape(data): + def to_dtype_struct(x): + if hasattr(x, "shape") and hasattr(x, "dtype"): + return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype) + elif isinstance(x, int): + return jax.ShapeDtypeStruct(shape=(), dtype=jnp.int32) + elif isinstance(x, float): + return jax.ShapeDtypeStruct(shape=(), dtype=jnp.float32) + + return tree_map(to_dtype_struct, data) + + +@jit_class +@dataclass +class TensorflowDataset(Problem): + """Wrap a tensorflow dataset as a problem. + + TensorFlow Datasets (TFDS) directly depends on the package `tensorflow-datasets` and `grain`. + Additionally, when downloading the dataset for the first time, it requires `tensorflow` to be installed and a active internet connection. + If you want to avoid installing `tensorflow`, you can prepare the dataset beforehand in another environment with `tensorflow` installed, + run: + + .. code-block:: python + + import tensorflow_datasets as tfds + tfds.data_source(self.dataset) + + and then copy the dataset to the target machine. + The default location is`~/tensorflow_datasets`. `~/` means the home directory of the user. + + Please notice that the data is loaded under JAX's jit context, so the data should be valid JAX data type, + namely JAX or Numpy arrays, or Python's int, float, list, and dict. + If the data contains other types like strings, you should convert them into arrays using the `operations` parameter. + + You can also download the dataset through a proxy server by setting the environment variable `TFDS_HTTP_PROXY` and `TFDS_HTTPS_PROXY`, + for http and https proxy respectively. + + The details of the dataset can be found at https://www.tensorflow.org/datasets/catalog/overview + The details about operations/transformations can be found at https://github.com/google/grain/blob/main/docs/transformations.md + + Parameters + ---------- + dataset + The dataset name. + batch_size + The batch size. + loss_func + The loss function. + The function signature is loss(weights, data) -> loss_value, and it should be jittable. + The `weight` is the weight of the neural network, and the `data` is the data from TFDS, which is a dictionary. + split + Which split of the dataset to use. + Default to "train". + operations + The list of transformations to apply to the data. + Default to []. + After the transformations, we will always apply a batch operation to create a batch of data. + datadir + The directory to store the dataset. + Default to None, which means tensorflow-datasets will automatically determine the directory. + seed + The random seed used to seed the dataloader. + Given the same seed, the dataloader should data in the same order. + Default to 0. + try_gcs + Whether to try to download the dataset from Google Cloud Storage. + Usually Google's storage server is faster than the original server of the dataset. + """ + + dataset: Static[str] + batch_size: Static[int] + loss_func: Static[Callable] + split: Static[str] = field(default="train") + operations: Static[List[Any]] = field(default_factory=list) + datadir: Static[Optional[str]] = field(default=None) + seed: Static[int] = field(default=0) + try_gcs: Static[bool] = field(default=True) + iterator: Static[pygrain.PyGrainDatasetIterator] = field(init=False) + data_shape_dtypes: Static[Any] = field(init=False) + + def __post_init__(self): + if self.datadir is None: + data_source = tfds.data_source( + self.dataset, try_gcs=self.try_gcs, split=self.split + ) + else: + data_source = tfds.data_source( + self.dataset, + data_dir=self.datadir, + try_gcs=self.try_gcs, + split=self.split, + ) + + sampler = pygrain.IndexSampler( + num_records=len(data_source), + shard_options=pygrain.NoSharding(), + shuffle=True, + seed=self.seed, + ) + + operations = self.operations + [ + pygrain.Batch(batch_size=self.batch_size, drop_remainder=True) + ] + + loader = pygrain.DataLoader( + data_source=data_source, + operations=operations, + sampler=sampler, + worker_count=0, + ) + object.__setattr__(self, "iterator", iter(loader)) + data_shape_dtypes = get_dtype_shape(self._next_data()) + object.__setattr__(self, "data_shape_dtypes", data_shape_dtypes) + + @x32_func_call + def _next_data(self): + return next(self.iterator) + + def evaluate(self, state, pop): + data = jax.experimental.io_callback(self._next_data, self.data_shape_dtypes) + loss = jax.vmap(self.loss_func, in_axes=(0, None))(pop, data) + return loss, state diff --git a/src/evox/problems/neuroevolution/supervised_learning/torchvision_dataset.py b/src/evox/problems/neuroevolution/supervised_learning/torchvision_dataset.py deleted file mode 100644 index 0417b171..00000000 --- a/src/evox/problems/neuroevolution/supervised_learning/torchvision_dataset.py +++ /dev/null @@ -1,312 +0,0 @@ -import warnings -from functools import partial -from typing import Callable, NamedTuple, Optional, Union - -import evox -import jax -import jax.numpy as jnp -import numpy as np -import optax -from evox import Problem, State, Stateful, jit_class, jit_method -from jax import jit, lax, vmap -from jax.tree_util import tree_leaves -from torch.utils.data import DataLoader, Dataset, Sampler, Subset, random_split -from torchvision import datasets - - -def np_collate_fn(batch: list): - data, labels = list(zip(*batch)) - return np.stack(data), np.array(labels) - - -def jnp_collate_fn(batch: list): - data, labels = list(zip(*batch)) - return jnp.stack(data), jnp.array(labels) - - -class InMemoryDataset: - def __init__(self, data, labels): - self.data = data - self.labels = labels - - def __len__(self): - return self.labels.shape[0] - - def __getitem__(self, indices): - return self.data[indices], self.labels[indices] - - @classmethod - def from_pytorch_dataset(cls, dataset): - dataset = [dataset[i] for i in range(len(dataset))] - data, labels = list(zip(*dataset)) - data, labels = jnp.array(data), jnp.array(labels) - return cls(data, labels) - - -class DeterministicRandomSampler(Sampler): - def __init__(self, key, max_len): - self.key = key - self.max_len = max_len - - def reset(self, indices): - self.indices = indices - - def __iter__(self): - self.key, subkey = jax.random.split(self.key) - return iter(jax.random.permutation(subkey, self.max_len).tolist()) - - def __len__(self): - return self.max_len - - -@jit_class -class TorchvisionDataset(Problem): - def __init__( - self, - root: str, - forward_func: Callable, - batch_size: int, - num_passes: int = 1, - dataset_name: Optional[str] = None, - train_dataset: Optional[Dataset] = None, - test_dataset: Optional[Dataset] = None, - valid_percent: float = 0.2, - num_workers: int = 0, - in_memory: bool = False, - loss_func: Callable = optax.softmax_cross_entropy_with_integer_labels, - ): - self.num_passes = num_passes - - if batch_size % num_passes != 0: - self.batch_size = int(round(batch_size / num_passes)) - warn_msg = f"batch_size isn't evenly divisible by num_passes, the actual batch size will be rounded" - warnings.warn(warn_msg) - else: - self.batch_size = batch_size // num_passes - - self.forward_func = forward_func - self.loss_func = loss_func - self.valid_percent = valid_percent - self.num_workers = num_workers - self.in_memory = in_memory - self.collate_fn = np_collate_fn - - if train_dataset is not None and isinstance(train_dataset, Dataset): - if dataset_name is not None: - warnings.warn( - "When train_dataset and test_dataset are specified, dataset_name is ignored" - ) - - self.train_dataset = train_dataset - self.test_dataset = test_dataset - elif dataset_name == "mnist": - self.train_dataset = datasets.MNIST( - root, train=True, download=True, transform=np.array - ) - self.test_dataset = datasets.MNIST( - root, train=False, download=True, transform=np.array - ) - elif dataset_name == "cifar10": - self.train_dataset = datasets.CIFAR10( - root, train=True, download=True, transform=np.array - ) - self.test_dataset = datasets.CIFAR10( - root, train=False, download=True, transform=np.array - ) - elif dataset_name == "imagenet": - self.train_dataset = datasets.ImageNet( - root, train=True, download=True, transform=np.array - ) - self.test_dataset = datasets.ImageNet( - root, train=False, download=True, transform=np.array - ) - else: - raise ValueError(f"Not supported dataset: {dataset_name}") - - if in_memory: - if self.num_workers != 0: - warnings.warn("When in_memory is True, num_workers is ignored") - self.train_dataset = InMemoryDataset.from_pytorch_dataset( - self.train_dataset - ) - self.test_dataset = InMemoryDataset.from_pytorch_dataset(self.test_dataset) - - @jit_method - def _new_permutation(self, key): - num_batches = len(self.train_dataset) // self.batch_size - permutation = jax.random.permutation(key, len(self.train_dataset)) - permutation = permutation[: num_batches * self.batch_size].reshape( - (-1, self.batch_size) - ) - return permutation - - def _in_memory_setup(self, key): - key, subset_key, perm_key = jax.random.split(key, num=3) - indices = jax.random.permutation(subset_key, len(self.train_dataset)) - valid_len = int(len(self.train_dataset) * self.valid_percent) - train_len = len(self.train_dataset) - valid_len - - self.valid_dataset = InMemoryDataset(*self.train_dataset[indices[train_len:]]) - self.train_dataset = InMemoryDataset(*self.train_dataset[indices[:train_len]]) - - permutation = self._new_permutation(perm_key) - - return State(key=key, permutation=permutation, iter=0, mode=0, metric_func=0) - - def _setup(self, key): - key, subset_key, sampler_key = jax.random.split(key, num=3) - indices = jax.random.permutation(subset_key, len(self.train_dataset)) - valid_len = int(len(self.train_dataset) * self.valid_percent) - train_len = len(self.train_dataset) - valid_len - - self.valid_dataset = Subset(self.train_dataset, indices[train_len:].tolist()) - self.train_dataset = Subset(self.train_dataset, indices[:train_len].tolist()) - self.sampler = DeterministicRandomSampler(sampler_key, len(self.train_dataset)) - - self.train_dataloader = DataLoader( - self.train_dataset, - sampler=self.sampler, - batch_size=self.batch_size, - collate_fn=self.collate_fn, - num_workers=self.num_workers, - drop_last=True, - ) - - self.train_iter = iter(self.train_dataloader) - # to work with jit, mode and metric are integer type - # 0 - train, 1 - valid, 2 - test - return State(mode=0, metric=0) - - def setup(self, key): - if self.in_memory: - return self._in_memory_setup(key) - else: - return self._setup(key) - - def train(self, state, metric=0): - return state.update(mode=0, metric_func=metric) - - def valid(self, state, metric=0): - return state.update(mode=1, metric_func=metric) - - def test(self, state, metric=0): - return state.update(mode=2, metric_func=metric) - - @jit_method - def _evaluate_in_memory_train(self, state, batch_params): - def new_epoch(state): - key, subkey = jax.random.split(state.key) - permutation = self._new_permutation(subkey) - return state.update(key=key, permutation=permutation, iter=0) - - def batch_evaluate(i, state_and_acc): - state, accumulator = state_and_acc - - state = lax.cond( - state.iter >= state.permutation.shape[0], - new_epoch, - lambda x: x, # identity - state, - ) - data, labels = self.train_dataset[state.permutation[state.iter]] - losses = self._metric_func(state, data, labels, batch_params) - return state.update(iter=state.iter + 1), accumulator + losses - - pop_size = tree_leaves(batch_params)[0].shape[0] - if self.num_passes > 1: - state, total_loss = lax.fori_loop( - 0, self.num_passes, batch_evaluate, (state, jnp.zeros((pop_size,))) - ) - else: - state, total_loss = batch_evaluate(0, (state, jnp.zeros((pop_size,)))) - - return total_loss / self.batch_size / self.num_passes, state - - def _evaluate_in_memory_valid(self, state, batch_params): - num_batches = len(self.valid_dataset) // self.batch_size - permutation = jnp.arange(num_batches * self.batch_size).reshape( - num_batches, self.batch_size - ) - - def batch_evaluate(i, accumulated_metric): - data, labels = self.valid_dataset[permutation[i]] - return accumulated_metric + self._metric_func( - state, data, labels, batch_params - ) - - pop_size = tree_leaves(batch_params)[0].shape[0] - metric = lax.fori_loop(0, num_batches, batch_evaluate, jnp.zeros((pop_size,))) - return metric / (num_batches * self.batch_size), state - - def _evaluate_train(self, state, batch_params): - try: - data, labels = next(self.train_iter) - # data, labels = jnp.asarray(data), jnp.asarray(labels) - except StopIteration: - self.train_iter = iter(self.train_dataloader) - data, labels = next(self.train_iter) - # data, labels = jnp.asarray(data), jnp.asarray(labels) - - pop_size = tree_leaves(batch_params)[0].shape[0] - total_loss = jnp.zeros((pop_size,)) - for _ in range(self.num_passes): - total_loss += self._calculate_loss(data, labels, batch_params) - return total_loss / self.batch_size / self.num_passes, state - - def _evaluate_valid(self, state, batch_params): - valid_dataloader = DataLoader( - self.valid_dataset, - batch_size=self.batch_size, - collate_fn=self.collate_fn, - num_workers=self.num_workers, - drop_last=True, - ) - - accumulated_metric = 0 - for data, labels in valid_dataloader: - accumulated_metric += self._metric_func(state, data, labels, batch_params) - return accumulated_metric / (len(valid_dataloader) * self.batch_size), state - - def evaluate(self, state, batch_params): - if self.in_memory: - return lax.switch( - state.mode, - [self._evaluate_in_memory_train, self._evaluate_in_memory_valid], - state, - batch_params, - ) - else: - return lax.switch( - state.mode, - [self._evaluate_train_mode, self._evaluate_valid_mode], - state, - batch_params, - ) - - def _metric_func(self, state, data, labels, batch_params): - return lax.switch( - state.metric_func, - [self._calculate_loss, self._calculate_accuracy], - data, - labels, - batch_params, - ) - - @jit_method - def _calculate_accuracy(self, data, labels, batch_params): - output = vmap(self.forward_func, in_axes=(0, None))( - batch_params, data - ) # (pop_size, batch_size, out_dim) - output = jnp.argmax(output, axis=2) # (pop_size, batch_size) - num_correct = jnp.sum((output == labels), axis=1) # don't reduce here - num_correct = num_correct.astype(jnp.float32) - - return num_correct - - @jit_method - def _calculate_loss(self, data, labels, batch_params): - output = vmap(self.forward_func, in_axes=(0, None))(batch_params, data) - loss = jnp.sum(vmap(self.loss_func, in_axes=(0, None))(output, labels), axis=1) - loss = loss.astype(jnp.float32) - - return loss diff --git a/src/evox/utils/__init__.py b/src/evox/utils/__init__.py index e791d9ec..620d8823 100644 --- a/src/evox/utils/__init__.py +++ b/src/evox/utils/__init__.py @@ -1 +1,2 @@ -from .common import * \ No newline at end of file +from .common import * +import io \ No newline at end of file diff --git a/src/evox/utils/io.py b/src/evox/utils/io.py new file mode 100644 index 00000000..43d8c4f0 --- /dev/null +++ b/src/evox/utils/io.py @@ -0,0 +1,26 @@ +import numpy as np +import jax +from jax.tree_util import tree_map + + +def x32_func_call(func): + def inner_func(*args, **kwargs): + return to_x32_if_needed(func(*args, **kwargs)) + + return inner_func + + +def to_x32_if_needed(values): + if jax.config.jax_enable_x64: + # we have 64-bit enabled, so nothing to do + return values + + def to_x32(value): + if value.dtype == np.float64: + return value.astype(np.float32) + elif value.dtype == np.int64: + return value.astype(np.int32) + else: + return value + + return tree_map(to_x32, values) diff --git a/tests/test_neuroevolution.py b/tests/test_neuroevolution.py index a077e953..accce567 100644 --- a/tests/test_neuroevolution.py +++ b/tests/test_neuroevolution.py @@ -1,115 +1,71 @@ -import time +import math import jax import jax.numpy as jnp -import pytest from flax import linen as nn -from evox import algorithms, workflows, problems, utils -from evox.monitors import StdSOMonitor +from evox import algorithms, problems, workflows +from evox.monitors import EvalMonitor +from evox.utils import TreeAndVector, rank_based_fitness -class PartialPGPE(algorithms.PGPE): - def __init__(self, center_init): - super().__init__( - 100, center_init, "adam", center_learning_rate=0.01, stdev_init=0.01 - ) - -class SimpleCNN(nn.Module): - """A simple CNN model.""" +class MyNet(nn.Module): + """Smallest network possible. + Used to run the test. + """ @nn.compact def __call__(self, x): - x = nn.Conv(features=6, kernel_size=(3, 3), padding="SAME")(x) - x = nn.relu(x) - x = nn.Conv(features=16, kernel_size=(3, 3), padding="SAME")(x) - x = nn.relu(x) - x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nn.Conv(features=16, kernel_size=(3, 3), padding="SAME")(x) + batch_size = x.shape[0] + # downsample the image to 7x7 to save some computation + x = x[:, ::4, ::4, 0] / 255.0 + x = x.reshape(batch_size, -1) + x = nn.Dense(16)(x) x = nn.relu(x) - x = nn.Conv(features=16, kernel_size=(3, 3), padding="SAME")(x) - x = nn.relu(x) - x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) - - x = x.reshape(x.shape[0], -1) - x = nn.Dense(120)(x) - x = nn.sigmoid(x) x = nn.Dense(10)(x) - return x - - -def init_problem_and_model(key): - model = SimpleCNN() - batch_size = 64 - initial_params = model.init(key, jnp.zeros((batch_size, 32, 32, 3))) - problem = problems.neuroevolution.TorchvisionDataset( - root="./datasets", - batch_size=batch_size, - forward_func=model.apply, - dataset_name="cifar10", - ) - return initial_params, problem + return jax.nn.softmax(x) -@pytest.mark.skip(reason="time consuming") -def test_neuroevolution_treemap(): +def test_tfds(): + BATCH_SIZE = 8 key = jax.random.PRNGKey(42) - workflow_key, model_init_key = jax.random.split(key) - - initial_params, problem = init_problem_and_model(model_init_key) - - start = time.perf_counter() - center_init = jax.tree_util.tree_map( - lambda x: x.reshape(-1), - initial_params, - ) - monitor = StdSOMonitor() - workflow = workflows.StdWorkflow( - algorithm=Algorithms.TreeAlgorithm(PartialPGPE, initial_params, center_init), - problem=problem, - monitors=[monitor], - ) - # init the workflow - state = workflow.init(workflow_key) - - # run the workflow for 100 steps - for i in range(100): - state = workflow.step(state) - - # the result should be close to 0 - min_fitness = monitor.get_best_fitness() - print(f"Treemap loss: {min_fitness} time: {time.perf_counter() - start}") + model_key, workflow_key = jax.random.split(key) + model = MyNet() + params = model.init(model_key, jnp.zeros((BATCH_SIZE, 28, 28, 1))) -@pytest.mark.skip(reason="time consuming") -def test_neuroevolution_adapter(): - key = jax.random.PRNGKey(42) - workflow_key, model_init_key = jax.random.split(key) - initial_params, problem = init_problem_and_model(model_init_key) + @jax.jit + def loss_func(weight, data): + # a very bad loss function + images, labels = data["image"], data["label"] + outputs = model.apply(weight, images) + labels = jax.nn.one_hot(labels, 10) + return jnp.mean((outputs - labels) ** 2) - start = time.perf_counter() - adapter = utils.TreeAndVector(initial_params) - monitor = StdSOMonitor() - algorithm = algorithms.PGPE( - 100, - adapter.to_vector(initial_params), - "adam", - center_learning_rate=0.01, - stdev_init=0.01, + problem = problems.neuroevolution.TensorflowDataset( + dataset="fashion_mnist", batch_size=BATCH_SIZE, loss_func=loss_func ) + adapter = TreeAndVector(params) + monitor = EvalMonitor() + + center = adapter.to_vector(params) + # create a workflow workflow = workflows.StdWorkflow( - algorithm=algorithm, + algorithm=algorithms.PGPE( + optimizer="adam", + center_init=center, + pop_size=64, + stdev_init=0.1, + ), problem=problem, sol_transforms=[adapter.batched_to_tree], + fit_transforms=[rank_based_fitness], monitors=[monitor], ) # init the workflow - state = workflow.init(key) - - # run the workflow for 100 steps - for i in range(100): + state = workflow.init(workflow_key) + for i in range(3): state = workflow.step(state) - # the result should be close to 0 - min_fitness = monitor.get_best_fitness() - print(f"Adapter loss: {min_fitness} time: {time.perf_counter() - start}") + best_fitness = monitor.get_best_fitness().item() + assert math.isclose(best_fitness, 0.07662, abs_tol=0.01)