diff --git a/src/pytorch_tabular/categorical_encoders.py b/src/pytorch_tabular/categorical_encoders.py index 62effb0a..35b771fe 100644 --- a/src/pytorch_tabular/categorical_encoders.py +++ b/src/pytorch_tabular/categorical_encoders.py @@ -62,8 +62,8 @@ def transform(self, X): not X[self.cols].isnull().any().any() ), "`handle_missing` = `error` and missing values found in columns to encode." X_encoded = X.copy(deep=True) - category_cols = X_encoded.select_dtypes(include='category').columns - X_encoded[category_cols] = X_encoded[category_cols].astype('object') + category_cols = X_encoded.select_dtypes(include="category").columns + X_encoded[category_cols] = X_encoded[category_cols].astype("object") for col, mapping in self._mapping.items(): X_encoded[col] = X_encoded[col].fillna(NAN_CATEGORY).map(mapping["value"]) @@ -269,4 +269,4 @@ def save_as_object_file(self, path): def load_from_object_file(self, path): for k, v in pickle.load(open(path, "rb")).items(): - setattr(self, k, v) \ No newline at end of file + setattr(self, k, v) diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index d6ecf34f..81f13a69 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -301,14 +301,14 @@ def _update_config(self, config) -> InferredConfig: else: raise ValueError(f"{config.task} is an unsupported task.") if self.train is not None: - category_cols = self.train[config.categorical_cols].select_dtypes(include='category').columns - self.train[category_cols] = self.train[category_cols].astype('object') + category_cols = self.train[config.categorical_cols].select_dtypes(include="category").columns + self.train[category_cols] = self.train[category_cols].astype("object") categorical_cardinality = [ int(x) + 1 for x in list(self.train[config.categorical_cols].fillna("NA").nunique().values) ] else: - category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include='category').columns - self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype('object') + category_cols = self.train_dataset.data[config.categorical_cols].select_dtypes(include="category").columns + self.train_dataset.data[category_cols] = self.train_dataset.data[category_cols].astype("object") categorical_cardinality = [ int(x) + 1 for x in list(self.train_dataset.data[config.categorical_cols].nunique().values) ]