Skip to content

Commit

Permalink
feat: Quantile function in SQL (#18047)
Browse files Browse the repository at this point in the history
  • Loading branch information
pomo-mondreganto authored Oct 8, 2024
1 parent 1f48036 commit 9dada18
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 1 deletion.
39 changes: 38 additions & 1 deletion crates/polars-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use std::ops::Sub;

use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions};
use polars_core::export::regex;
use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult, Schema, TimeUnit};
use polars_core::prelude::{
polars_bail, polars_err, DataType, PolarsResult, QuantileInterpolOptions, Schema, TimeUnit,
};
use polars_lazy::dsl::Expr;
#[cfg(feature = "list_eval")]
use polars_lazy::dsl::ListNameSpaceExtension;
Expand Down Expand Up @@ -504,6 +506,13 @@ pub(crate) enum PolarsSQLFunctions {
/// SELECT MEDIAN(column_1) FROM df;
/// ```
Median,
/// SQL 'quantile_cont' function
/// Returns the continuous quantile element from the grouping
/// (interpolated value between two closest values).
/// ```sql
/// SELECT QUANTILE_CONT(column_1) FROM df;
/// ```
QuantileCont,
/// SQL 'min' function
/// Returns the smallest (minimum) of all the elements in the grouping.
/// ```sql
Expand Down Expand Up @@ -686,6 +695,7 @@ impl PolarsSQLFunctions {
"pi",
"pow",
"power",
"quantile_cont",
"radians",
"regexp_like",
"replace",
Expand Down Expand Up @@ -818,6 +828,7 @@ impl PolarsSQLFunctions {
"last" => Self::Last,
"max" => Self::Max,
"median" => Self::Median,
"quantile_cont" => Self::QuantileCont,
"min" => Self::Min,
"stdev" | "stddev" | "stdev_samp" | "stddev_samp" => Self::StdDev,
"sum" => Self::Sum,
Expand Down Expand Up @@ -1243,6 +1254,32 @@ impl SQLFunctionVisitor<'_> {
Last => self.visit_unary(Expr::last),
Max => self.visit_unary_with_opt_cumulative(Expr::max, Expr::cum_max),
Median => self.visit_unary(Expr::median),
QuantileCont => {
let args = extract_args(function)?;
match args.len() {
2 => self.try_visit_binary(|e, q| {
let value = match q {
Expr::Literal(LiteralValue::Float(f)) => {
if (0.0..=1.0).contains(&f) {
Expr::from(f)
} else {
polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
}
},
Expr::Literal(LiteralValue::Int(n)) => {
if (0..=1).contains(&n) {
Expr::from(n as f64)
} else {
polars_bail!(SQLSyntax: "QUANTILE_CONT value must be between 0 and 1 ({})", args[1])
}
},
_ => polars_bail!(SQLSyntax: "invalid value for QUANTILE_CONT ({})", args[1])
};
Ok(e.quantile(value, QuantileInterpolOptions::Linear))
}),
_ => polars_bail!(SQLSyntax: "QUANTILE_CONT expects 2 arguments (found {})", args.len()),
}
},
Min => self.visit_unary_with_opt_cumulative(Expr::min, Expr::cum_min),
StdDev => self.visit_unary(|e| e.std(1)),
Sum => self.visit_unary_with_opt_cumulative(Expr::sum, Expr::cum_sum),
Expand Down
65 changes: 65 additions & 0 deletions crates/polars-sql/tests/functions_aggregate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use polars_core::prelude::*;
use polars_lazy::prelude::*;
use polars_plan::dsl::Expr;
use polars_sql::*;

fn create_df() -> LazyFrame {
df! {
"Year" => [2018, 2018, 2019, 2019, 2020, 2020],
"Country" => ["US", "UK", "US", "UK", "US", "UK"],
"Sales" => [1000, 2000, 3000, 4000, 5000, 6000]
}
.unwrap()
.lazy()
}

fn create_expected(expr: Expr, sql: &str) -> (DataFrame, DataFrame) {
let df = create_df();
let alias = "TEST";

let query = format!(
r#"
SELECT
{sql} as {alias}
FROM
df
"#
);

let expected = df
.clone()
.select(&[expr.alias(alias)])
.sort([alias], Default::default())
.collect()
.unwrap();
let mut ctx = SQLContext::new();
ctx.register("df", df);

let actual = ctx.execute(&query).unwrap().collect().unwrap();
(expected, actual)
}

#[test]
fn test_median() {
let expr = col("Sales").median();

let sql_expr = "MEDIAN(Sales)";
let (expected, actual) = create_expected(expr, sql_expr);

assert!(expected.equals(&actual))
}

#[test]
fn test_quantile_cont() {
for &q in &[0.25, 0.5, 0.75] {
let expr = col("Sales").quantile(lit(q), QuantileInterpolOptions::Linear);

let sql_expr = format!("QUANTILE_CONT(Sales, {})", q);
let (expected, actual) = create_expected(expr, &sql_expr);

assert!(
expected.equals(&actual),
"q: {q}: expected {expected:?}, got {actual:?}"
)
}
}

0 comments on commit 9dada18

Please sign in to comment.