diff --git a/src/equisolve/numpy/preprocessing/_base.py b/src/equisolve/numpy/preprocessing/_base.py index 7e17a01..6066776 100644 --- a/src/equisolve/numpy/preprocessing/_base.py +++ b/src/equisolve/numpy/preprocessing/_base.py @@ -29,6 +29,15 @@ class StandardScaler: :param atol: The relative tolerance for the optimization: variance is considered zero when it is less than abs(mean) * rtol + atol. + :param gradient_compliant: Sets if the operations in this transform + keep the values and gradients consistent + ∇transform(X) are just applied on the + gradients as on the values transform(∇X). + Effectively, by setting this flag to False + the gradients are also centered, + otherwise only the values are centered + (assuming the with_mean flag is True). + It is recommended to leave this flag to True. """ def __init__( @@ -39,6 +48,7 @@ def __init__( column_wise: bool = False, rtol: float = 0.0, atol: float = 1e-12, + gradient_compliant: bool = True, ): if type(parameter_keys) not in (list, tuple, np.ndarray): self.parameter_keys = [parameter_keys] @@ -50,6 +60,7 @@ def __init__( self.column_wise = column_wise self.rtol = rtol self.atol = atol + self.gradient_compliant = gradient_compliant def _validate_data(self, X: TensorMap, y: TensorMap = None): """Validates :class:`equistore.TensorBlock`'s for the usage in models. @@ -183,7 +194,7 @@ def fit( if sample_weights is not None: sw_block = sample_weights.block(key) sample_weights = block_to_array(sw_block, [parameter]) - if self.with_mean: + if self.with_mean and not (self.gradient_compliant): mean_values = np.average(X_mat, weights=sample_weights, axis=0) mean_values = mean_values.reshape((1, 1) + mean_values.shape) else: diff --git a/tests/numpy/preprocessing/test_base.py b/tests/numpy/preprocessing/test_base.py index 450a744..973364d 100644 --- a/tests/numpy/preprocessing/test_base.py +++ b/tests/numpy/preprocessing/test_base.py @@ -64,12 +64,14 @@ class TestStandardScaler: @pytest.mark.parametrize("with_mean", [True, False]) @pytest.mark.parametrize("with_std", [True, False]) - def test_standard_scaler_transform(self, with_mean, with_std): + @pytest.mark.parametrize("gradient_compliant", [True, False]) + def test_standard_scaler_transform(self, with_mean, with_std, gradient_compliant): st = StandardScaler( parameter_keys=["values", "positions"], with_mean=with_mean, with_std=with_std, column_wise=False, + gradient_compliant=gradient_compliant, ).fit(self.X) X_t = st.transform(self.X) @@ -91,7 +93,7 @@ def test_standard_scaler_transform(self, with_mean, with_std): for _, X_grad in X_t.block().gradients(): X_grad = X_grad.data.reshape(-1, X_grad.data.shape[-1]) - if with_mean: + if with_mean and not (gradient_compliant): assert_allclose(np.mean(X_grad, axis=0), 0, atol=1e-14, rtol=1e-14) if with_std: assert_allclose( @@ -103,12 +105,16 @@ def test_standard_scaler_transform(self, with_mean, with_std): @pytest.mark.parametrize("with_mean", [True, False]) @pytest.mark.parametrize("with_std", [True, False]) - def test_standard_scaler_inverse_transform(self, with_mean, with_std): + @pytest.mark.parametrize("gradient_compliant", [True, False]) + def test_standard_scaler_inverse_transform( + self, with_mean, with_std, gradient_compliant + ): st = StandardScaler( parameter_keys=["values", "positions"], with_mean=with_mean, with_std=with_std, column_wise=False, + gradient_compliant=gradient_compliant, ).fit(self.X) X_t_inv_t = st.inverse_transform(st.transform(self.X))