diff --git a/rl4co/models/rl/common/base.py b/rl4co/models/rl/common/base.py index b025565b..9fa7bc03 100644 --- a/rl4co/models/rl/common/base.py +++ b/rl4co/models/rl/common/base.py @@ -36,7 +36,7 @@ class RL4COLitModule(LightningModule): lr_scheduler_interval: learning rate scheduler interval lr_scheduler_monitor: learning rate scheduler monitor generate_default_data: whether to generate default datasets, filling up the data directory - shuffle_train_dataloader: whether to shuffle training dataloader + shuffle_train_dataloader: whether to shuffle training dataloader. Default is False since we recreate dataset every epoch dataloader_num_workers: number of workers for dataloader data_dir: data directory metrics: metrics @@ -50,7 +50,7 @@ def __init__( batch_size: int = 512, val_batch_size: int = None, test_batch_size: int = None, - train_data_size: int = 1_280_000, + train_data_size: int = 100_000, val_data_size: int = 10_000, test_data_size: int = 10_000, optimizer: Union[str, torch.optim.Optimizer, partial] = "Adam", @@ -63,7 +63,7 @@ def __init__( lr_scheduler_interval: str = "epoch", lr_scheduler_monitor: str = "val/reward", generate_default_data: bool = False, - shuffle_train_dataloader: bool = True, + shuffle_train_dataloader: bool = False, dataloader_num_workers: int = 0, data_dir: str = "data/", log_on_step: bool = True, @@ -278,8 +278,12 @@ def on_train_epoch_end(self): """Called at the end of the training epoch. This can be used for instance to update the train dataset with new data (which is the case in RL). """ - train_dataset = self.env.dataset(self.data_cfg["train_data_size"], "train") - self.train_dataset = self.wrap_dataset(train_dataset) + # Only update if not in the first epoch + # If last epoch, we don't need to update since we will not use the dataset anymore + if self.current_epoch < self.trainer.max_epochs - 1: + log.info("Generating training dataset for next epoch...") + train_dataset = self.env.dataset(self.data_cfg["train_data_size"], "train") + self.train_dataset = self.wrap_dataset(train_dataset) def wrap_dataset(self, dataset): """Wrap dataset with policy-specific wrapper. This is useful i.e. in REINFORCE where we need to