Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Commit

Permalink
adding a gradient_compliant flag to StandardScaler
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Feb 23, 2023
1 parent 15fa2e8 commit f120d35
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
13 changes: 12 additions & 1 deletion src/equisolve/numpy/preprocessing/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ class StandardScaler:
:param atol:
The relative tolerance for the optimization: variance is
considered zero when it is less than abs(mean) * rtol + atol.
:param gradient_compliant: Sets if the operations in this transform
keep the values and gradients consistent
∇transform(X) are just applied on the
gradients as on the values transform(∇X).
Effectively, by setting this flag to False
the gradients are also centered,
otherwise only the values are centered
(assuming the with_mean flag is True).
It is recommended to leave this flag to True.
"""

def __init__(
Expand All @@ -39,6 +48,7 @@ def __init__(
column_wise: bool = False,
rtol: float = 0.0,
atol: float = 1e-12,
gradient_compliant: bool = True,
):
if type(parameter_keys) not in (list, tuple, np.ndarray):
self.parameter_keys = [parameter_keys]
Expand All @@ -50,6 +60,7 @@ def __init__(
self.column_wise = column_wise
self.rtol = rtol
self.atol = atol
self.gradient_compliant = gradient_compliant

def _validate_data(self, X: TensorMap, y: TensorMap = None):
"""Validates :class:`equistore.TensorBlock`'s for the usage in models.
Expand Down Expand Up @@ -183,7 +194,7 @@ def fit(
if sample_weights is not None:
sw_block = sample_weights.block(key)
sample_weights = block_to_array(sw_block, [parameter])
if self.with_mean:
if self.with_mean and not (self.gradient_compliant):
mean_values = np.average(X_mat, weights=sample_weights, axis=0)
mean_values = mean_values.reshape((1, 1) + mean_values.shape)
else:
Expand Down
12 changes: 9 additions & 3 deletions tests/numpy/preprocessing/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ class TestStandardScaler:

@pytest.mark.parametrize("with_mean", [True, False])
@pytest.mark.parametrize("with_std", [True, False])
def test_standard_scaler_transform(self, with_mean, with_std):
@pytest.mark.parametrize("gradient_compliant", [True, False])
def test_standard_scaler_transform(self, with_mean, with_std, gradient_compliant):
st = StandardScaler(
parameter_keys=["values", "positions"],
with_mean=with_mean,
with_std=with_std,
column_wise=False,
gradient_compliant=gradient_compliant,
).fit(self.X)
X_t = st.transform(self.X)

Expand All @@ -91,7 +93,7 @@ def test_standard_scaler_transform(self, with_mean, with_std):

for _, X_grad in X_t.block().gradients():
X_grad = X_grad.data.reshape(-1, X_grad.data.shape[-1])
if with_mean:
if with_mean and not (gradient_compliant):
assert_allclose(np.mean(X_grad, axis=0), 0, atol=1e-14, rtol=1e-14)
if with_std:
assert_allclose(
Expand All @@ -103,12 +105,16 @@ def test_standard_scaler_transform(self, with_mean, with_std):

@pytest.mark.parametrize("with_mean", [True, False])
@pytest.mark.parametrize("with_std", [True, False])
def test_standard_scaler_inverse_transform(self, with_mean, with_std):
@pytest.mark.parametrize("gradient_compliant", [True, False])
def test_standard_scaler_inverse_transform(
self, with_mean, with_std, gradient_compliant
):
st = StandardScaler(
parameter_keys=["values", "positions"],
with_mean=with_mean,
with_std=with_std,
column_wise=False,
gradient_compliant=gradient_compliant,
).fit(self.X)
X_t_inv_t = st.inverse_transform(st.transform(self.X))

Expand Down

0 comments on commit f120d35

Please sign in to comment.