Skip to content

Commit

Permalink
feat: Update DataFrame.pivot to allow index=None when values is…
Browse files Browse the repository at this point in the history
… set (#17126)
  • Loading branch information
MarcoGorelli authored Jun 22, 2024
1 parent 9391c03 commit 3117ab1
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 116 deletions.
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/frame/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl PhysicalAggExpr for PivotExpr {
pub fn pivot<I0, I1, I2, S0, S1, S2>(
df: &DataFrame,
on: I0,
index: I1,
index: Option<I1>,
values: Option<I2>,
sort_columns: bool,
agg_expr: Option<Expr>,
Expand All @@ -59,7 +59,7 @@ where
pub fn pivot_stable<I0, I1, I2, S0, S1, S2>(
df: &DataFrame,
on: I0,
index: I1,
index: Option<I1>,
values: Option<I2>,
sort_columns: bool,
agg_expr: Option<Expr>,
Expand Down
74 changes: 45 additions & 29 deletions crates/polars-ops/src/frame/pivot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series {
pub fn pivot<I0, I1, I2, S0, S1, S2>(
pivot_df: &DataFrame,
on: I0,
index: I1,
index: Option<I1>,
values: Option<I2>,
sort_columns: bool,
agg_fn: Option<PivotAgg>,
Expand All @@ -99,15 +99,11 @@ where
S1: AsRef<str>,
S2: AsRef<str>,
{
let index = index
.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
let on = on
.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
let values = get_values_columns(pivot_df, &index, &on, values);
let (index, values) = assign_remaining_columns(pivot_df, &on, index, values)?;
pivot_impl(
pivot_df,
&on,
Expand All @@ -128,7 +124,7 @@ where
pub fn pivot_stable<I0, I1, I2, S0, S1, S2>(
pivot_df: &DataFrame,
on: I0,
index: I1,
index: Option<I1>,
values: Option<I2>,
sort_columns: bool,
agg_fn: Option<PivotAgg>,
Expand All @@ -142,15 +138,11 @@ where
S1: AsRef<str>,
S2: AsRef<str>,
{
let index = index
.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
let on = on
.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
let values = get_values_columns(pivot_df, &index, &on, values);
let (index, values) = assign_remaining_columns(pivot_df, &on, index, values)?;
pivot_impl(
pivot_df,
&on,
Expand All @@ -163,28 +155,52 @@ where
)
}

/// Determine `values` columns, which is optional in `pivot` calls.
/// Ensure both `index` and `values` are populated with `Vec<String>`.
///
/// If not specified (i.e. is `None`), use all remaining columns in the
/// `DataFrame` after `index` and `columns` have been excluded.
fn get_values_columns<I, S>(
/// - If `index` is None, assign columns not in `on` and `values` to it.
/// - If `values` is None, assign columns not in `on` and `index` to it.
/// - At least one of `index` and `values` must be non-null.
fn assign_remaining_columns<I1, I2, S1, S2>(
df: &DataFrame,
index: &[String],
on: &[String],
values: Option<I>,
) -> Vec<String>
index: Option<I1>,
values: Option<I2>,
) -> PolarsResult<(Vec<String>, Vec<String>)>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
I1: IntoIterator<Item = S1>,
I2: IntoIterator<Item = S2>,
S1: AsRef<str>,
S2: AsRef<str>,
{
match values {
Some(v) => v.into_iter().map(|s| s.as_ref().to_string()).collect(),
None => df
.get_column_names()
.into_iter()
.map(|c| c.to_string())
.filter(|c| !(index.contains(c) | on.contains(c)))
.collect(),
match (index, values) {
(Some(index), Some(values)) => {
let index = index.into_iter().map(|s| s.as_ref().to_string()).collect();
let values = values.into_iter().map(|s| s.as_ref().to_string()).collect();
Ok((index, values))
},
(Some(index), None) => {
let index: Vec<String> = index.into_iter().map(|s| s.as_ref().to_string()).collect();
let values = df
.get_column_names()
.into_iter()
.map(|s| s.to_string())
.filter(|c| !(index.contains(c) | on.contains(c)))
.collect();
Ok((index, values))
},
(None, Some(values)) => {
let values: Vec<String> = values.into_iter().map(|s| s.as_ref().to_string()).collect();
let index = df
.get_column_names()
.into_iter()
.map(|s| s.to_string())
.filter(|c| !(values.contains(c) | on.contains(c)))
.collect();
Ok((index, values))
},
(None, None) => {
polars_bail!(InvalidOperation: "`index` and `values` cannot both be None in `pivot` operation")
},
}
}

Expand Down
24 changes: 12 additions & 12 deletions crates/polars/tests/it/core/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn test_pivot_date_() -> PolarsResult<()> {
let out = pivot(
&df,
["values1"],
["index"],
Some(["index"]),
Some(["values2"]),
true,
Some(PivotAgg::Count),
Expand All @@ -34,7 +34,7 @@ fn test_pivot_date_() -> PolarsResult<()> {
let mut out = pivot_stable(
&df,
["values2"],
["index"],
Some(["index"]),
Some(["values1"]),
true,
Some(PivotAgg::First),
Expand Down Expand Up @@ -64,7 +64,7 @@ fn test_pivot_old() {
let pvt = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::Sum),
Expand All @@ -79,7 +79,7 @@ fn test_pivot_old() {
let pvt = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::Min),
Expand All @@ -93,7 +93,7 @@ fn test_pivot_old() {
let pvt = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::Max),
Expand All @@ -107,7 +107,7 @@ fn test_pivot_old() {
let pvt = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::Mean),
Expand All @@ -121,7 +121,7 @@ fn test_pivot_old() {
let pvt = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::Count),
Expand Down Expand Up @@ -149,7 +149,7 @@ fn test_pivot_categorical() -> PolarsResult<()> {
let out = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
true,
Some(PivotAgg::Count),
Expand All @@ -174,7 +174,7 @@ fn test_pivot_new() -> PolarsResult<()> {
let out = (pivot_stable(
&df,
["cols1"],
["index1", "index2"],
Some(["index1", "index2"]),
Some(["values1"]),
true,
Some(PivotAgg::Sum),
Expand All @@ -191,7 +191,7 @@ fn test_pivot_new() -> PolarsResult<()> {
let out = pivot_stable(
&df,
["cols1", "cols2"],
["index1", "index2"],
Some(["index1", "index2"]),
Some(["values1"]),
true,
Some(PivotAgg::Sum),
Expand Down Expand Up @@ -222,7 +222,7 @@ fn test_pivot_2() -> PolarsResult<()> {
let out = pivot_stable(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::First),
Expand Down Expand Up @@ -255,7 +255,7 @@ fn test_pivot_datetime() -> PolarsResult<()> {
let out = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::Sum),
Expand Down
9 changes: 8 additions & 1 deletion docs/releases/upgrade/1.md
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ After:
... "test_2": [100, 100, 60, 60],
... }
... )
>>> df.pivot(index='name', on='subject', values=['test_1', 'test_2'])
>>> df.pivot('subject', index='name')
┌───────┬──────────────┬────────────────┬──────────────┬────────────────┐
│ name ┆ test_1_maths ┆ test_1_physics ┆ test_2_maths ┆ test_2_physics │
---------------
Expand All @@ -404,6 +404,13 @@ After:
└───────┴──────────────┴────────────────┴──────────────┴────────────────┘
```

Note that the function signature has also changed:

- `columns` has been renamed to `on`, and is now the first positional argument.
- `index` and `values` are both optional. If `index` is not specified, then it
will use all columns not specified in `on` and `values`. If `values` is
not specified, it will use all columns not specified in `on` and `index`.

### Support Decimal types by default when converting from Arrow

Update conversion from Arrow to always convert Decimals into Polars Decimals, rather than cast to Float64.
Expand Down
4 changes: 2 additions & 2 deletions docs/src/rust/user-guide/transformations/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// --8<-- [end:df]

// --8<-- [start:eager]
let out = pivot(&df, ["foo"], ["bar"], Some(["N"]), false, None, None)?;
let out = pivot(&df, ["foo"], Some(["bar"]), Some(["N"]), false, None, None)?;
println!("{}", &out);
// --8<-- [end:eager]

Expand All @@ -23,7 +23,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let q2 = pivot(
&q.collect()?,
["foo"],
["bar"],
Some(["bar"]),
Some(["N"]),
false,
None,
Expand Down
Loading

0 comments on commit 3117ab1

Please sign in to comment.