Skip to content

Commit

Permalink
[FIX] Pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
lechevaa committed Sep 27, 2024
1 parent 273939c commit db3c8a9
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions tests/test_scaler_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,33 +135,34 @@ def test_minmaxscaler_with_multiple_features_and_samples(input_shape):
pred = model(data)
# Check if model is identity
assert np.allclose(pred, data, rtol=1e-3)
# Check if all features have their own min and max. Data is shape (Ndata, Nfeatures) and min is shape (1, Nfeatures)
# Check if all features have their own min and max.
# Data is shape (Ndata, Nfeatures) and min is shape (1, Nfeatures)
assert scalerlayer.data_min.shape[1] == input_shape[1]
assert scalerlayer.data_max.shape[1] == input_shape[1]


@pytest.mark.parametrize(
"feature_ranges",
"feature_range",
[[(0.0, 1.0), (2.0, 3.0), (4.0, 5.0)], [(-1.0, 0.0), (-2.0, 0)], [(0.0, 1.0)]],
)
def test_minmaxscaler_with_multiple_features_ranges(feature_ranges):
data = [rng.uniform(-500, 500, (10, len(feature_ranges)))]
def test_minmaxscaler_with_multiple_features_ranges(feature_range):
data = [rng.uniform(-500, 500, (10, len(feature_range)))]

scalerlayer = MinMaxScalerLayer(feature_range=feature_ranges)
unscalerlayer = MinMaxUnScalerLayer(feature_range=feature_ranges)
scalerlayer = MinMaxScalerLayer(feature_range=feature_range)
unscalerlayer = MinMaxUnScalerLayer(feature_range=feature_range)

scalerlayer.adapt(data)
unscalerlayer.adapt(data)

model = keras.Sequential(
[keras.layers.Input([len(feature_ranges)]), scalerlayer, unscalerlayer]
[keras.layers.Input([len(feature_range)]), scalerlayer, unscalerlayer]
)

pred = model(data)
# Check if model is identity
assert np.allclose(pred, data, rtol=1e-3)

# Check that feature_ranges have correct shape
assert len(scalerlayer.feature_range) == len(feature_ranges)
for fr1, fr2 in zip(scalerlayer.feature_range, feature_ranges):
assert len(scalerlayer.feature_range) == len(feature_range)
for fr1, fr2 in zip(scalerlayer.feature_range, feature_range):
assert fr1[0] == fr2[0] and fr1[1] == fr2[1]

0 comments on commit db3c8a9

Please sign in to comment.