Skip to content

Commit

Permalink
feat: Sort decimal fields (pola-rs#14649)
Browse files Browse the repository at this point in the history
  • Loading branch information
flisky authored Mar 6, 2024
1 parent 72a6f89 commit d562759
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 0 deletions.
3 changes: 3 additions & 0 deletions crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,9 @@ pub(crate) fn convert_sort_column_multi_sort(s: &Series) -> PolarsResult<Series>
.collect::<PolarsResult<Vec<_>>>()?;
return StructChunked::new(ca.name(), &new_fields).map(|ca| ca.into_series());
},
// we could fallback to default branch, but decimal is not numeric dtype for now, so explicit here
#[cfg(feature = "dtype-decimal")]
Decimal(_, _) => s.clone(),
_ => {
let phys = s.to_physical_repr().into_owned();
polars_ensure!(
Expand Down
17 changes: 17 additions & 0 deletions crates/polars-core/src/series/implementations/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ impl private::PrivateSeries for SeriesWrap<DecimalChunked> {
.into_decimal_unchecked(self.0.precision(), self.0.scale())
.into_series())
}
fn into_total_eq_inner<'a>(&'a self) -> Box<dyn TotalEqInner + 'a> {
(&self.0).into_total_eq_inner()
}
fn into_total_ord_inner<'a>(&'a self) -> Box<dyn TotalOrdInner + 'a> {
(&self.0).into_total_ord_inner()
}

#[cfg(feature = "algorithm_group_by")]
unsafe fn agg_sum(&self, groups: &GroupsProxy) -> Series {
Expand Down Expand Up @@ -211,6 +217,17 @@ impl SeriesTrait for SeriesWrap<DecimalChunked> {
self.0.get_any_value_unchecked(index)
}

fn sort_with(&self, options: SortOptions) -> Series {
self.0
.sort_with(options)
.into_decimal_unchecked(self.0.precision(), self.0.scale())
.into_series()
}

fn arg_sort(&self, options: SortOptions) -> IdxCa {
self.0.arg_sort(options)
}

fn null_count(&self) -> usize {
self.0.null_count()
}
Expand Down
1 change: 1 addition & 0 deletions crates/polars-row/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ pub fn encoded_size(data_type: &ArrowDataType) -> usize {
Int16 => i16::ENCODED_LEN,
Int32 => i32::ENCODED_LEN,
Int64 => i64::ENCODED_LEN,
Decimal(_, _) => i128::ENCODED_LEN,
Float32 => f32::ENCODED_LEN,
Float64 => f64::ENCODED_LEN,
Boolean => bool::ENCODED_LEN,
Expand Down
1 change: 1 addition & 0 deletions crates/polars-row/src/fixed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ encode_signed!(1, i8);
encode_signed!(2, i16);
encode_signed!(4, i32);
encode_signed!(8, i64);
encode_signed!(16, i128);

impl FixedLengthEncoding for f32 {
type Encoded = [u8; 4];
Expand Down
1 change: 1 addition & 0 deletions crates/polars-row/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ macro_rules! with_match_arrow_primitive_type {(
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
Decimal(_, _) => __with_ty__! { i128 },
UInt8 => __with_ty__! { u8 },
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
Expand Down
27 changes: 27 additions & 0 deletions py-polars/tests/unit/datatypes/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,33 @@ def test_decimal_in_filter() -> None:
}


def test_decimal_sort() -> None:
df = pl.DataFrame(
{
"foo": [1, 2, 3],
"bar": [D("3.4"), D("2.1"), D("4.5")],
"baz": [1, 1, 2],
}
)
assert df.sort("bar").to_dict(as_series=False) == {
"foo": [2, 1, 3],
"bar": [D("2.1"), D("3.4"), D("4.5")],
"baz": [1, 1, 2],
}
assert df.sort(["foo", "bar"]).to_dict(as_series=False) == {
"foo": [1, 2, 3],
"bar": [D("3.4"), D("2.1"), D("4.5")],
"baz": [1, 1, 2],
}

assert df.select([pl.col("foo").sort_by("bar", descending=True).alias("s1")])[
"s1"
].to_list() == [3, 1, 2]
assert df.select([pl.col("foo").sort_by(["baz", "bar"]).alias("s2")])[
"s2"
].to_list() == [2, 1, 3]


def test_decimal_write_parquet_12375() -> None:
f = io.BytesIO()
df = pl.DataFrame(
Expand Down

0 comments on commit d562759

Please sign in to comment.