Skip to content

Commit

Permalink
feat: Support per-column nulls_last on sort operations (#16639)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Jun 1, 2024
1 parent 4c80271 commit 8710274
Show file tree
Hide file tree
Showing 33 changed files with 279 additions and 213 deletions.
10 changes: 8 additions & 2 deletions crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,14 @@ pub fn _arg_bottom_k(
sort_options: &mut SortMultipleOptions,
) -> PolarsResult<NoNull<IdxCa>> {
let from_n_rows = by_column[0].len();
_broadcast_descending(by_column.len(), &mut sort_options.descending);
let encoded = _get_rows_encoded(by_column, &sort_options.descending, sort_options.nulls_last)?;
_broadcast_bools(by_column.len(), &mut sort_options.descending);
_broadcast_bools(by_column.len(), &mut sort_options.nulls_last);

let encoded = _get_rows_encoded(
by_column,
&sort_options.descending,
&sort_options.nulls_last,
)?;
let arr = encoded.into_array();
let mut rows = arr
.values_iter()
Expand Down
34 changes: 21 additions & 13 deletions crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,26 @@ pub(crate) fn arg_sort_multiple_impl<T: NullOrderCmp + Send + Copy>(
by: &[Series],
options: &SortMultipleOptions,
) -> PolarsResult<IdxCa> {
let nulls_last = &options.nulls_last;
let descending = &options.descending;

debug_assert_eq!(descending.len() - 1, by.len());
debug_assert_eq!(nulls_last.len() - 1, by.len());

let compare_inner: Vec<_> = by
.iter()
.map(|s| s.into_total_ord_inner())
.collect_trusted();

let first_descending = descending[0];
let first_nulls_last = nulls_last[0];

let compare = |tpl_a: &(_, T), tpl_b: &(_, T)| -> Ordering {
match (
first_descending,
tpl_a
.1
.null_order_cmp(&tpl_b.1, options.nulls_last ^ first_descending),
.null_order_cmp(&tpl_b.1, first_nulls_last ^ first_descending),
) {
// if ordering is equal, we check the other arrays until we find a non-equal ordering
// if we have exhausted all arrays, we keep the equal ordering.
Expand All @@ -52,7 +57,7 @@ pub(crate) fn arg_sort_multiple_impl<T: NullOrderCmp + Send + Copy>(
ordering_other_columns(
&compare_inner,
descending.get_unchecked(1..),
options.nulls_last,
nulls_last.get_unchecked(1..),
idx_a,
idx_b,
)
Expand Down Expand Up @@ -184,17 +189,19 @@ pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult<RowsEncoded> {
pub fn _get_rows_encoded(
by: &[Series],
descending: &[bool],
nulls_last: bool,
nulls_last: &[bool],
) -> PolarsResult<RowsEncoded> {
debug_assert_eq!(by.len(), descending.len());
debug_assert_eq!(by.len(), nulls_last.len());

let mut cols = Vec::with_capacity(by.len());
let mut fields = Vec::with_capacity(by.len());
for (by, descending) in by.iter().zip(descending) {
let arr = _get_rows_encoded_compat_array(by)?;

for ((by, desc), null_last) in by.iter().zip(descending).zip(nulls_last) {
let arr = _get_rows_encoded_compat_array(by)?;
let sort_field = EncodingField {
descending: *descending,
nulls_last,
descending: *desc,
nulls_last: *null_last,
no_order: false,
};
match arr.data_type() {
Expand All @@ -203,12 +210,12 @@ pub fn _get_rows_encoded(
let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
for arr in arr.values() {
cols.push(arr.clone() as ArrayRef);
fields.push(sort_field)
fields.push(sort_field);
}
},
_ => {
cols.push(arr);
fields.push(sort_field)
fields.push(sort_field);
},
}
}
Expand All @@ -219,7 +226,7 @@ pub fn _get_rows_encoded_ca(
name: &str,
by: &[Series],
descending: &[bool],
nulls_last: bool,
nulls_last: &[bool],
) -> PolarsResult<BinaryOffsetChunked> {
_get_rows_encoded(by, descending, nulls_last)
.map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array()))
Expand All @@ -236,12 +243,13 @@ pub fn _get_rows_encoded_ca_unordered(
pub(crate) fn argsort_multiple_row_fmt(
by: &[Series],
mut descending: Vec<bool>,
nulls_last: bool,
mut nulls_last: Vec<bool>,
parallel: bool,
) -> PolarsResult<IdxCa> {
_broadcast_descending(by.len(), &mut descending);
_broadcast_bools(by.len(), &mut descending);
_broadcast_bools(by.len(), &mut nulls_last);

let rows_encoded = _get_rows_encoded(by, &descending, nulls_last)?;
let rows_encoded = _get_rows_encoded(by, &descending, &nulls_last)?;
let mut items: Vec<_> = rows_encoded.iter().enumerate_idx().collect();

if parallel {
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-core/src/chunked_array/ops/sort/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ mod test {

let out = df.sort(
["cat", "vals"],
SortMultipleOptions::default().with_order_descendings([false, false]),
SortMultipleOptions::default().with_order_descending_multi([false, false]),
)?;
let out = out.column("cat")?;
let cat = out.categorical()?;
assert_order(cat, &["a", "a", "b", "c"]);

let out = df.sort(
["vals", "cat"],
SortMultipleOptions::default().with_order_descendings([false, false]),
SortMultipleOptions::default().with_order_descending_multi([false, false]),
)?;
let out = out.column("cat")?;
let cat = out.categorical()?;
Expand Down
29 changes: 14 additions & 15 deletions crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,13 @@ where
fn ordering_other_columns<'a>(
compare_inner: &'a [Box<dyn TotalOrdInner + 'a>],
descending: &[bool],
nulls_last: bool,
nulls_last: &[bool],
idx_a: usize,
idx_b: usize,
) -> Ordering {
for (cmp, descending) in compare_inner.iter().zip(descending) {
// SAFETY:
// indices are in bounds
let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b, nulls_last ^ descending) };
for ((cmp, descending), null_last) in compare_inner.iter().zip(descending).zip(nulls_last) {
// SAFETY: indices are in bounds
let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b, null_last ^ descending) };
match (ordering, descending) {
(Ordering::Equal, _) => continue,
(_, true) => return ordering.reverse(),
Expand Down Expand Up @@ -557,7 +556,7 @@ impl StructChunked {
self.name(),
&[self.clone().into_series()],
&[options.descending],
options.nulls_last,
&[options.nulls_last],
)
.unwrap();
bin.arg_sort(Default::default())
Expand Down Expand Up @@ -670,10 +669,10 @@ pub(crate) fn convert_sort_column_multi_sort(s: &Series) -> PolarsResult<Series>
Ok(out)
}

pub fn _broadcast_descending(n_cols: usize, descending: &mut Vec<bool>) {
if n_cols > descending.len() && descending.len() == 1 {
while n_cols != descending.len() {
descending.push(descending[0]);
pub fn _broadcast_bools(n_cols: usize, values: &mut Vec<bool>) {
if n_cols > values.len() && values.len() == 1 {
while n_cols != values.len() {
values.push(values[0]);
}
}
}
Expand All @@ -689,10 +688,10 @@ pub(crate) fn prepare_arg_sort(
.map(convert_sort_column_multi_sort)
.collect::<PolarsResult<Vec<_>>>()?;

let first = columns.remove(0);
_broadcast_bools(n_cols, &mut sort_options.descending);
_broadcast_bools(n_cols, &mut sort_options.nulls_last);

// broadcast ordering
_broadcast_descending(n_cols, &mut sort_options.descending);
let first = columns.remove(0);
Ok((first, columns))
}

Expand Down Expand Up @@ -831,7 +830,7 @@ mod test {

let out = df.sort(
["groups", "values"],
SortMultipleOptions::default().with_order_descendings([true, false]),
SortMultipleOptions::default().with_order_descending_multi([true, false]),
)?;
let expected = df!(
"groups" => [3, 2, 1],
Expand All @@ -841,7 +840,7 @@ mod test {

let out = df.sort(
["values", "groups"],
SortMultipleOptions::default().with_order_descendings([false, true]),
SortMultipleOptions::default().with_order_descending_multi([false, true]),
)?;
let expected = df!(
"groups" => [2, 1, 3],
Expand Down
41 changes: 27 additions & 14 deletions crates/polars-core/src/chunked_array/ops/sort/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct SortOptions {
/// SortMultipleOptions::default()
/// .with_maintain_order(true)
/// .with_multithreaded(false)
/// .with_order_descendings([false, true])
/// .with_order_descending_multi([false, true])
/// .with_nulls_last(true),
/// )?;
///
Expand All @@ -83,15 +83,15 @@ pub struct SortMultipleOptions {
///
/// If only one value is given, it will broadcast to all columns.
///
/// Use [`SortMultipleOptions::with_order_descendings`]
/// Use [`SortMultipleOptions::with_order_descending_multi`]
/// or [`SortMultipleOptions::with_order_descending`] to modify.
///
/// # Safety
///
/// Len must matches the number of columns or equal to 1.
/// Len must match the number of columns, or equal 1.
pub descending: Vec<bool>,
/// Whether place null values last. Default `false`.
pub nulls_last: bool,
pub nulls_last: Vec<bool>,
/// Whether sort in multiple threads. Default `true`.
pub multithreaded: bool,
/// Whether maintain the order of equal elements. Default `false`.
Expand All @@ -113,7 +113,7 @@ impl Default for SortMultipleOptions {
fn default() -> Self {
Self {
descending: vec![false],
nulls_last: false,
nulls_last: vec![false],
multithreaded: true,
maintain_order: false,
}
Expand All @@ -126,12 +126,15 @@ impl SortMultipleOptions {
Self::default()
}

/// Specify order for each columns. Default all `false`.
/// Specify order for each column. Defaults all `false`.
///
/// # Safety
///
/// Len must matches the number of columns or equal to 1.
pub fn with_order_descendings(mut self, descending: impl IntoIterator<Item = bool>) -> Self {
/// Len must match the number of columns, or be equal to 1.
pub fn with_order_descending_multi(
mut self,
descending: impl IntoIterator<Item = bool>,
) -> Self {
self.descending = descending.into_iter().collect();
self
}
Expand All @@ -142,19 +145,29 @@ impl SortMultipleOptions {
self
}

/// Whether place null values last. Default `false`.
/// Specify whether to place nulls last, per-column. Defaults all `false`.
///
/// # Safety
///
/// Len must match the number of columns, or be equal to 1.
pub fn with_nulls_last_multi(mut self, nulls_last: impl IntoIterator<Item = bool>) -> Self {
self.nulls_last = nulls_last.into_iter().collect();
self
}

/// Whether to place null values last. Default `false`.
pub fn with_nulls_last(mut self, enabled: bool) -> Self {
self.nulls_last = enabled;
self.nulls_last = vec![enabled];
self
}

/// Whether sort in multiple threads. Default `true`.
/// Whether to sort in multiple threads. Default `true`.
pub fn with_multithreaded(mut self, enabled: bool) -> Self {
self.multithreaded = enabled;
self
}

/// Whether maintain the order of equal elements. Default `false`.
/// Whether to maintain the order of equal elements. Default `false`.
pub fn with_maintain_order(mut self, enabled: bool) -> Self {
self.maintain_order = enabled;
self
Expand Down Expand Up @@ -208,7 +221,7 @@ impl From<&SortOptions> for SortMultipleOptions {
fn from(value: &SortOptions) -> Self {
SortMultipleOptions {
descending: vec![value.descending],
nulls_last: value.nulls_last,
nulls_last: vec![value.nulls_last],
multithreaded: value.multithreaded,
maintain_order: value.maintain_order,
}
Expand All @@ -219,7 +232,7 @@ impl From<&SortMultipleOptions> for SortOptions {
fn from(value: &SortMultipleOptions) -> Self {
SortOptions {
descending: value.descending.first().copied().unwrap_or(false),
nulls_last: value.nulls_last,
nulls_last: value.nulls_last.first().copied().unwrap_or(false),
multithreaded: value.multithreaded,
maintain_order: value.maintain_order,
}
Expand Down
19 changes: 7 additions & 12 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1773,18 +1773,17 @@ impl DataFrame {
mut sort_options: SortMultipleOptions,
slice: Option<(i64, usize)>,
) -> PolarsResult<Self> {
// note that the by_column argument also contains evaluated expression from polars-lazy
// that may not even be present in this dataframe.
// note that the by_column argument also contains evaluated expression from
// polars-lazy that may not even be present in this dataframe.

// therefore when we try to set the first columns as sorted, we ignore the error
// as expressions are not present (they are renamed to _POLARS_SORT_COLUMN_i.
let first_descending = sort_options.descending[0];
let first_by_column = by_column[0].name().to_string();

let set_sorted = |df: &mut DataFrame| {
// Mark the first sort column as sorted
// if the column did not exists it is ok, because we sorted by an expression
// not present in the dataframe
// Mark the first sort column as sorted; if the column does not exist it
// is ok, because we sorted by an expression not present in the dataframe
let _ = df.apply(&first_by_column, |s| {
let mut s = s.clone();
if first_descending {
Expand All @@ -1795,14 +1794,11 @@ impl DataFrame {
s
});
};

if self.is_empty() {
let mut out = self.clone();
set_sorted(&mut out);

return Ok(out);
}

if let Some((0, k)) = slice {
return self.bottom_k_impl(k, by_column, sort_options);
}
Expand All @@ -1824,7 +1820,7 @@ impl DataFrame {
let s = &by_column[0];
let options = SortOptions {
descending: sort_options.descending[0],
nulls_last: sort_options.nulls_last,
nulls_last: sort_options.nulls_last[0],
multithreaded: sort_options.multithreaded,
maintain_order: sort_options.maintain_order,
};
Expand All @@ -1836,13 +1832,12 @@ impl DataFrame {
if let Some((offset, len)) = slice {
out = out.slice(offset, len);
}

return Ok(out.into_frame());
}
s.arg_sort(options)
},
_ => {
if sort_options.nulls_last
if sort_options.nulls_last.iter().all(|&x| x)
|| has_struct
|| std::env::var("POLARS_ROW_FMT_SORT").is_ok()
{
Expand Down Expand Up @@ -1899,7 +1894,7 @@ impl DataFrame {
/// df.sort(
/// &["sepal_width", "sepal_length"],
/// SortMultipleOptions::new()
/// .with_order_descendings([false, true])
/// .with_order_descending_multi([false, true])
/// )
/// }
/// ```
Expand Down
Loading

0 comments on commit 8710274

Please sign in to comment.