Skip to content

Commit

Permalink
fix: Fix struct arithmetic schema
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 22, 2024
1 parent b5a8a50 commit 9219e44
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
45 changes: 28 additions & 17 deletions crates/polars-plan/src/logical_plan/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,25 +298,36 @@ fn get_arithmetic_field(
_ => {
let right_type = right_ae.get_type(schema, ctxt, arena)?;

// Avoid needlessly type casting numeric columns during arithmetic
// with literals.
if (left_field.dtype.is_integer() && right_type.is_integer())
|| (left_field.dtype.is_float() && right_type.is_float())
{
match (left_ae, right_ae) {
(AExpr::Literal(_), AExpr::Literal(_)) => {},
(AExpr::Literal(_), _) => {
// literal will be coerced to match right type
left_field.coerce(right_type);
match (&left_field.dtype, &right_type) {
#[cfg(feature = "dtype-struct")]
(Struct(_), Struct(_)) => {
if op.is_arithmetic() {
return Ok(left_field);
},
(_, AExpr::Literal(_)) => {
// literal will be coerced to match right type
return Ok(left_field);
},
_ => {},
}
}
},
_ => {
// Avoid needlessly type casting numeric columns during arithmetic
// with literals.
if (left_field.dtype.is_integer() && right_type.is_integer())
|| (left_field.dtype.is_float() && right_type.is_float())
{
match (left_ae, right_ae) {
(AExpr::Literal(_), AExpr::Literal(_)) => {},
(AExpr::Literal(_), _) => {
// literal will be coerced to match right type
left_field.coerce(right_type);
return Ok(left_field);
},
(_, AExpr::Literal(_)) => {
// literal will be coerced to match right type
return Ok(left_field);
},
_ => {},
}
}
},
}

try_get_supertype(&left_field.dtype, &right_type)?
},
};
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/datatypes/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,3 +868,11 @@ def test_struct_null_count_10130() -> None:

s = pl.Series([{"a": None}])
assert s.null_count() == 1


def test_struct_arithmetic_schema() -> None:
q = pl.LazyFrame({"A": [1], "B": [2]})

assert q.select(pl.struct("A") - pl.struct("B")).schema["A"] == pl.Struct(
{"A": pl.Int64}
)

0 comments on commit 9219e44

Please sign in to comment.