Skip to content

Commit

Permalink
make synthetic_data.AbstractRandomDataset return generic types (#700)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #700

Makes it possible to use AbstractRandomDataset in more test cases

Reviewed By: JKSenthil

Differential Revision: D53532958

fbshipit-source-id: d6e397795a201b402abacfb6fe81f8353bc7fda1
  • Loading branch information
gunchu authored and facebook-github-bot committed Feb 8, 2024
1 parent 5295f82 commit d3ef20a
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions torchtnt/utils/data/synthetic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import abc
import logging
from dataclasses import dataclass, field
from typing import Dict
from typing import Generic, TypeVar

import torch
from torch.utils.data import Dataset
Expand All @@ -19,15 +19,18 @@

logger: logging.Logger = logging.getLogger(__name__)

TItem = TypeVar("TItem")


@dataclass
class AbstractRandomDataset(Dataset, abc.ABC):
class AbstractRandomDataset(Dataset, abc.ABC, Generic[TItem]):
"""
An abstract base class for random datasets.
Intended for subclassing, this class provides the framework for implementing
custom random datasets. Each subclass should provide
a concrete implementation of the `_generate_random_item` method.
custom random datasets. Each subclass should provide a concrete implementation
of the `_generate_random_item` method that produces a single random dataset
item of type `TItem`.
Attributes:
size (int, default=100): The total number of items the dataset will contain.
Expand All @@ -49,16 +52,15 @@ def __len__(self) -> int:
"""
return self.size

def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
def __getitem__(self, idx: int) -> TItem:
"""
Fetch a dataset item by index.
Args:
idx (int): Index of the desired dataset item.
Returns:
Dict[str, torch.Tensor]: Dictionary containing dataset attributes, primarily tensors.
The exact keys depend on the implementation of `_generate_random_item`.
TItem: A single random item of type `TItem`.
Raises:
IndexError: If the provided index is out of valid range.
Expand All @@ -69,15 +71,14 @@ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
raise IndexError(f"Index {idx} out of range [0, {self.size-1}]")

@abc.abstractmethod
def _generate_random_item(self) -> Dict[str, torch.Tensor]:
def _generate_random_item(self) -> TItem:
"""
Abstract method to produce a random dataset item.
Subclasses must override this to define their specific random item generation.
Returns:
Dict[str, torch.Tensor]: Dictionary containing item attributes and tensors.
The exact keys depend on the implementation of `_generate_random_item`.
TItem: A single random item of type `TItem`.
"""
raise NotImplementedError(
"Subclasses of AbstractRandomDataset should implement _generate_random_item."
Expand Down

0 comments on commit d3ef20a

Please sign in to comment.