Skip to content

Commit

Permalink
[FEAT]: sql concat and stddev (#3153)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored Oct 31, 2024
1 parent f966e02 commit 301cd48
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 7 deletions.
54 changes: 49 additions & 5 deletions src/daft-sql/src/modules/aggs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub struct SQLModuleAggs;

impl SQLModule for SQLModuleAggs {
fn register(parent: &mut SQLFunctions) {
use AggExpr::{Count, Max, Mean, Min, Sum};
use AggExpr::{Count, Max, Mean, Min, Stddev, Sum};
// HACK TO USE AggExpr as an enum rather than a
let nil = Arc::new(Expr::Literal(LiteralValue::Null));
parent.add_fn(
Expand All @@ -27,7 +27,9 @@ impl SQLModule for SQLModuleAggs {
parent.add_fn("avg", Mean(nil.clone()));
parent.add_fn("mean", Mean(nil.clone()));
parent.add_fn("min", Min(nil.clone()));
parent.add_fn("max", Max(nil));
parent.add_fn("max", Max(nil.clone()));
parent.add_fn("stddev", Stddev(nil.clone()));
parent.add_fn("stddev_samp", Stddev(nil));
}
}

Expand All @@ -49,15 +51,19 @@ impl SQLFunction for AggExpr {
Self::Mean(_) => static_docs::AVG_DOCSTRING.replace("{}", alias),
Self::Min(_) => static_docs::MIN_DOCSTRING.to_string(),
Self::Max(_) => static_docs::MAX_DOCSTRING.to_string(),
Self::Stddev(_) => static_docs::STDDEV_DOCSTRING.to_string(),
e => unimplemented!("Need to implement docstrings for {e}"),
}
}

fn arg_names(&self) -> &'static [&'static str] {
match self {
Self::Count(_, _) | Self::Sum(_) | Self::Mean(_) | Self::Min(_) | Self::Max(_) => {
&["input"]
}
Self::Count(_, _)
| Self::Sum(_)
| Self::Mean(_)
| Self::Min(_)
| Self::Max(_)
| Self::Stddev(_) => &["input"],
e => unimplemented!("Need to implement arg names for {e}"),
}
}
Expand Down Expand Up @@ -324,4 +330,42 @@ Example:
│ 200 │
╰───────╯
(Showing first 1 of 1 rows)";

pub(crate) const STDDEV_DOCSTRING: &str =
"Calculates the standard deviation of non-null elements in the input expression.
Example:
.. code-block:: sql
:caption: SQL
SELECT stddev(x) FROM tbl
.. code-block:: text
:caption: Input
╭───────╮
│ x │
│ --- │
│ Int64 │
╞═══════╡
│ 100 │
├╌╌╌╌╌╌╌┤
│ 200 │
├╌╌╌╌╌╌╌┤
│ null │
╰───────╯
(Showing first 3 of 3 rows)
.. code-block:: text
:caption: Output
╭──────────────╮
│ x │
│ --- │
│ Float64 │
╞══════════════╡
│ 70.710678118 │
╰──────────────╯
(Showing first 1 of 1 rows)";
}
33 changes: 32 additions & 1 deletion src/daft-sql/src/modules/utf8.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use daft_core::array::ops::Utf8NormalizeOptions;
use daft_dsl::{
binary_op,
functions::{
self,
utf8::{normalize, Utf8Expr},
},
ExprRef, LiteralValue,
ExprRef, LiteralValue, Operator,
};
use daft_functions::{
count_matches::{utf8_count_matches, CountMatchesFunction},
Expand Down Expand Up @@ -62,6 +63,7 @@ impl SQLModule for SQLModuleUtf8 {
parent.add_fn("normalize", SQLNormalize);
parent.add_fn("tokenize_encode", SQLTokenizeEncode);
parent.add_fn("tokenize_decode", SQLTokenizeDecode);
parent.add_fn("concat", SQLConcat);
}
}

Expand Down Expand Up @@ -539,3 +541,32 @@ impl SQLFunction for SQLTokenizeDecode {
}
}
}

pub struct SQLConcat;

impl SQLFunction for SQLConcat {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> SQLPlannerResult<ExprRef> {
let inputs = inputs
.iter()
.map(|input| planner.plan_function_arg(input))
.collect::<SQLPlannerResult<Vec<_>>>()?;
let mut inputs = inputs.into_iter();

let Some(mut first) = inputs.next() else {
invalid_operation_err!("concat requires at least one argument")
};
for input in inputs {
first = binary_op(Operator::Plus, first, input);
}

Ok(first)
}

fn docstrings(&self, _: &str) -> String {
"Concatenate the inputs into a single string".to_string()
}
}
43 changes: 43 additions & 0 deletions tests/sql/test_aggs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import daft
from daft import col


def test_aggs_sql():
df = daft.from_pydict(
{
"id": [1, 2, 3, 3, 3, 3, 2, 1, 3, 1],
"values": [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5],
"floats": [0.01, 0.011, 0.01047, 0.02, 0.019, 0.018, 0.017, 0.016, 0.015, 0.014],
}
)
expected = (
df.agg(
[
col("values").sum().alias("sum"),
col("values").mean().alias("mean"),
col("values").min().alias("min"),
col("values").max().alias("max"),
col("values").count().alias("count"),
col("values").stddev().alias("std"),
]
)
.collect()
.to_pydict()
)

actual = (
daft.sql("""
SELECT
sum(values) as sum,
mean(values) as mean,
min(values) as min,
max(values) as max,
count(values) as count,
stddev(values) as std
FROM df
""")
.collect()
.to_pydict()
)

assert actual == expected
11 changes: 10 additions & 1 deletion tests/sql/test_utf8_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def test_utf8_exprs():
normalize(a, remove_punct:=true, lowercase:=true) as normalize_remove_punct_lower_a,
normalize(a, remove_punct:=true, lowercase:=true, white_space:=true) as normalize_remove_punct_lower_ws_a,
tokenize_encode(a, 'r50k_base') as tokenize_encode_a,
tokenize_decode(tokenize_encode(a, 'r50k_base'), 'r50k_base') as tokenize_decode_a
tokenize_decode(tokenize_encode(a, 'r50k_base'), 'r50k_base') as tokenize_decode_a,
concat(a, '---') as concat_a,
concat('--', a, a, a, '--') as concat_multi_a
FROM df
"""
actual = daft.sql(sql).collect()
Expand Down Expand Up @@ -105,6 +107,13 @@ def test_utf8_exprs():
.alias("normalize_remove_punct_lower_ws_a"),
col("a").str.tokenize_encode("r50k_base").alias("tokenize_encode_a"),
col("a").str.tokenize_encode("r50k_base").str.tokenize_decode("r50k_base").alias("tokenize_decode_a"),
col("a").str.concat("---").alias("concat_a"),
daft.lit("--")
.str.concat(col("a"))
.str.concat(col("a"))
.str.concat(col("a"))
.str.concat("--")
.alias("concat_multi_a"),
)
.collect()
.to_pydict()
Expand Down

0 comments on commit 301cd48

Please sign in to comment.