From c150c3118601817334b5c83e444136722f139aaf Mon Sep 17 00:00:00 2001 From: Bill Huang Date: Tue, 23 Apr 2024 17:18:00 +0800 Subject: [PATCH] dev: add try_gcs, split, and datadir parameter --- .../supervised_learning/tfds.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/src/evox/problems/neuroevolution/supervised_learning/tfds.py b/src/evox/problems/neuroevolution/supervised_learning/tfds.py index 6da0eec6..dab3eec9 100644 --- a/src/evox/problems/neuroevolution/supervised_learning/tfds.py +++ b/src/evox/problems/neuroevolution/supervised_learning/tfds.py @@ -1,5 +1,5 @@ from dataclasses import field -from typing import Any, Callable, List +from typing import Any, Callable, List, Optional import grain.python as pygrain import jax @@ -45,6 +45,9 @@ class TensorflowDataset(Problem): namely JAX or Numpy arrays, or Python's int, float, list, and dict. If the data contains other types like strings, you should convert them into arrays using the `operations` parameter. + You can also download the dataset through a proxy server by setting the environment variable `TFDS_HTTP_PROXY` and `TFDS_HTTPS_PROXY`, + for http and https proxy respectively. + The details of the dataset can be found at https://www.tensorflow.org/datasets/catalog/overview The details about operations/transformations can be found at https://github.com/google/grain/blob/main/docs/transformations.md @@ -58,26 +61,49 @@ class TensorflowDataset(Problem): The loss function. The function signature is loss(weights, data) -> loss_value, and it should be jittable. The `weight` is the weight of the neural network, and the `data` is the data from TFDS, which is a dictionary. + split + Which split of the dataset to use. + Default to "train". operations The list of transformations to apply to the data. Default to []. After the transformations, we will always apply a batch operation to create a batch of data. + datadir + The directory to store the dataset. + Default to None, which means tensorflow-datasets will automatically determine the directory. seed The random seed used to seed the dataloader. Given the same seed, the dataloader should data in the same order. Default to 0. + try_gcs + Whether to try to download the dataset from Google Cloud Storage. + Usually Google's storage server is faster than the original server of the dataset. """ dataset: Static[str] batch_size: Static[int] loss_func: Static[Callable] + split: Static[str] = field(default="train") operations: Static[List[Any]] = field(default_factory=list) + datadir: Static[Optional[str]] = field(default=None) seed: Static[int] = field(default=0) + try_gcs: Static[bool] = field(default=True) iterator: Static[pygrain.PyGrainDatasetIterator] = field(init=False) data_shape_dtypes: Static[Any] = field(init=False) def __post_init__(self): - data_source = tfds.data_source(self.dataset, split="train") + if self.datadir is None: + data_source = tfds.data_source( + self.dataset, try_gcs=self.try_gcs, split=self.split + ) + else: + data_source = tfds.data_source( + self.dataset, + data_dir=self.datadir, + try_gcs=self.try_gcs, + split=self.split, + ) + sampler = pygrain.IndexSampler( num_records=len(data_source), shard_options=pygrain.NoSharding(),