Skip to content

Commit

Permalink
Merge branch 'main' into precommit/prettier
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Nov 15, 2023
2 parents 019a469 + b244fb8 commit 838d20b
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,4 @@ compare.py
checkpoints/
docs/examples/basic/
examples/test_save/
tests/.datasets/occupancy_data.zip
8 changes: 4 additions & 4 deletions src/pytorch_tabular/tabular_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
)
return DataLoader(
dataset,
batch_size if batch_size is not None else self.batch_size,
batch_size or self.batch_size,
shuffle=True if self.train_sampler is None else False,
num_workers=self.config.num_workers,
sampler=self.train_sampler,
Expand All @@ -547,7 +547,7 @@ def val_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
)
return DataLoader(
dataset,
batch_size if batch_size is not None else self.batch_size,
batch_size or self.batch_size,
shuffle=False,
num_workers=self.config.num_workers,
pin_memory=self.config.pin_memory,
Expand All @@ -574,7 +574,7 @@ def test_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:
)
return DataLoader(
dataset,
batch_size if batch_size is not None else self.batch_size,
batch_size or self.batch_size,
shuffle=False,
num_workers=self.config.num_workers,
pin_memory=self.config.pin_memory,
Expand Down Expand Up @@ -613,7 +613,7 @@ def prepare_inference_dataloader(self, df: pd.DataFrame, batch_size: Optional[in
)
return DataLoader(
dataset,
batch_size if batch_size is not None else self.batch_size,
batch_size or self.batch_size,
shuffle=False,
num_workers=self.config.num_workers,
)
Expand Down
15 changes: 9 additions & 6 deletions src/pytorch_tabular/tabular_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,9 @@ def fit(
assert len(metrics) == len(
metrics_prob_inputs
), "The length of `metrics` and `metrics_prob_inputs` should be equal"
seed = seed if seed is not None else self.config.seed
seed_everything(seed)
seed = seed or self.config.seed
if seed:
seed_everything(seed)
if datamodule is None:
datamodule = self.prepare_dataloader(train, validation, test, train_sampler, target_transform, seed)
else:
Expand Down Expand Up @@ -738,8 +739,9 @@ def pretrain(
assert (
self.config.task == "ssl"
), f"`pretrain` is not valid for {self.config.task} task. Please use `fit` instead."
seed = seed if seed is not None else self.config.seed
seed_everything(seed)
seed = seed or self.config.seed
if seed:
seed_everything(seed)
if datamodule is None:
datamodule = self.prepare_dataloader(
train,
Expand Down Expand Up @@ -973,8 +975,9 @@ def finetune(
assert (
self._is_finetune_model
), "finetune() can only be called on a finetune model created using `TabularModel.create_finetune_model()`"
seed = seed if seed is not None else self.config.seed
seed_everything(seed)
seed = seed or self.config.seed
if seed:
seed_everything(seed)
if datamodule is None:
target_transform = self._check_and_set_target_transform(target_transform)
self.datamodule._set_target_transform(target_transform)
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_tabular/utils/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def to_one_hot(y, depth=None):
depth (int): the size of the one hot dimension
"""
y_flat = y.to(torch.int64).view(-1, 1)
depth = depth if depth is not None else int(torch.max(y_flat)) + 1
depth = depth or int(torch.max(y_flat)) + 1
y_one_hot = torch.zeros(y_flat.size()[0], depth, device=y.device).scatter_(1, y_flat, 1)
y_one_hot = y_one_hot.view(*(tuple(y.shape) + (-1,)))
return y_one_hot
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import urllib.request

urllib.request.urlretrieve(
"https://archive.ics.uci.edu/ml/machine-learning-databases/00357/occupancy_data.zip", DATASET_ZIP_OCCUPANCY
"http://archive.ics.uci.edu/ml/machine-learning-databases/00357/occupancy_data.zip", DATASET_ZIP_OCCUPANCY
)


Expand Down

0 comments on commit 838d20b

Please sign in to comment.