Skip to content

Commit

Permalink
fix(clip): fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
conradsoon committed Nov 3, 2024
1 parent 503ae13 commit e09720e
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 28 deletions.
2 changes: 0 additions & 2 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,6 @@ def clip(self, min: Expression, max: Expression) -> Expression:
"""
min_expr = Expression._to_expression(min)
max_expr = Expression._to_expression(max)
if max_expr < min_expr:
raise ValueError("max must be greater than or equal to min")
return Expression._from_pyexpr(native.clip(self._expr, min_expr._expr, max_expr._expr))

def sign(self) -> Expression:
Expand Down
3 changes: 0 additions & 3 deletions src/daft-core/src/datatypes/infer_datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ impl<'a> InferDataType<'a> {
(DataType::Null, _, _) => {
Err(DaftError::TypeError("Cannot clip null values".to_string()))
} // These checks are for situations where the Python bindings pass in a None directly.
(_, DataType::Null, DataType::Null) => Err(DaftError::TypeError(
"Cannot clip values with null min and max".to_string(),
)), // As above.
(input_type, min_type, max_type) => {
// This path gets called when the Python bindings pass in a Series, but note that there can still be nulls within the series.
let mut output_type = (*input_type).clone();
Expand Down
25 changes: 24 additions & 1 deletion src/daft-functions/src/numeric/clip.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use common_error::{DaftError, DaftResult};
use daft_core::{
datatypes::InferDataType,
datatypes::{DataType, InferDataType},
prelude::{Field, Schema},
series::Series,
};
Expand Down Expand Up @@ -34,6 +34,29 @@ impl ScalarUDF for Clip {
let min_field = inputs[1].to_field(schema)?;
let max_field = inputs[2].to_field(schema)?;

// Check if the array_field is numeric
if !array_field.dtype.is_numeric() {
return Err(DaftError::TypeError(format!(
"Expected array input to be numeric, got {}",
array_field.dtype
)));
}

// Check if min_field and max_field are numeric or null
if !(min_field.dtype.is_numeric() || min_field.dtype == DataType::Null) {
return Err(DaftError::TypeError(format!(
"Expected min input to be numeric or null, got {}",
min_field.dtype
)));
}

if !(max_field.dtype.is_numeric() || max_field.dtype == DataType::Null) {
return Err(DaftError::TypeError(format!(
"Expected max input to be numeric or null, got {}",
max_field.dtype
)));
}

let output_type = InferDataType::clip_op(
&InferDataType::from(&array_field.dtype),
&InferDataType::from(&min_field.dtype),
Expand Down
8 changes: 2 additions & 6 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,9 @@ def test_repr_functions_clip() -> None:
a = col("a")
b = col("b")
c = col("c")
y = a.clip(
a,
b,
c,
)
y = a.clip(b, c)
repr_out = repr(y)
assert repr_out == "clip(col(a), col(b), col(c)))"
assert repr_out == "clip(col(a), col(b), col(c))"
copied = copy.deepcopy(y)
assert repr_out == repr(copied)

Expand Down
2 changes: 1 addition & 1 deletion tests/expressions/typing/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_round(unary_data_fixture):
def test_clip(ternary_data_fixture):
data, min, max = ternary_data_fixture
assert_typing_resolve_vs_runtime_behavior(
data=(data,),
data=ternary_data_fixture,
expr=col(data.name()).clip(col(min.name()), col(max.name())),
run_kernel=lambda: data.clip(min, max),
resolvable=is_numeric(data.datatype())
Expand Down
104 changes: 89 additions & 15 deletions tests/table/numeric/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,8 @@ def test_clip_zero_handling():

def test_clip_empty_array():
table = MicroPartition.from_pydict({"a": []})
clip_table = table.eval_expression_list([col("a").clip(0, 1)])
expected = []
assert clip_table.get_column("a").to_pylist() == expected
with pytest.raises(ValueError):
table.eval_expression_list([col("a").clip(0, 1)])


def test_clip_all_within_bounds():
Expand All @@ -418,32 +417,107 @@ def test_clip_nan_handling():


def test_clip_column_with_scalar():
table = MicroPartition.from_pydict({"a": [1, 2, 3, 4, 5]})
# Clip with column as lower bound and scalar as upper bound
clip_table = table.eval_expression_list([col("a").clip(col("a"), 4)])
expected = [1, 2, 3, 4, 4]
assert clip_table.get_column("a").to_pylist() == expected
# Initialize the table with data, lower bounds, and upper bounds
table = MicroPartition.from_pydict(
{
"data": [1.0, 2.5, None, 4.7, 5.0, float("nan")],
"lower_bound": [0.5, 2.0, 1.0, None, 4.0, 0.0],
"upper_bound": [2.0, 3.0, 5.0, None, None, float("inf")],
}
)

# Clip with scalar as lower bound and column as upper bound
clip_table = table.eval_expression_list([col("a").clip(2, col("a"))])
expected = [2, 2, 3, 4, 5]
assert clip_table.get_column("a").to_pylist() == expected
# Clip with column lower bound and scalar upper bound (5)
clip_table = table.eval_expression_list([col("data").clip(col("lower_bound"), 5)])
expected = [
1.0, # 1.0 clipped between 0.5 and 5 -> 1.0
2.5, # 2.5 clipped between 2.0 and 5 -> 2.5
None, # data is None
4.7, # lower_bound is None, no lower bound applied
5.0, # 5.0 clipped between 4.0 and 5 -> 5.0
float("nan"), # data is NaN
]
actual = clip_table.get_column("data").to_pylist()
assert all(
(a == b) or (a is None and b is None) or (math.isnan(a) and math.isnan(b)) for a, b in zip(actual, expected)
), f"Expected {expected}, got {actual}"

# Clip with scalar lower bound (2.0) and column upper bound
clip_table = table.eval_expression_list([col("data").clip(2.0, col("upper_bound"))])
expected = [
2.0, # 1.0 clipped to 2.0 (upper_bound is 2.0)
2.5, # 2.5 remains (between 2.0 and 3.0)
None, # data is None
4.7, # upper_bound is None, no upper bound applied
5.0, # 5.0 remains (upper_bound is None)
float("nan"), # data is NaN
]
actual = clip_table.get_column("data").to_pylist()
assert all(
(a == b) or (a is None and b is None) or (math.isnan(a) and math.isnan(b)) for a, b in zip(actual, expected)
), f"Expected {expected}, got {actual}"

# Clip with column lower bound and column upper bound
clip_table = table.eval_expression_list([col("data").clip(col("lower_bound"), col("upper_bound"))])
expected = [
1.0, # Clipped between 0.5 and 2.0 -> 1.0
2.5, # Clipped between 2.0 and 3.0 -> 2.5
None, # data is None
4.7, # lower and upper bounds are None, data remains unchanged
5.0, # Clipped between 4.0 and 5.0 -> 5.0
float("nan"), # data is NaN
]
actual = clip_table.get_column("data").to_pylist()
assert all(
(a == b) or (a is None and b is None) or (math.isnan(a) and math.isnan(b)) for a, b in zip(actual, expected)
), f"Expected {expected}, got {actual}"

# Clip with scalar lower bound (-inf) and upper bound (inf)
clip_table = table.eval_expression_list([col("data").clip(float("-inf"), float("inf"))])
expected = [1.0, 2.5, None, 4.7, 5.0, float("nan")] # Data remains unchanged
actual = clip_table.get_column("data").to_pylist()
assert all(
(a == b) or (a is None and b is None) or (math.isnan(a) and math.isnan(b)) for a, b in zip(actual, expected)
), f"Expected {expected}, got {actual}"

# Clip with None lower bound and None upper bound
clip_table = table.eval_expression_list([col("data").clip(None, None)])
expected = [1.0, 2.5, None, 4.7, 5.0, float("nan")] # Data remains unchanged
actual = clip_table.get_column("data").to_pylist()
assert all(
(a == b) or (a is None and b is None) or (math.isnan(a) and math.isnan(b)) for a, b in zip(actual, expected)
), f"Expected {expected}, got {actual}"

# Clip with scalar lower bound (2.0) and scalar upper bound (5.0)
clip_table = table.eval_expression_list([col("data").clip(2.0, 5.0)])
expected = [
2.0, # 1.0 clipped to 2.0
2.5, # 2.5 remains
None, # data is None
4.7, # 4.7 remains
5.0, # 5.0 remains
float("nan"), # data is NaN
]
actual = clip_table.get_column("data").to_pylist()
assert all(
(a == b) or (a is None and b is None) or (math.isnan(a) and math.isnan(b)) for a, b in zip(actual, expected)
), f"Expected {expected}, got {actual}"


def test_clip_invalid_bounds():
table = MicroPartition.from_pydict({"a": [1, 2, 3, 4, 5], "b": [2, 3, 4, 5, 6]})

### NOTE: This is meant to catch PanicException in pyo3.
# Test with column as lower bound and scalar as upper bound where upper < lower
with pytest.raises(ValueError, match="Upper bound must be greater than or equal to lower bound"):
with pytest.raises(BaseException):
table.eval_expression_list([col("a").clip(col("b"), 1)])

# Test with scalar as lower bound and column as upper bound where upper < lower
with pytest.raises(ValueError, match="Upper bound must be greater than or equal to lower bound"):
with pytest.raises(BaseException):
table.eval_expression_list([col("a").clip(6, col("b"))])

# Test with both bounds as columns where some upper values < lower values
table = MicroPartition.from_pydict({"a": [1, 2, 3, 4, 5], "b": [2, 1, 4, 3, 6]})
with pytest.raises(ValueError, match="Upper bound must be greater than or equal to lower bound"):
with pytest.raises(BaseException):
table.eval_expression_list([col("a").clip(col("b"), col("a"))])


Expand Down

0 comments on commit e09720e

Please sign in to comment.