diff --git a/.gitignore b/.gitignore index 73e97780..ab5a30fe 100644 --- a/.gitignore +++ b/.gitignore @@ -151,3 +151,4 @@ compare.py checkpoints/ docs/examples/basic/ examples/test_save/ +tests/.datasets/occupancy_data.zip diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 3c756ff2..abe36978 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -521,7 +521,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: ) return DataLoader( dataset, - batch_size if batch_size is not None else self.batch_size, + batch_size or self.batch_size, shuffle=True if self.train_sampler is None else False, num_workers=self.config.num_workers, sampler=self.train_sampler, @@ -547,7 +547,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: ) return DataLoader( dataset, - batch_size if batch_size is not None else self.batch_size, + batch_size or self.batch_size, shuffle=False, num_workers=self.config.num_workers, pin_memory=self.config.pin_memory, @@ -574,7 +574,7 @@ def test_dataloader(self, batch_size: Optional[int] = None) -> DataLoader: ) return DataLoader( dataset, - batch_size if batch_size is not None else self.batch_size, + batch_size or self.batch_size, shuffle=False, num_workers=self.config.num_workers, pin_memory=self.config.pin_memory, @@ -613,7 +613,7 @@ def prepare_inference_dataloader(self, df: pd.DataFrame, batch_size: Optional[in ) return DataLoader( dataset, - batch_size if batch_size is not None else self.batch_size, + batch_size or self.batch_size, shuffle=False, num_workers=self.config.num_workers, ) diff --git a/src/pytorch_tabular/tabular_model.py b/src/pytorch_tabular/tabular_model.py index ddbfa66f..6a98b8ef 100644 --- a/src/pytorch_tabular/tabular_model.py +++ b/src/pytorch_tabular/tabular_model.py @@ -667,8 +667,9 @@ def fit( assert len(metrics) == len( metrics_prob_inputs ), "The length of `metrics` and `metrics_prob_inputs` should be equal" - seed = seed if seed is not None else self.config.seed - seed_everything(seed) + seed = seed or self.config.seed + if seed: + seed_everything(seed) if datamodule is None: datamodule = self.prepare_dataloader(train, validation, test, train_sampler, target_transform, seed) else: @@ -738,8 +739,9 @@ def pretrain( assert ( self.config.task == "ssl" ), f"`pretrain` is not valid for {self.config.task} task. Please use `fit` instead." - seed = seed if seed is not None else self.config.seed - seed_everything(seed) + seed = seed or self.config.seed + if seed: + seed_everything(seed) if datamodule is None: datamodule = self.prepare_dataloader( train, @@ -973,8 +975,9 @@ def finetune( assert ( self._is_finetune_model ), "finetune() can only be called on a finetune model created using `TabularModel.create_finetune_model()`" - seed = seed if seed is not None else self.config.seed - seed_everything(seed) + seed = seed or self.config.seed + if seed: + seed_everything(seed) if datamodule is None: target_transform = self._check_and_set_target_transform(target_transform) self.datamodule._set_target_transform(target_transform) diff --git a/src/pytorch_tabular/utils/nn_utils.py b/src/pytorch_tabular/utils/nn_utils.py index 524a4937..5eecf0c1 100644 --- a/src/pytorch_tabular/utils/nn_utils.py +++ b/src/pytorch_tabular/utils/nn_utils.py @@ -82,7 +82,7 @@ def to_one_hot(y, depth=None): depth (int): the size of the one hot dimension """ y_flat = y.to(torch.int64).view(-1, 1) - depth = depth if depth is not None else int(torch.max(y_flat)) + 1 + depth = depth or int(torch.max(y_flat)) + 1 y_one_hot = torch.zeros(y_flat.size()[0], depth, device=y.device).scatter_(1, y_flat, 1) y_one_hot = y_one_hot.view(*(tuple(y.shape) + (-1,))) return y_one_hot diff --git a/tests/conftest.py b/tests/conftest.py index 69f0e07a..09167a0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ import urllib.request urllib.request.urlretrieve( - "https://archive.ics.uci.edu/ml/machine-learning-databases/00357/occupancy_data.zip", DATASET_ZIP_OCCUPANCY + "http://archive.ics.uci.edu/ml/machine-learning-databases/00357/occupancy_data.zip", DATASET_ZIP_OCCUPANCY )