From e09720e8723455807cb798b392c25c2f2c2d88bd Mon Sep 17 00:00:00 2001 From: conradsoon <18553610+conradsoon@users.noreply.github.com> Date: Sun, 3 Nov 2024 18:43:21 +0800 Subject: [PATCH] fix(clip): fix tests --- daft/expressions/expressions.py | 2 - src/daft-core/src/datatypes/infer_datatype.rs | 3 - src/daft-functions/src/numeric/clip.rs | 25 ++++- tests/expressions/test_expressions.py | 8 +- tests/expressions/typing/test_arithmetic.py | 2 +- tests/table/numeric/test_numeric.py | 104 +++++++++++++++--- 6 files changed, 116 insertions(+), 28 deletions(-) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 15ace02f25..4f39b1455e 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -634,8 +634,6 @@ def clip(self, min: Expression, max: Expression) -> Expression: """ min_expr = Expression._to_expression(min) max_expr = Expression._to_expression(max) - if max_expr < min_expr: - raise ValueError("max must be greater than or equal to min") return Expression._from_pyexpr(native.clip(self._expr, min_expr._expr, max_expr._expr)) def sign(self) -> Expression: diff --git a/src/daft-core/src/datatypes/infer_datatype.rs b/src/daft-core/src/datatypes/infer_datatype.rs index f6a95e78da..db182e983b 100644 --- a/src/daft-core/src/datatypes/infer_datatype.rs +++ b/src/daft-core/src/datatypes/infer_datatype.rs @@ -65,9 +65,6 @@ impl<'a> InferDataType<'a> { (DataType::Null, _, _) => { Err(DaftError::TypeError("Cannot clip null values".to_string())) } // These checks are for situations where the Python bindings pass in a None directly. - (_, DataType::Null, DataType::Null) => Err(DaftError::TypeError( - "Cannot clip values with null min and max".to_string(), - )), // As above. (input_type, min_type, max_type) => { // This path gets called when the Python bindings pass in a Series, but note that there can still be nulls within the series. let mut output_type = (*input_type).clone(); diff --git a/src/daft-functions/src/numeric/clip.rs b/src/daft-functions/src/numeric/clip.rs index 663be98252..946987b9e4 100644 --- a/src/daft-functions/src/numeric/clip.rs +++ b/src/daft-functions/src/numeric/clip.rs @@ -1,6 +1,6 @@ use common_error::{DaftError, DaftResult}; use daft_core::{ - datatypes::InferDataType, + datatypes::{DataType, InferDataType}, prelude::{Field, Schema}, series::Series, }; @@ -34,6 +34,29 @@ impl ScalarUDF for Clip { let min_field = inputs[1].to_field(schema)?; let max_field = inputs[2].to_field(schema)?; + // Check if the array_field is numeric + if !array_field.dtype.is_numeric() { + return Err(DaftError::TypeError(format!( + "Expected array input to be numeric, got {}", + array_field.dtype + ))); + } + + // Check if min_field and max_field are numeric or null + if !(min_field.dtype.is_numeric() || min_field.dtype == DataType::Null) { + return Err(DaftError::TypeError(format!( + "Expected min input to be numeric or null, got {}", + min_field.dtype + ))); + } + + if !(max_field.dtype.is_numeric() || max_field.dtype == DataType::Null) { + return Err(DaftError::TypeError(format!( + "Expected max input to be numeric or null, got {}", + max_field.dtype + ))); + } + let output_type = InferDataType::clip_op( &InferDataType::from(&array_field.dtype), &InferDataType::from(&min_field.dtype), diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 59131a870f..51fe60583d 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -131,13 +131,9 @@ def test_repr_functions_clip() -> None: a = col("a") b = col("b") c = col("c") - y = a.clip( - a, - b, - c, - ) + y = a.clip(b, c) repr_out = repr(y) - assert repr_out == "clip(col(a), col(b), col(c)))" + assert repr_out == "clip(col(a), col(b), col(c))" copied = copy.deepcopy(y) assert repr_out == repr(copied) diff --git a/tests/expressions/typing/test_arithmetic.py b/tests/expressions/typing/test_arithmetic.py index a18903d2c5..c10ddb20fd 100644 --- a/tests/expressions/typing/test_arithmetic.py +++ b/tests/expressions/typing/test_arithmetic.py @@ -124,7 +124,7 @@ def test_round(unary_data_fixture): def test_clip(ternary_data_fixture): data, min, max = ternary_data_fixture assert_typing_resolve_vs_runtime_behavior( - data=(data,), + data=ternary_data_fixture, expr=col(data.name()).clip(col(min.name()), col(max.name())), run_kernel=lambda: data.clip(min, max), resolvable=is_numeric(data.datatype()) diff --git a/tests/table/numeric/test_numeric.py b/tests/table/numeric/test_numeric.py index f23dcbda71..88479816b8 100644 --- a/tests/table/numeric/test_numeric.py +++ b/tests/table/numeric/test_numeric.py @@ -390,9 +390,8 @@ def test_clip_zero_handling(): def test_clip_empty_array(): table = MicroPartition.from_pydict({"a": []}) - clip_table = table.eval_expression_list([col("a").clip(0, 1)]) - expected = [] - assert clip_table.get_column("a").to_pylist() == expected + with pytest.raises(ValueError): + table.eval_expression_list([col("a").clip(0, 1)]) def test_clip_all_within_bounds(): @@ -418,32 +417,107 @@ def test_clip_nan_handling(): def test_clip_column_with_scalar(): - table = MicroPartition.from_pydict({"a": [1, 2, 3, 4, 5]}) - # Clip with column as lower bound and scalar as upper bound - clip_table = table.eval_expression_list([col("a").clip(col("a"), 4)]) - expected = [1, 2, 3, 4, 4] - assert clip_table.get_column("a").to_pylist() == expected + # Initialize the table with data, lower bounds, and upper bounds + table = MicroPartition.from_pydict( + { + "data": [1.0, 2.5, None, 4.7, 5.0, float("nan")], + "lower_bound": [0.5, 2.0, 1.0, None, 4.0, 0.0], + "upper_bound": [2.0, 3.0, 5.0, None, None, float("inf")], + } + ) - # Clip with scalar as lower bound and column as upper bound - clip_table = table.eval_expression_list([col("a").clip(2, col("a"))]) - expected = [2, 2, 3, 4, 5] - assert clip_table.get_column("a").to_pylist() == expected + # Clip with column lower bound and scalar upper bound (5) + clip_table = table.eval_expression_list([col("data").clip(col("lower_bound"), 5)]) + expected = [ + 1.0, # 1.0 clipped between 0.5 and 5 -> 1.0 + 2.5, # 2.5 clipped between 2.0 and 5 -> 2.5 + None, # data is None + 4.7, # lower_bound is None, no lower bound applied + 5.0, # 5.0 clipped between 4.0 and 5 -> 5.0 + float("nan"), # data is NaN + ] + actual = clip_table.get_column("data").to_pylist() + assert all( + (a == b) or (a is None and b is None) or (math.isnan(a) and math.isnan(b)) for a, b in zip(actual, expected) + ), f"Expected {expected}, got {actual}" + + # Clip with scalar lower bound (2.0) and column upper bound + clip_table = table.eval_expression_list([col("data").clip(2.0, col("upper_bound"))]) + expected = [ + 2.0, # 1.0 clipped to 2.0 (upper_bound is 2.0) + 2.5, # 2.5 remains (between 2.0 and 3.0) + None, # data is None + 4.7, # upper_bound is None, no upper bound applied + 5.0, # 5.0 remains (upper_bound is None) + float("nan"), # data is NaN + ] + actual = clip_table.get_column("data").to_pylist() + assert all( + (a == b) or (a is None and b is None) or (math.isnan(a) and math.isnan(b)) for a, b in zip(actual, expected) + ), f"Expected {expected}, got {actual}" + + # Clip with column lower bound and column upper bound + clip_table = table.eval_expression_list([col("data").clip(col("lower_bound"), col("upper_bound"))]) + expected = [ + 1.0, # Clipped between 0.5 and 2.0 -> 1.0 + 2.5, # Clipped between 2.0 and 3.0 -> 2.5 + None, # data is None + 4.7, # lower and upper bounds are None, data remains unchanged + 5.0, # Clipped between 4.0 and 5.0 -> 5.0 + float("nan"), # data is NaN + ] + actual = clip_table.get_column("data").to_pylist() + assert all( + (a == b) or (a is None and b is None) or (math.isnan(a) and math.isnan(b)) for a, b in zip(actual, expected) + ), f"Expected {expected}, got {actual}" + + # Clip with scalar lower bound (-inf) and upper bound (inf) + clip_table = table.eval_expression_list([col("data").clip(float("-inf"), float("inf"))]) + expected = [1.0, 2.5, None, 4.7, 5.0, float("nan")] # Data remains unchanged + actual = clip_table.get_column("data").to_pylist() + assert all( + (a == b) or (a is None and b is None) or (math.isnan(a) and math.isnan(b)) for a, b in zip(actual, expected) + ), f"Expected {expected}, got {actual}" + + # Clip with None lower bound and None upper bound + clip_table = table.eval_expression_list([col("data").clip(None, None)]) + expected = [1.0, 2.5, None, 4.7, 5.0, float("nan")] # Data remains unchanged + actual = clip_table.get_column("data").to_pylist() + assert all( + (a == b) or (a is None and b is None) or (math.isnan(a) and math.isnan(b)) for a, b in zip(actual, expected) + ), f"Expected {expected}, got {actual}" + + # Clip with scalar lower bound (2.0) and scalar upper bound (5.0) + clip_table = table.eval_expression_list([col("data").clip(2.0, 5.0)]) + expected = [ + 2.0, # 1.0 clipped to 2.0 + 2.5, # 2.5 remains + None, # data is None + 4.7, # 4.7 remains + 5.0, # 5.0 remains + float("nan"), # data is NaN + ] + actual = clip_table.get_column("data").to_pylist() + assert all( + (a == b) or (a is None and b is None) or (math.isnan(a) and math.isnan(b)) for a, b in zip(actual, expected) + ), f"Expected {expected}, got {actual}" def test_clip_invalid_bounds(): table = MicroPartition.from_pydict({"a": [1, 2, 3, 4, 5], "b": [2, 3, 4, 5, 6]}) + ### NOTE: This is meant to catch PanicException in pyo3. # Test with column as lower bound and scalar as upper bound where upper < lower - with pytest.raises(ValueError, match="Upper bound must be greater than or equal to lower bound"): + with pytest.raises(BaseException): table.eval_expression_list([col("a").clip(col("b"), 1)]) # Test with scalar as lower bound and column as upper bound where upper < lower - with pytest.raises(ValueError, match="Upper bound must be greater than or equal to lower bound"): + with pytest.raises(BaseException): table.eval_expression_list([col("a").clip(6, col("b"))]) # Test with both bounds as columns where some upper values < lower values table = MicroPartition.from_pydict({"a": [1, 2, 3, 4, 5], "b": [2, 1, 4, 3, 6]}) - with pytest.raises(ValueError, match="Upper bound must be greater than or equal to lower bound"): + with pytest.raises(BaseException): table.eval_expression_list([col("a").clip(col("b"), col("a"))])