-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #134 from EMI-Group/tfds
Introduce Tensorflow Dataset
- Loading branch information
Showing
14 changed files
with
236 additions
and
456 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,4 +5,4 @@ Supervised Learning | |
.. toctree:: | ||
:maxdepth: 1 | ||
|
||
torchvision | ||
tfds |
6 changes: 6 additions & 0 deletions
6
docs/source/api/problems/neuroevolution/supervised_learning/tfds.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
================== | ||
Tensorflow Dataset | ||
================== | ||
|
||
.. autoclass:: evox.problems.neuroevolution.TensorflowDataset | ||
:members: |
6 changes: 0 additions & 6 deletions
6
docs/source/api/problems/neuroevolution/supervised_learning/torchvision.rst
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,8 +57,8 @@ | |
"envpool", | ||
"gymnasium", | ||
"ray", | ||
"torch", | ||
"torchvision", | ||
"tensorflow_datasets", | ||
"grain", | ||
"gpjax", | ||
] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 3 additions & 3 deletions
6
src/evox/problems/neuroevolution/supervised_learning/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
try: | ||
# optional dependency: torchvision, optax | ||
from .torchvision_dataset import TorchvisionDataset | ||
from .tfds import TensorflowDataset | ||
except ImportError as e: | ||
original_error_msg = str(e) | ||
|
||
def TorchvisionDataset(*args, **kwargs): | ||
def TensorflowDataset(*args, **kwargs): | ||
raise ImportError( | ||
f'TorchvisionDataset requires torchvision, optax but got "{original_error_msg}" when importing' | ||
f'TensorflowDataset requires tensorflow-datasets, grain but got "{original_error_msg}" when importing' | ||
) |
135 changes: 135 additions & 0 deletions
135
src/evox/problems/neuroevolution/supervised_learning/tfds.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from dataclasses import field | ||
from typing import Any, Callable, List, Optional | ||
|
||
import grain.python as pygrain | ||
import jax | ||
import jax.numpy as jnp | ||
import tensorflow_datasets as tfds | ||
from jax.tree_util import tree_map | ||
|
||
from evox import Problem, Static, dataclass, jit_class | ||
from evox.utils.io import x32_func_call | ||
|
||
|
||
def get_dtype_shape(data): | ||
def to_dtype_struct(x): | ||
if hasattr(x, "shape") and hasattr(x, "dtype"): | ||
return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype) | ||
elif isinstance(x, int): | ||
return jax.ShapeDtypeStruct(shape=(), dtype=jnp.int32) | ||
elif isinstance(x, float): | ||
return jax.ShapeDtypeStruct(shape=(), dtype=jnp.float32) | ||
|
||
return tree_map(to_dtype_struct, data) | ||
|
||
|
||
@jit_class | ||
@dataclass | ||
class TensorflowDataset(Problem): | ||
"""Wrap a tensorflow dataset as a problem. | ||
TensorFlow Datasets (TFDS) directly depends on the package `tensorflow-datasets` and `grain`. | ||
Additionally, when downloading the dataset for the first time, it requires `tensorflow` to be installed and a active internet connection. | ||
If you want to avoid installing `tensorflow`, you can prepare the dataset beforehand in another environment with `tensorflow` installed, | ||
run: | ||
.. code-block:: python | ||
import tensorflow_datasets as tfds | ||
tfds.data_source(self.dataset) | ||
and then copy the dataset to the target machine. | ||
The default location is`~/tensorflow_datasets`. `~/` means the home directory of the user. | ||
Please notice that the data is loaded under JAX's jit context, so the data should be valid JAX data type, | ||
namely JAX or Numpy arrays, or Python's int, float, list, and dict. | ||
If the data contains other types like strings, you should convert them into arrays using the `operations` parameter. | ||
You can also download the dataset through a proxy server by setting the environment variable `TFDS_HTTP_PROXY` and `TFDS_HTTPS_PROXY`, | ||
for http and https proxy respectively. | ||
The details of the dataset can be found at https://www.tensorflow.org/datasets/catalog/overview | ||
The details about operations/transformations can be found at https://github.com/google/grain/blob/main/docs/transformations.md | ||
Parameters | ||
---------- | ||
dataset | ||
The dataset name. | ||
batch_size | ||
The batch size. | ||
loss_func | ||
The loss function. | ||
The function signature is loss(weights, data) -> loss_value, and it should be jittable. | ||
The `weight` is the weight of the neural network, and the `data` is the data from TFDS, which is a dictionary. | ||
split | ||
Which split of the dataset to use. | ||
Default to "train". | ||
operations | ||
The list of transformations to apply to the data. | ||
Default to []. | ||
After the transformations, we will always apply a batch operation to create a batch of data. | ||
datadir | ||
The directory to store the dataset. | ||
Default to None, which means tensorflow-datasets will automatically determine the directory. | ||
seed | ||
The random seed used to seed the dataloader. | ||
Given the same seed, the dataloader should data in the same order. | ||
Default to 0. | ||
try_gcs | ||
Whether to try to download the dataset from Google Cloud Storage. | ||
Usually Google's storage server is faster than the original server of the dataset. | ||
""" | ||
|
||
dataset: Static[str] | ||
batch_size: Static[int] | ||
loss_func: Static[Callable] | ||
split: Static[str] = field(default="train") | ||
operations: Static[List[Any]] = field(default_factory=list) | ||
datadir: Static[Optional[str]] = field(default=None) | ||
seed: Static[int] = field(default=0) | ||
try_gcs: Static[bool] = field(default=True) | ||
iterator: Static[pygrain.PyGrainDatasetIterator] = field(init=False) | ||
data_shape_dtypes: Static[Any] = field(init=False) | ||
|
||
def __post_init__(self): | ||
if self.datadir is None: | ||
data_source = tfds.data_source( | ||
self.dataset, try_gcs=self.try_gcs, split=self.split | ||
) | ||
else: | ||
data_source = tfds.data_source( | ||
self.dataset, | ||
data_dir=self.datadir, | ||
try_gcs=self.try_gcs, | ||
split=self.split, | ||
) | ||
|
||
sampler = pygrain.IndexSampler( | ||
num_records=len(data_source), | ||
shard_options=pygrain.NoSharding(), | ||
shuffle=True, | ||
seed=self.seed, | ||
) | ||
|
||
operations = self.operations + [ | ||
pygrain.Batch(batch_size=self.batch_size, drop_remainder=True) | ||
] | ||
|
||
loader = pygrain.DataLoader( | ||
data_source=data_source, | ||
operations=operations, | ||
sampler=sampler, | ||
worker_count=0, | ||
) | ||
object.__setattr__(self, "iterator", iter(loader)) | ||
data_shape_dtypes = get_dtype_shape(self._next_data()) | ||
object.__setattr__(self, "data_shape_dtypes", data_shape_dtypes) | ||
|
||
@x32_func_call | ||
def _next_data(self): | ||
return next(self.iterator) | ||
|
||
def evaluate(self, state, pop): | ||
data = jax.experimental.io_callback(self._next_data, self.data_shape_dtypes) | ||
loss = jax.vmap(self.loss_func, in_axes=(0, None))(pop, data) | ||
return loss, state |
Oops, something went wrong.