Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Update DataFrame.pivot to allow index=None when values is set #17126

Merged
merged 1 commit into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/frame/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl PhysicalAggExpr for PivotExpr {
pub fn pivot<I0, I1, I2, S0, S1, S2>(
df: &DataFrame,
on: I0,
index: I1,
index: Option<I1>,
values: Option<I2>,
sort_columns: bool,
agg_expr: Option<Expr>,
Expand All @@ -59,7 +59,7 @@ where
pub fn pivot_stable<I0, I1, I2, S0, S1, S2>(
df: &DataFrame,
on: I0,
index: I1,
index: Option<I1>,
values: Option<I2>,
sort_columns: bool,
agg_expr: Option<Expr>,
Expand Down
74 changes: 45 additions & 29 deletions crates/polars-ops/src/frame/pivot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series {
pub fn pivot<I0, I1, I2, S0, S1, S2>(
pivot_df: &DataFrame,
on: I0,
index: I1,
index: Option<I1>,
values: Option<I2>,
sort_columns: bool,
agg_fn: Option<PivotAgg>,
Expand All @@ -99,15 +99,11 @@ where
S1: AsRef<str>,
S2: AsRef<str>,
{
let index = index
.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
let on = on
.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
let values = get_values_columns(pivot_df, &index, &on, values);
let (index, values) = assign_remaining_columns(pivot_df, &on, index, values)?;
pivot_impl(
pivot_df,
&on,
Expand All @@ -128,7 +124,7 @@ where
pub fn pivot_stable<I0, I1, I2, S0, S1, S2>(
pivot_df: &DataFrame,
on: I0,
index: I1,
index: Option<I1>,
values: Option<I2>,
sort_columns: bool,
agg_fn: Option<PivotAgg>,
Expand All @@ -142,15 +138,11 @@ where
S1: AsRef<str>,
S2: AsRef<str>,
{
let index = index
.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
let on = on
.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
let values = get_values_columns(pivot_df, &index, &on, values);
let (index, values) = assign_remaining_columns(pivot_df, &on, index, values)?;
pivot_impl(
pivot_df,
&on,
Expand All @@ -163,28 +155,52 @@ where
)
}

/// Determine `values` columns, which is optional in `pivot` calls.
/// Ensure both `index` and `values` are populated with `Vec<String>`.
///
/// If not specified (i.e. is `None`), use all remaining columns in the
/// `DataFrame` after `index` and `columns` have been excluded.
fn get_values_columns<I, S>(
/// - If `index` is None, assign columns not in `on` and `values` to it.
/// - If `values` is None, assign columns not in `on` and `index` to it.
/// - At least one of `index` and `values` must be non-null.
fn assign_remaining_columns<I1, I2, S1, S2>(
df: &DataFrame,
index: &[String],
on: &[String],
values: Option<I>,
) -> Vec<String>
index: Option<I1>,
values: Option<I2>,
) -> PolarsResult<(Vec<String>, Vec<String>)>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
I1: IntoIterator<Item = S1>,
I2: IntoIterator<Item = S2>,
S1: AsRef<str>,
S2: AsRef<str>,
{
match values {
Some(v) => v.into_iter().map(|s| s.as_ref().to_string()).collect(),
None => df
.get_column_names()
.into_iter()
.map(|c| c.to_string())
.filter(|c| !(index.contains(c) | on.contains(c)))
.collect(),
match (index, values) {
(Some(index), Some(values)) => {
let index = index.into_iter().map(|s| s.as_ref().to_string()).collect();
let values = values.into_iter().map(|s| s.as_ref().to_string()).collect();
Ok((index, values))
},
(Some(index), None) => {
let index: Vec<String> = index.into_iter().map(|s| s.as_ref().to_string()).collect();
let values = df
.get_column_names()
.into_iter()
.map(|s| s.to_string())
.filter(|c| !(index.contains(c) | on.contains(c)))
.collect();
Ok((index, values))
},
(None, Some(values)) => {
let values: Vec<String> = values.into_iter().map(|s| s.as_ref().to_string()).collect();
let index = df
.get_column_names()
.into_iter()
.map(|s| s.to_string())
.filter(|c| !(values.contains(c) | on.contains(c)))
.collect();
Ok((index, values))
},
(None, None) => {
polars_bail!(InvalidOperation: "`index` and `values` cannot both be None in `pivot` operation")
},
}
}

Expand Down
24 changes: 12 additions & 12 deletions crates/polars/tests/it/core/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fn test_pivot_date_() -> PolarsResult<()> {
let out = pivot(
&df,
["values1"],
["index"],
Some(["index"]),
Some(["values2"]),
true,
Some(PivotAgg::Count),
Expand All @@ -34,7 +34,7 @@ fn test_pivot_date_() -> PolarsResult<()> {
let mut out = pivot_stable(
&df,
["values2"],
["index"],
Some(["index"]),
Some(["values1"]),
true,
Some(PivotAgg::First),
Expand Down Expand Up @@ -64,7 +64,7 @@ fn test_pivot_old() {
let pvt = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::Sum),
Expand All @@ -79,7 +79,7 @@ fn test_pivot_old() {
let pvt = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::Min),
Expand All @@ -93,7 +93,7 @@ fn test_pivot_old() {
let pvt = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::Max),
Expand All @@ -107,7 +107,7 @@ fn test_pivot_old() {
let pvt = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::Mean),
Expand All @@ -121,7 +121,7 @@ fn test_pivot_old() {
let pvt = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::Count),
Expand Down Expand Up @@ -149,7 +149,7 @@ fn test_pivot_categorical() -> PolarsResult<()> {
let out = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
true,
Some(PivotAgg::Count),
Expand All @@ -174,7 +174,7 @@ fn test_pivot_new() -> PolarsResult<()> {
let out = (pivot_stable(
&df,
["cols1"],
["index1", "index2"],
Some(["index1", "index2"]),
Some(["values1"]),
true,
Some(PivotAgg::Sum),
Expand All @@ -191,7 +191,7 @@ fn test_pivot_new() -> PolarsResult<()> {
let out = pivot_stable(
&df,
["cols1", "cols2"],
["index1", "index2"],
Some(["index1", "index2"]),
Some(["values1"]),
true,
Some(PivotAgg::Sum),
Expand Down Expand Up @@ -222,7 +222,7 @@ fn test_pivot_2() -> PolarsResult<()> {
let out = pivot_stable(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::First),
Expand Down Expand Up @@ -255,7 +255,7 @@ fn test_pivot_datetime() -> PolarsResult<()> {
let out = pivot(
&df,
["columns"],
["index"],
Some(["index"]),
Some(["values"]),
false,
Some(PivotAgg::Sum),
Expand Down
9 changes: 8 additions & 1 deletion docs/releases/upgrade/1.md
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ After:
... "test_2": [100, 100, 60, 60],
... }
... )
>>> df.pivot(index='name', on='subject', values=['test_1', 'test_2'])
>>> df.pivot('subject', index='name')
┌───────┬──────────────┬────────────────┬──────────────┬────────────────┐
│ name ┆ test_1_maths ┆ test_1_physics ┆ test_2_maths ┆ test_2_physics │
---------------
Expand All @@ -404,6 +404,13 @@ After:
└───────┴──────────────┴────────────────┴──────────────┴────────────────┘
```

Note that the function signature has also changed:

- `columns` has been renamed to `on`, and is now the first positional argument.
- `index` and `values` are both optional. If `index` is not specified, then it
will use all columns not specified in `on` and `values`. If `values` is
not specified, it will use all columns not specified in `on` and `index`.

### Support Decimal types by default when converting from Arrow

Update conversion from Arrow to always convert Decimals into Polars Decimals, rather than cast to Float64.
Expand Down
4 changes: 2 additions & 2 deletions docs/src/rust/user-guide/transformations/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// --8<-- [end:df]

// --8<-- [start:eager]
let out = pivot(&df, ["foo"], ["bar"], Some(["N"]), false, None, None)?;
let out = pivot(&df, ["foo"], Some(["bar"]), Some(["N"]), false, None, None)?;
println!("{}", &out);
// --8<-- [end:eager]

Expand All @@ -23,7 +23,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let q2 = pivot(
&q.collect()?,
["foo"],
["bar"],
Some(["bar"]),
Some(["N"]),
false,
None,
Expand Down
Loading