Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 8, 2024
1 parent d4962c8 commit dfa4844
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions src/pytorch_tabular/categorical_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def transform(self, X):
not X[self.cols].isnull().any().any()
), "`handle_missing` = `error` and missing values found in columns to encode."
X_encoded = X.copy(deep=True)
category_cols = X_encoded.select_dtypes(include='category').columns
X_encoded[category_cols] = X_encoded[category_cols].astype('object')
category_cols = X_encoded.select_dtypes(include="category").columns
X_encoded[category_cols] = X_encoded[category_cols].astype("object")
for col, mapping in self._mapping.items():
X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"])

Expand Down Expand Up @@ -269,4 +269,4 @@ def save_as_object_file(self, path):

def load_from_object_file(self, path):
for k, v in pickle.load(open(path, "rb")).items():
setattr(self, k, v)
setattr(self, k, v)
8 changes: 4 additions & 4 deletions src/pytorch_tabular/tabular_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,14 @@ def _update_config(self, config) -> InferredConfig:
else:
raise ValueError(f"{config.task} is an unsupported task.")
if self.train is not None:
category_cols = self.train[config.categorical_cols].select_dtypes(include='category').columns
self.train[category_cols] = self.train[category_cols].astype('object')
category_cols = self.train[config.categorical_cols].select_dtypes(include="category").columns
self.train[category_cols] = self.train[category_cols].astype("object")
categorical_cardinality = [
int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values)
]
else:
category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include='category').columns
self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype('object')
category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include="category").columns
self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype("object")
categorical_cardinality = [
int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values)
]
Expand Down

0 comments on commit dfa4844

Please sign in to comment.