diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 3598a4042d..5fe1307747 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1095,34 +1095,6 @@ class PyExpr: def __repr__(self) -> str: ... def __hash__(self) -> int: ... def __reduce__(self) -> tuple: ... - def utf8_endswith(self, pattern: PyExpr) -> PyExpr: ... - def utf8_startswith(self, pattern: PyExpr) -> PyExpr: ... - def utf8_contains(self, pattern: PyExpr) -> PyExpr: ... - def utf8_match(self, pattern: PyExpr) -> PyExpr: ... - def utf8_split(self, pattern: PyExpr, regex: bool) -> PyExpr: ... - def utf8_extract(self, pattern: PyExpr, index: int) -> PyExpr: ... - def utf8_extract_all(self, pattern: PyExpr, index: int) -> PyExpr: ... - def utf8_replace(self, pattern: PyExpr, replacement: PyExpr, regex: bool) -> PyExpr: ... - def utf8_length(self) -> PyExpr: ... - def utf8_length_bytes(self) -> PyExpr: ... - def utf8_lower(self) -> PyExpr: ... - def utf8_upper(self) -> PyExpr: ... - def utf8_lstrip(self) -> PyExpr: ... - def utf8_rstrip(self) -> PyExpr: ... - def utf8_reverse(self) -> PyExpr: ... - def utf8_capitalize(self) -> PyExpr: ... - def utf8_left(self, nchars: PyExpr) -> PyExpr: ... - def utf8_right(self, nchars: PyExpr) -> PyExpr: ... - def utf8_find(self, substr: PyExpr) -> PyExpr: ... - def utf8_rpad(self, length: PyExpr, pad: PyExpr) -> PyExpr: ... - def utf8_lpad(self, length: PyExpr, pad: PyExpr) -> PyExpr: ... - def utf8_repeat(self, n: PyExpr) -> PyExpr: ... - def utf8_like(self, pattern: PyExpr) -> PyExpr: ... - def utf8_ilike(self, pattern: PyExpr) -> PyExpr: ... - def utf8_substr(self, start: PyExpr, length: PyExpr) -> PyExpr: ... - def utf8_to_date(self, format: str) -> PyExpr: ... - def utf8_to_datetime(self, format: str, timezone: str | None = None) -> PyExpr: ... - def utf8_normalize(self, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PyExpr: ... def struct_get(self, name: str) -> PyExpr: ... def map_get(self, key: PyExpr) -> PyExpr: ... def partitioning_days(self) -> PyExpr: ... @@ -1320,6 +1292,40 @@ def list_max(expr: PyExpr) -> PyExpr: ... def list_slice(expr: PyExpr, start: PyExpr, end: PyExpr | None = None) -> PyExpr: ... def list_chunk(expr: PyExpr, size: int) -> PyExpr: ... +# --- +# expr.utf8 namespace +# --- +def utf8_endswith(expr: PyExpr, pattern: PyExpr) -> PyExpr: ... +def utf8_startswith(expr: PyExpr, pattern: PyExpr) -> PyExpr: ... +def utf8_contains(expr: PyExpr, pattern: PyExpr) -> PyExpr: ... +def utf8_match(expr: PyExpr, pattern: PyExpr) -> PyExpr: ... +def utf8_split(expr: PyExpr, pattern: PyExpr, regex: bool) -> PyExpr: ... +def utf8_extract(expr: PyExpr, pattern: PyExpr, index: int) -> PyExpr: ... +def utf8_extract_all(expr: PyExpr, pattern: PyExpr, index: int) -> PyExpr: ... +def utf8_replace(expr: PyExpr, pattern: PyExpr, replacement: PyExpr, regex: bool) -> PyExpr: ... +def utf8_length(expr: PyExpr) -> PyExpr: ... +def utf8_length_bytes(expr: PyExpr) -> PyExpr: ... +def utf8_lower(expr: PyExpr) -> PyExpr: ... +def utf8_upper(expr: PyExpr) -> PyExpr: ... +def utf8_lstrip(expr: PyExpr) -> PyExpr: ... +def utf8_rstrip(expr: PyExpr) -> PyExpr: ... +def utf8_reverse(expr: PyExpr) -> PyExpr: ... +def utf8_capitalize(expr: PyExpr) -> PyExpr: ... +def utf8_left(expr: PyExpr, nchars: PyExpr) -> PyExpr: ... +def utf8_right(expr: PyExpr, nchars: PyExpr) -> PyExpr: ... +def utf8_find(expr: PyExpr, substr: PyExpr) -> PyExpr: ... +def utf8_rpad(expr: PyExpr, length: PyExpr, pad: PyExpr) -> PyExpr: ... +def utf8_lpad(expr: PyExpr, length: PyExpr, pad: PyExpr) -> PyExpr: ... +def utf8_repeat(expr: PyExpr, n: PyExpr) -> PyExpr: ... +def utf8_like(expr: PyExpr, pattern: PyExpr) -> PyExpr: ... +def utf8_ilike(expr: PyExpr, pattern: PyExpr) -> PyExpr: ... +def utf8_substr(expr: PyExpr, start: PyExpr, length: PyExpr) -> PyExpr: ... +def utf8_to_date(expr: PyExpr, format: str) -> PyExpr: ... +def utf8_to_datetime(expr: PyExpr, format: str, timezone: str | None = None) -> PyExpr: ... +def utf8_normalize( + expr: PyExpr, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool +) -> PyExpr: ... + class PyCatalog: @staticmethod def new() -> PyCatalog: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 44bbc302e8..5a4cb1c7c6 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1887,7 +1887,7 @@ def contains(self, substr: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value contains the provided pattern """ substr_expr = Expression._to_expression(substr) - return Expression._from_pyexpr(self._expr.utf8_contains(substr_expr._expr)) + return Expression._from_pyexpr(native.utf8_contains(self._expr, substr_expr._expr)) def match(self, pattern: str | Expression) -> Expression: """Checks whether each string matches the given regular expression pattern in a string column @@ -1917,7 +1917,7 @@ def match(self, pattern: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value matches the provided pattern """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(self._expr.utf8_match(pattern_expr._expr)) + return Expression._from_pyexpr(native.utf8_match(self._expr, pattern_expr._expr)) def endswith(self, suffix: str | Expression) -> Expression: """Checks whether each string ends with the given pattern in a string column @@ -1947,7 +1947,7 @@ def endswith(self, suffix: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value ends with the provided pattern """ suffix_expr = Expression._to_expression(suffix) - return Expression._from_pyexpr(self._expr.utf8_endswith(suffix_expr._expr)) + return Expression._from_pyexpr(native.utf8_endswith(self._expr, suffix_expr._expr)) def startswith(self, prefix: str | Expression) -> Expression: """Checks whether each string starts with the given pattern in a string column @@ -1977,7 +1977,7 @@ def startswith(self, prefix: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value starts with the provided pattern """ prefix_expr = Expression._to_expression(prefix) - return Expression._from_pyexpr(self._expr.utf8_startswith(prefix_expr._expr)) + return Expression._from_pyexpr(native.utf8_startswith(self._expr, prefix_expr._expr)) def split(self, pattern: str | Expression, regex: bool = False) -> Expression: r"""Splits each string on the given literal or regex pattern, into a list of strings. @@ -2028,7 +2028,7 @@ def split(self, pattern: str | Expression, regex: bool = False) -> Expression: Expression: A List[Utf8] expression containing the string splits for each string in the column. """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(self._expr.utf8_split(pattern_expr._expr, regex)) + return Expression._from_pyexpr(native.utf8_split(self._expr, pattern_expr._expr, regex)) def concat(self, other: str | Expression) -> Expression: """Concatenates two string expressions together @@ -2119,7 +2119,7 @@ def extract(self, pattern: str | Expression, index: int = 0) -> Expression: `extract_all` """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(self._expr.utf8_extract(pattern_expr._expr, index)) + return Expression._from_pyexpr(native.utf8_extract(self._expr, pattern_expr._expr, index)) def extract_all(self, pattern: str | Expression, index: int = 0) -> Expression: r"""Extracts the specified match group from all regex matches in each string in a string column. @@ -2175,7 +2175,7 @@ def extract_all(self, pattern: str | Expression, index: int = 0) -> Expression: `extract` """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(self._expr.utf8_extract_all(pattern_expr._expr, index)) + return Expression._from_pyexpr(native.utf8_extract_all(self._expr, pattern_expr._expr, index)) def replace( self, @@ -2232,7 +2232,9 @@ def replace( """ pattern_expr = Expression._to_expression(pattern) replacement_expr = Expression._to_expression(replacement) - return Expression._from_pyexpr(self._expr.utf8_replace(pattern_expr._expr, replacement_expr._expr, regex)) + return Expression._from_pyexpr( + native.utf8_replace(self._expr, pattern_expr._expr, replacement_expr._expr, regex) + ) def length(self) -> Expression: """Retrieves the length for a UTF-8 string column @@ -2259,7 +2261,7 @@ def length(self) -> Expression: Returns: Expression: an UInt64 expression with the length of each string """ - return Expression._from_pyexpr(self._expr.utf8_length()) + return Expression._from_pyexpr(native.utf8_length(self._expr)) def length_bytes(self) -> Expression: """Retrieves the length for a UTF-8 string column in bytes. @@ -2286,7 +2288,7 @@ def length_bytes(self) -> Expression: Returns: Expression: an UInt64 expression with the length of each string """ - return Expression._from_pyexpr(self._expr.utf8_length_bytes()) + return Expression._from_pyexpr(native.utf8_length_bytes(self._expr)) def lower(self) -> Expression: """Convert UTF-8 string to all lowercase @@ -2313,7 +2315,7 @@ def lower(self) -> Expression: Returns: Expression: a String expression which is `self` lowercased """ - return Expression._from_pyexpr(self._expr.utf8_lower()) + return Expression._from_pyexpr(native.utf8_lower(self._expr)) def upper(self) -> Expression: """Convert UTF-8 string to all upper @@ -2340,7 +2342,7 @@ def upper(self) -> Expression: Returns: Expression: a String expression which is `self` uppercased """ - return Expression._from_pyexpr(self._expr.utf8_upper()) + return Expression._from_pyexpr(native.utf8_upper(self._expr)) def lstrip(self) -> Expression: """Strip whitespace from the left side of a UTF-8 string @@ -2367,7 +2369,7 @@ def lstrip(self) -> Expression: Returns: Expression: a String expression which is `self` with leading whitespace stripped """ - return Expression._from_pyexpr(self._expr.utf8_lstrip()) + return Expression._from_pyexpr(native.utf8_lstrip(self._expr)) def rstrip(self) -> Expression: """Strip whitespace from the right side of a UTF-8 string @@ -2394,7 +2396,7 @@ def rstrip(self) -> Expression: Returns: Expression: a String expression which is `self` with trailing whitespace stripped """ - return Expression._from_pyexpr(self._expr.utf8_rstrip()) + return Expression._from_pyexpr(native.utf8_rstrip(self._expr)) def reverse(self) -> Expression: """Reverse a UTF-8 string @@ -2421,7 +2423,7 @@ def reverse(self) -> Expression: Returns: Expression: a String expression which is `self` reversed """ - return Expression._from_pyexpr(self._expr.utf8_reverse()) + return Expression._from_pyexpr(native.utf8_reverse(self._expr)) def capitalize(self) -> Expression: """Capitalize a UTF-8 string @@ -2448,7 +2450,7 @@ def capitalize(self) -> Expression: Returns: Expression: a String expression which is `self` uppercased with the first character and lowercased the rest """ - return Expression._from_pyexpr(self._expr.utf8_capitalize()) + return Expression._from_pyexpr(native.utf8_capitalize(self._expr)) def left(self, nchars: int | Expression) -> Expression: """Gets the n (from nchars) left-most characters of each string @@ -2476,7 +2478,7 @@ def left(self, nchars: int | Expression) -> Expression: Expression: a String expression which is the `n` left-most characters of `self` """ nchars_expr = Expression._to_expression(nchars) - return Expression._from_pyexpr(self._expr.utf8_left(nchars_expr._expr)) + return Expression._from_pyexpr(native.utf8_left(self._expr, nchars_expr._expr)) def right(self, nchars: int | Expression) -> Expression: """Gets the n (from nchars) right-most characters of each string @@ -2504,7 +2506,7 @@ def right(self, nchars: int | Expression) -> Expression: Expression: a String expression which is the `n` right-most characters of `self` """ nchars_expr = Expression._to_expression(nchars) - return Expression._from_pyexpr(self._expr.utf8_right(nchars_expr._expr)) + return Expression._from_pyexpr(native.utf8_right(self._expr, nchars_expr._expr)) def find(self, substr: str | Expression) -> Expression: """Returns the index of the first occurrence of the substring in each string @@ -2536,7 +2538,7 @@ def find(self, substr: str | Expression) -> Expression: Expression: an Int64 expression with the index of the first occurrence of the substring in each string """ substr_expr = Expression._to_expression(substr) - return Expression._from_pyexpr(self._expr.utf8_find(substr_expr._expr)) + return Expression._from_pyexpr(native.utf8_find(self._expr, substr_expr._expr)) def rpad(self, length: int | Expression, pad: str | Expression) -> Expression: """Right-pads each string by truncating or padding with the character @@ -2569,7 +2571,7 @@ def rpad(self, length: int | Expression, pad: str | Expression) -> Expression: """ length_expr = Expression._to_expression(length) pad_expr = Expression._to_expression(pad) - return Expression._from_pyexpr(self._expr.utf8_rpad(length_expr._expr, pad_expr._expr)) + return Expression._from_pyexpr(native.utf8_rpad(self._expr, length_expr._expr, pad_expr._expr)) def lpad(self, length: int | Expression, pad: str | Expression) -> Expression: """Left-pads each string by truncating on the right or padding with the character @@ -2602,7 +2604,7 @@ def lpad(self, length: int | Expression, pad: str | Expression) -> Expression: """ length_expr = Expression._to_expression(length) pad_expr = Expression._to_expression(pad) - return Expression._from_pyexpr(self._expr.utf8_lpad(length_expr._expr, pad_expr._expr)) + return Expression._from_pyexpr(native.utf8_lpad(self._expr, length_expr._expr, pad_expr._expr)) def repeat(self, n: int | Expression) -> Expression: """Repeats each string n times @@ -2630,7 +2632,7 @@ def repeat(self, n: int | Expression) -> Expression: Expression: a String expression which is `self` repeated `n` times """ n_expr = Expression._to_expression(n) - return Expression._from_pyexpr(self._expr.utf8_repeat(n_expr._expr)) + return Expression._from_pyexpr(native.utf8_repeat(self._expr, n_expr._expr)) def like(self, pattern: str | Expression) -> Expression: """Checks whether each string matches the given SQL LIKE pattern, case sensitive @@ -2661,7 +2663,7 @@ def like(self, pattern: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value matches the provided pattern """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(self._expr.utf8_like(pattern_expr._expr)) + return Expression._from_pyexpr(native.utf8_like(self._expr, pattern_expr._expr)) def ilike(self, pattern: str | Expression) -> Expression: """Checks whether each string matches the given SQL LIKE pattern, case insensitive @@ -2692,7 +2694,7 @@ def ilike(self, pattern: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value matches the provided pattern """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(self._expr.utf8_ilike(pattern_expr._expr)) + return Expression._from_pyexpr(native.utf8_ilike(self._expr, pattern_expr._expr)) def substr(self, start: int | Expression, length: int | Expression | None = None) -> Expression: """Extract a substring from a string, starting at a specified index and extending for a given length. @@ -2724,7 +2726,7 @@ def substr(self, start: int | Expression, length: int | Expression | None = None """ start_expr = Expression._to_expression(start) length_expr = Expression._to_expression(length) - return Expression._from_pyexpr(self._expr.utf8_substr(start_expr._expr, length_expr._expr)) + return Expression._from_pyexpr(native.utf8_substr(self._expr, start_expr._expr, length_expr._expr)) def to_date(self, format: str) -> Expression: """Converts a string to a date using the specified format @@ -2755,7 +2757,7 @@ def to_date(self, format: str) -> Expression: Returns: Expression: a Date expression which is parsed by given format """ - return Expression._from_pyexpr(self._expr.utf8_to_date(format)) + return Expression._from_pyexpr(native.utf8_to_date(self._expr, format)) def to_datetime(self, format: str, timezone: str | None = None) -> Expression: """Converts a string to a datetime using the specified format and timezone @@ -2805,7 +2807,7 @@ def to_datetime(self, format: str, timezone: str | None = None) -> Expression: Returns: Expression: a DateTime expression which is parsed by given format and timezone """ - return Expression._from_pyexpr(self._expr.utf8_to_datetime(format, timezone)) + return Expression._from_pyexpr(native.utf8_to_datetime(self._expr, format, timezone)) def normalize( self, @@ -2849,7 +2851,9 @@ def normalize( Returns: Expression: a String expression which is normalized. """ - return Expression._from_pyexpr(self._expr.utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space)) + return Expression._from_pyexpr( + native.utf8_normalize(self._expr, remove_punct, lowercase, nfd_unicode, white_space) + ) def tokenize_encode( self, diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 48a7751197..33962c0cc9 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -5,7 +5,6 @@ pub mod python; pub mod scalar; pub mod sketch; pub mod struct_; -pub mod utf8; use std::{ fmt::{Display, Formatter, Result, Write}, @@ -18,15 +17,11 @@ use python::PythonUDF; pub use scalar::*; use serde::{Deserialize, Serialize}; -use self::{ - map::MapExpr, partitioning::PartitioningExpr, sketch::SketchExpr, struct_::StructExpr, - utf8::Utf8Expr, -}; +use self::{map::MapExpr, partitioning::PartitioningExpr, sketch::SketchExpr, struct_::StructExpr}; use crate::{Expr, ExprRef, Operator}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum FunctionExpr { - Utf8(Utf8Expr), Map(MapExpr), Sketch(SketchExpr), Struct(StructExpr), @@ -49,7 +44,6 @@ impl FunctionExpr { #[inline] fn get_evaluator(&self) -> &dyn FunctionEvaluator { match self { - Self::Utf8(expr) => expr.get_evaluator(), Self::Map(expr) => expr.get_evaluator(), Self::Sketch(expr) => expr.get_evaluator(), Self::Struct(expr) => expr.get_evaluator(), diff --git a/src/daft-dsl/src/functions/utf8/capitalize.rs b/src/daft-dsl/src/functions/utf8/capitalize.rs deleted file mode 100644 index caa3c25359..0000000000 --- a/src/daft-dsl/src/functions/utf8/capitalize.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct CapitalizeEvaluator {} - -impl FunctionEvaluator for CapitalizeEvaluator { - fn fn_name(&self) -> &'static str { - "capitalize" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to capitalize to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_capitalize(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/extract.rs b/src/daft-dsl/src/functions/utf8/extract.rs deleted file mode 100644 index abe9d4df16..0000000000 --- a/src/daft-dsl/src/functions/utf8/extract.rs +++ /dev/null @@ -1,51 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ExtractEvaluator {} - -impl FunctionEvaluator for ExtractEvaluator { - fn fn_name(&self) -> &'static str { - "extract" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { - (Ok(data_field), Ok(pattern_field)) => { - match (&data_field.dtype, &pattern_field.dtype) { - (DataType::Utf8, DataType::Utf8) => { - Ok(Field::new(data_field.name, DataType::Utf8)) - } - _ => Err(DaftError::TypeError(format!( - "Expects inputs to extract to be utf8, but received {data_field} and {pattern_field}", - ))), - } - } - (Err(e), _) | (_, Err(e)) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - match inputs { - [data, pattern] => { - let index = match expr { - FunctionExpr::Utf8(Utf8Expr::Extract(index)) => index, - _ => panic!("Expected Utf8 Extract Expr, got {expr}"), - }; - data.utf8_extract(pattern, *index) - } - _ => Err(DaftError::ValueError(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/extract_all.rs b/src/daft-dsl/src/functions/utf8/extract_all.rs deleted file mode 100644 index e2395e8c19..0000000000 --- a/src/daft-dsl/src/functions/utf8/extract_all.rs +++ /dev/null @@ -1,51 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ExtractAllEvaluator {} - -impl FunctionEvaluator for ExtractAllEvaluator { - fn fn_name(&self) -> &'static str { - "extractall" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { - (Ok(data_field), Ok(pattern_field)) => { - match (&data_field.dtype, &pattern_field.dtype) { - (DataType::Utf8, DataType::Utf8) => { - Ok(Field::new(data_field.name, DataType::List(Box::new(DataType::Utf8)))) - } - _ => Err(DaftError::TypeError(format!( - "Expects inputs to extractAll to be utf8, but received {data_field} and {pattern_field}", - ))), - } - } - (Err(e), _) | (_, Err(e)) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - match inputs { - [data, pattern] => { - let index = match expr { - FunctionExpr::Utf8(Utf8Expr::ExtractAll(index)) => index, - _ => panic!("Expected Utf8 ExtractAll Expr, got {expr}"), - }; - data.utf8_extract_all(pattern, *index) - } - _ => Err(DaftError::ValueError(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/length.rs b/src/daft-dsl/src/functions/utf8/length.rs deleted file mode 100644 index 9f4729ac76..0000000000 --- a/src/daft-dsl/src/functions/utf8/length.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct LengthEvaluator {} - -impl FunctionEvaluator for LengthEvaluator { - fn fn_name(&self) -> &'static str { - "length" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::UInt64)), - _ => Err(DaftError::TypeError(format!( - "Expects input to length to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_length(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/length_bytes.rs b/src/daft-dsl/src/functions/utf8/length_bytes.rs deleted file mode 100644 index cdf0af383a..0000000000 --- a/src/daft-dsl/src/functions/utf8/length_bytes.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct LengthBytesEvaluator {} - -impl FunctionEvaluator for LengthBytesEvaluator { - fn fn_name(&self) -> &'static str { - "length_bytes" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::UInt64)), - _ => Err(DaftError::TypeError(format!( - "Expects input to length_bytes to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_length_bytes(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/lower.rs b/src/daft-dsl/src/functions/utf8/lower.rs deleted file mode 100644 index f3fd7a8c47..0000000000 --- a/src/daft-dsl/src/functions/utf8/lower.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct LowerEvaluator {} - -impl FunctionEvaluator for LowerEvaluator { - fn fn_name(&self) -> &'static str { - "lower" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to lower to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_lower(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/lstrip.rs b/src/daft-dsl/src/functions/utf8/lstrip.rs deleted file mode 100644 index 534aa1cd37..0000000000 --- a/src/daft-dsl/src/functions/utf8/lstrip.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct LstripEvaluator {} - -impl FunctionEvaluator for LstripEvaluator { - fn fn_name(&self) -> &'static str { - "lstrip" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to lstrip to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_lstrip(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs deleted file mode 100644 index 7a795250ff..0000000000 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ /dev/null @@ -1,356 +0,0 @@ -mod capitalize; -mod contains; -mod endswith; -mod extract; -mod extract_all; -mod find; -mod ilike; -mod left; -mod length; -mod length_bytes; -mod like; -mod lower; -mod lpad; -mod lstrip; -mod match_; -mod normalize; -mod repeat; -mod replace; -mod reverse; -mod right; -mod rpad; -mod rstrip; -mod split; -mod startswith; -mod substr; -mod to_date; -mod to_datetime; -mod upper; - -use capitalize::CapitalizeEvaluator; -use contains::ContainsEvaluator; -use daft_core::array::ops::Utf8NormalizeOptions; -use endswith::EndswithEvaluator; -use extract::ExtractEvaluator; -use extract_all::ExtractAllEvaluator; -use find::FindEvaluator; -use ilike::IlikeEvaluator; -use left::LeftEvaluator; -use length::LengthEvaluator; -use length_bytes::LengthBytesEvaluator; -use like::LikeEvaluator; -use lower::LowerEvaluator; -use lpad::LpadEvaluator; -use lstrip::LstripEvaluator; -use normalize::NormalizeEvaluator; -use repeat::RepeatEvaluator; -use replace::ReplaceEvaluator; -use reverse::ReverseEvaluator; -use right::RightEvaluator; -use rpad::RpadEvaluator; -use rstrip::RstripEvaluator; -use serde::{Deserialize, Serialize}; -use split::SplitEvaluator; -use startswith::StartswithEvaluator; -use substr::SubstrEvaluator; -use to_date::ToDateEvaluator; -use to_datetime::ToDatetimeEvaluator; -use upper::UpperEvaluator; - -use super::FunctionEvaluator; -use crate::{functions::utf8::match_::MatchEvaluator, Expr, ExprRef}; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub enum Utf8Expr { - EndsWith, - StartsWith, - Contains, - Split(bool), - Match, - Extract(usize), - ExtractAll(usize), - Replace(bool), - Length, - LengthBytes, - Lower, - Upper, - Lstrip, - Rstrip, - Reverse, - Capitalize, - Left, - Right, - Find, - Rpad, - Lpad, - Repeat, - Like, - Ilike, - Substr, - ToDate(String), - ToDatetime(String, Option), - Normalize(Utf8NormalizeOptions), -} - -impl Utf8Expr { - #[inline] - pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - match self { - Self::EndsWith => &EndswithEvaluator {}, - Self::StartsWith => &StartswithEvaluator {}, - Self::Contains => &ContainsEvaluator {}, - Self::Split(_) => &SplitEvaluator {}, - Self::Match => &MatchEvaluator {}, - Self::Extract(_) => &ExtractEvaluator {}, - Self::ExtractAll(_) => &ExtractAllEvaluator {}, - Self::Replace(_) => &ReplaceEvaluator {}, - Self::Length => &LengthEvaluator {}, - Self::LengthBytes => &LengthBytesEvaluator {}, - Self::Lower => &LowerEvaluator {}, - Self::Upper => &UpperEvaluator {}, - Self::Lstrip => &LstripEvaluator {}, - Self::Rstrip => &RstripEvaluator {}, - Self::Reverse => &ReverseEvaluator {}, - Self::Capitalize => &CapitalizeEvaluator {}, - Self::Left => &LeftEvaluator {}, - Self::Right => &RightEvaluator {}, - Self::Find => &FindEvaluator {}, - Self::Rpad => &RpadEvaluator {}, - Self::Lpad => &LpadEvaluator {}, - Self::Repeat => &RepeatEvaluator {}, - Self::Like => &LikeEvaluator {}, - Self::Ilike => &IlikeEvaluator {}, - Self::Substr => &SubstrEvaluator {}, - Self::ToDate(_) => &ToDateEvaluator {}, - Self::ToDatetime(_, _) => &ToDatetimeEvaluator {}, - Self::Normalize(_) => &NormalizeEvaluator {}, - } - } -} - -pub fn endswith(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::EndsWith), - inputs: vec![data, pattern], - } - .into() -} - -pub fn startswith(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::StartsWith), - inputs: vec![data, pattern], - } - .into() -} - -pub fn contains(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Contains), - inputs: vec![data, pattern], - } - .into() -} - -pub fn match_(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Match), - inputs: vec![data, pattern], - } - .into() -} - -pub fn split(data: ExprRef, pattern: ExprRef, regex: bool) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Split(regex)), - inputs: vec![data, pattern], - } - .into() -} - -pub fn extract(data: ExprRef, pattern: ExprRef, index: usize) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Extract(index)), - inputs: vec![data, pattern], - } - .into() -} - -pub fn extract_all(data: ExprRef, pattern: ExprRef, index: usize) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::ExtractAll(index)), - inputs: vec![data, pattern], - } - .into() -} - -pub fn replace(data: ExprRef, pattern: ExprRef, replacement: ExprRef, regex: bool) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Replace(regex)), - inputs: vec![data, pattern, replacement], - } - .into() -} - -pub fn length(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Length), - inputs: vec![data], - } - .into() -} - -pub fn length_bytes(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::LengthBytes), - inputs: vec![data], - } - .into() -} - -pub fn lower(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Lower), - inputs: vec![data], - } - .into() -} - -pub fn upper(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Upper), - inputs: vec![data], - } - .into() -} - -pub fn lstrip(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Lstrip), - inputs: vec![data], - } - .into() -} - -pub fn rstrip(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Rstrip), - inputs: vec![data], - } - .into() -} - -pub fn reverse(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Reverse), - inputs: vec![data], - } - .into() -} - -pub fn capitalize(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Capitalize), - inputs: vec![data], - } - .into() -} - -pub fn left(data: ExprRef, count: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Left), - inputs: vec![data, count], - } - .into() -} - -pub fn right(data: ExprRef, count: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Right), - inputs: vec![data, count], - } - .into() -} - -pub fn find(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Find), - inputs: vec![data, pattern], - } - .into() -} - -pub fn rpad(data: ExprRef, length: ExprRef, pad: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Rpad), - inputs: vec![data, length, pad], - } - .into() -} - -pub fn lpad(data: ExprRef, length: ExprRef, pad: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Lpad), - inputs: vec![data, length, pad], - } - .into() -} - -pub fn repeat(data: ExprRef, count: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Repeat), - inputs: vec![data, count], - } - .into() -} - -pub fn like(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Like), - inputs: vec![data, pattern], - } - .into() -} - -pub fn ilike(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Ilike), - inputs: vec![data, pattern], - } - .into() -} - -pub fn substr(data: ExprRef, start: ExprRef, length: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Substr), - inputs: vec![data, start, length], - } - .into() -} - -pub fn to_date(data: ExprRef, format: &str) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::ToDate(format.to_string())), - inputs: vec![data], - } - .into() -} - -pub fn to_datetime(data: ExprRef, format: &str, timezone: Option<&str>) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::ToDatetime( - format.to_string(), - timezone.map(|s| s.to_string()), - )), - inputs: vec![data], - } - .into() -} - -pub fn normalize(data: ExprRef, opts: Utf8NormalizeOptions) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Normalize(opts)), - inputs: vec![data], - } - .into() -} diff --git a/src/daft-dsl/src/functions/utf8/normalize.rs b/src/daft-dsl/src/functions/utf8/normalize.rs deleted file mode 100644 index b693e2c017..0000000000 --- a/src/daft-dsl/src/functions/utf8/normalize.rs +++ /dev/null @@ -1,47 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct NormalizeEvaluator {} - -impl FunctionEvaluator for NormalizeEvaluator { - fn fn_name(&self) -> &'static str { - "normalize" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to normalize to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - match inputs { - [data] => { - let opts = match expr { - FunctionExpr::Utf8(Utf8Expr::Normalize(opts)) => opts, - _ => panic!("Expected Utf8 Normalize Expr, got {expr}"), - }; - data.utf8_normalize(*opts) - } - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/reverse.rs b/src/daft-dsl/src/functions/utf8/reverse.rs deleted file mode 100644 index cff9363a82..0000000000 --- a/src/daft-dsl/src/functions/utf8/reverse.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ReverseEvaluator {} - -impl FunctionEvaluator for ReverseEvaluator { - fn fn_name(&self) -> &'static str { - "reverse" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to reverse to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_reverse(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/right.rs b/src/daft-dsl/src/functions/utf8/right.rs deleted file mode 100644 index 892c0f7341..0000000000 --- a/src/daft-dsl/src/functions/utf8/right.rs +++ /dev/null @@ -1,45 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct RightEvaluator {} - -impl FunctionEvaluator for RightEvaluator { - fn fn_name(&self) -> &'static str { - "right" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data, nchars] => match (data.to_field(schema), nchars.to_field(schema)) { - (Ok(data_field), Ok(nchars_field)) => { - match (&data_field.dtype, &nchars_field.dtype) { - (DataType::Utf8, dt) if dt.is_integer() => { - Ok(Field::new(data_field.name, DataType::Utf8)) - } - _ => Err(DaftError::TypeError(format!( - "Expects inputs to left to be utf8 and integer, but received {data_field} and {nchars_field}", - ))), - } - } - (Err(e), _) | (_, Err(e)) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data, nchars] => data.utf8_right(nchars), - _ => Err(DaftError::ValueError(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/rstrip.rs b/src/daft-dsl/src/functions/utf8/rstrip.rs deleted file mode 100644 index c138d4c86c..0000000000 --- a/src/daft-dsl/src/functions/utf8/rstrip.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct RstripEvaluator {} - -impl FunctionEvaluator for RstripEvaluator { - fn fn_name(&self) -> &'static str { - "rstrip" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to rstrip to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_rstrip(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/to_date.rs b/src/daft-dsl/src/functions/utf8/to_date.rs deleted file mode 100644 index 58adecbc05..0000000000 --- a/src/daft-dsl/src/functions/utf8/to_date.rs +++ /dev/null @@ -1,47 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ToDateEvaluator {} - -impl FunctionEvaluator for ToDateEvaluator { - fn fn_name(&self) -> &'static str { - "to_date" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Date)), - _ => Err(DaftError::TypeError(format!( - "Expects inputs to to_date to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - match inputs { - [data] => { - let format = match expr { - FunctionExpr::Utf8(Utf8Expr::ToDate(format)) => format, - _ => panic!("Expected Utf8 ToDate Expr, got {expr}"), - }; - data.utf8_to_date(format) - } - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/to_datetime.rs b/src/daft-dsl/src/functions/utf8/to_datetime.rs deleted file mode 100644 index 25368c8e64..0000000000 --- a/src/daft-dsl/src/functions/utf8/to_datetime.rs +++ /dev/null @@ -1,66 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::{datatypes::infer_timeunit_from_format_string, prelude::*}; - -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ToDatetimeEvaluator {} - -impl FunctionEvaluator for ToDatetimeEvaluator { - fn fn_name(&self) -> &'static str { - "to_datetime" - } - - fn to_field( - &self, - inputs: &[ExprRef], - schema: &Schema, - expr: &FunctionExpr, - ) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => { - let (format, timezone) = match expr { - FunctionExpr::Utf8(Utf8Expr::ToDatetime(format, timezone)) => { - (format, timezone) - } - _ => panic!("Expected Utf8 ToDatetime Expr, got {expr}"), - }; - let timeunit = infer_timeunit_from_format_string(format); - Ok(Field::new( - data_field.name, - DataType::Timestamp(timeunit, timezone.clone()), - )) - } - _ => Err(DaftError::TypeError(format!( - "Expects inputs to to_datetime to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - match inputs { - [data] => { - let (format, timezone) = match expr { - FunctionExpr::Utf8(Utf8Expr::ToDatetime(format, timezone)) => { - (format, timezone) - } - _ => panic!("Expected Utf8 ToDatetime Expr, got {expr}"), - }; - data.utf8_to_datetime(format, timezone.as_deref()) - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/upper.rs b/src/daft-dsl/src/functions/utf8/upper.rs deleted file mode 100644 index a02438b495..0000000000 --- a/src/daft-dsl/src/functions/utf8/upper.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct UpperEvaluator {} - -impl FunctionEvaluator for UpperEvaluator { - fn fn_name(&self) -> &'static str { - "upper" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to upper to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_upper(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 0ba4ad8b92..01d6866b94 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -9,7 +9,6 @@ use common_error::DaftError; use common_py_serde::impl_bincode_py_state_serialization; use common_resource_request::ResourceRequest; use daft_core::{ - array::ops::Utf8NormalizeOptions, datatypes::{IntervalValue, IntervalValueBuilder}, prelude::*, python::{PyDataType, PyField, PySchema, PySeries, PyTimeUnit}, @@ -487,165 +486,6 @@ impl PyExpr { hasher.finish() } - pub fn utf8_endswith(&self, pattern: &Self) -> PyResult { - use crate::functions::utf8::endswith; - Ok(endswith(self.into(), pattern.expr.clone()).into()) - } - - pub fn utf8_startswith(&self, pattern: &Self) -> PyResult { - use crate::functions::utf8::startswith; - Ok(startswith(self.into(), pattern.expr.clone()).into()) - } - - pub fn utf8_contains(&self, pattern: &Self) -> PyResult { - use crate::functions::utf8::contains; - Ok(contains(self.into(), pattern.expr.clone()).into()) - } - - pub fn utf8_match(&self, pattern: &Self) -> PyResult { - use crate::functions::utf8::match_; - Ok(match_(self.into(), pattern.expr.clone()).into()) - } - - pub fn utf8_split(&self, pattern: &Self, regex: bool) -> PyResult { - use crate::functions::utf8::split; - Ok(split(self.into(), pattern.expr.clone(), regex).into()) - } - - pub fn utf8_extract(&self, pattern: &Self, index: usize) -> PyResult { - use crate::functions::utf8::extract; - Ok(extract(self.into(), pattern.expr.clone(), index).into()) - } - - pub fn utf8_extract_all(&self, pattern: &Self, index: usize) -> PyResult { - use crate::functions::utf8::extract_all; - Ok(extract_all(self.into(), pattern.expr.clone(), index).into()) - } - - pub fn utf8_replace(&self, pattern: &Self, replacement: &Self, regex: bool) -> PyResult { - use crate::functions::utf8::replace; - Ok(replace( - self.into(), - pattern.expr.clone(), - replacement.expr.clone(), - regex, - ) - .into()) - } - - pub fn utf8_length(&self) -> PyResult { - use crate::functions::utf8::length; - Ok(length(self.into()).into()) - } - - pub fn utf8_length_bytes(&self) -> PyResult { - use crate::functions::utf8::length_bytes; - Ok(length_bytes(self.into()).into()) - } - - pub fn utf8_lower(&self) -> PyResult { - use crate::functions::utf8::lower; - Ok(lower(self.into()).into()) - } - - pub fn utf8_upper(&self) -> PyResult { - use crate::functions::utf8::upper; - Ok(upper(self.into()).into()) - } - - pub fn utf8_lstrip(&self) -> PyResult { - use crate::functions::utf8::lstrip; - Ok(lstrip(self.into()).into()) - } - - pub fn utf8_rstrip(&self) -> PyResult { - use crate::functions::utf8::rstrip; - Ok(rstrip(self.into()).into()) - } - - pub fn utf8_reverse(&self) -> PyResult { - use crate::functions::utf8::reverse; - Ok(reverse(self.into()).into()) - } - - pub fn utf8_capitalize(&self) -> PyResult { - use crate::functions::utf8::capitalize; - Ok(capitalize(self.into()).into()) - } - - pub fn utf8_left(&self, count: &Self) -> PyResult { - use crate::functions::utf8::left; - Ok(left(self.into(), count.into()).into()) - } - - pub fn utf8_right(&self, count: &Self) -> PyResult { - use crate::functions::utf8::right; - Ok(right(self.into(), count.into()).into()) - } - - pub fn utf8_find(&self, substr: &Self) -> PyResult { - use crate::functions::utf8::find; - Ok(find(self.into(), substr.into()).into()) - } - - pub fn utf8_rpad(&self, length: &Self, pad: &Self) -> PyResult { - use crate::functions::utf8::rpad; - Ok(rpad(self.into(), length.into(), pad.into()).into()) - } - - pub fn utf8_lpad(&self, length: &Self, pad: &Self) -> PyResult { - use crate::functions::utf8::lpad; - Ok(lpad(self.into(), length.into(), pad.into()).into()) - } - - pub fn utf8_repeat(&self, n: &Self) -> PyResult { - use crate::functions::utf8::repeat; - Ok(repeat(self.into(), n.into()).into()) - } - - pub fn utf8_like(&self, pattern: &Self) -> PyResult { - use crate::functions::utf8::like; - Ok(like(self.into(), pattern.into()).into()) - } - - pub fn utf8_ilike(&self, pattern: &Self) -> PyResult { - use crate::functions::utf8::ilike; - Ok(ilike(self.into(), pattern.into()).into()) - } - - pub fn utf8_substr(&self, start: &Self, length: &Self) -> PyResult { - use crate::functions::utf8::substr; - Ok(substr(self.into(), start.into(), length.into()).into()) - } - - pub fn utf8_to_date(&self, format: &str) -> PyResult { - use crate::functions::utf8::to_date; - Ok(to_date(self.into(), format).into()) - } - - pub fn utf8_to_datetime(&self, format: &str, timezone: Option<&str>) -> PyResult { - use crate::functions::utf8::to_datetime; - Ok(to_datetime(self.into(), format, timezone).into()) - } - - pub fn utf8_normalize( - &self, - remove_punct: bool, - lowercase: bool, - nfd_unicode: bool, - white_space: bool, - ) -> PyResult { - use crate::functions::utf8::normalize; - let opts = Utf8NormalizeOptions { - remove_punct, - lowercase, - nfd_unicode, - white_space, - }; - - Ok(normalize(self.into(), opts).into()) - } - pub fn struct_get(&self, name: &str) -> PyResult { use crate::functions::struct_::get; Ok(get(self.into(), name).into()) diff --git a/src/daft-functions/src/lib.rs b/src/daft-functions/src/lib.rs index a8b1c8d0cd..e73a73cda0 100644 --- a/src/daft-functions/src/lib.rs +++ b/src/daft-functions/src/lib.rs @@ -11,6 +11,7 @@ pub mod temporal; pub mod to_struct; pub mod tokenize; pub mod uri; +pub mod utf8; use common_error::DaftError; #[cfg(feature = "python")] @@ -50,6 +51,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { float::register_modules(parent)?; temporal::register_modules(parent)?; list::register_modules(parent)?; + utf8::register_modules(parent)?; Ok(()) } diff --git a/src/daft-functions/src/utf8/capitalize.rs b/src/daft-functions/src/utf8/capitalize.rs new file mode 100644 index 0000000000..abf770924c --- /dev/null +++ b/src/daft-functions/src/utf8/capitalize.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Capitalize {} + +#[typetag::serde] +impl ScalarUDF for Utf8Capitalize { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "capitalize" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to capitalize to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_capitalize(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_capitalize(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Capitalize {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_capitalize")] +pub fn py_utf8_capitalize(expr: PyExpr) -> PyResult { + Ok(utf8_capitalize(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/contains.rs b/src/daft-functions/src/utf8/contains.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/contains.rs rename to src/daft-functions/src/utf8/contains.rs index 8c63b17be3..2ea8708711 100644 --- a/src/daft-dsl/src/functions/utf8/contains.rs +++ b/src/daft-functions/src/utf8/contains.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Contains {} -pub(super) struct ContainsEvaluator {} - -impl FunctionEvaluator for ContainsEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Contains { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "contains" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { (Ok(data_field), Ok(pattern_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for ContainsEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, pattern] => data.utf8_contains(pattern), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for ContainsEvaluator { } } } + +#[must_use] +pub fn utf8_contains(input: ExprRef, pattern: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Contains {}, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_contains")] +pub fn py_utf8_contains(expr: PyExpr, pattern: PyExpr) -> PyResult { + Ok(utf8_contains(expr.into(), pattern.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/endswith.rs b/src/daft-functions/src/utf8/endswith.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/endswith.rs rename to src/daft-functions/src/utf8/endswith.rs index 5785f92257..8f11cb8db8 100644 --- a/src/daft-dsl/src/functions/utf8/endswith.rs +++ b/src/daft-functions/src/utf8/endswith.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Endswith {} -pub(super) struct EndswithEvaluator {} - -impl FunctionEvaluator for EndswithEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Endswith { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "endswith" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { (Ok(data_field), Ok(pattern_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for EndswithEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, pattern] => data.utf8_endswith(pattern), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for EndswithEvaluator { } } } + +#[must_use] +pub fn utf8_endswith(input: ExprRef, pattern: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Endswith {}, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_endswith")] +pub fn py_utf8_endswith(expr: PyExpr, pattern: PyExpr) -> PyResult { + Ok(utf8_endswith(expr.into(), pattern.into()).into()) +} diff --git a/src/daft-functions/src/utf8/extract.rs b/src/daft-functions/src/utf8/extract.rs new file mode 100644 index 0000000000..f5a97b1c3e --- /dev/null +++ b/src/daft-functions/src/utf8/extract.rs @@ -0,0 +1,74 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Extract { + pub index: usize, +} + +#[typetag::serde] +impl ScalarUDF for Utf8Extract { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "extract" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { + (Ok(data_field), Ok(pattern_field)) => { + match (&data_field.dtype, &pattern_field.dtype) { + (DataType::Utf8, DataType::Utf8) => { + Ok(Field::new(data_field.name, DataType::Utf8)) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to extract to be utf8, but received {data_field} and {pattern_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data, pattern] => data.utf8_extract(pattern, self.index), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_extract(input: ExprRef, pattern: ExprRef, index: usize) -> ExprRef { + ScalarFunction::new(Utf8Extract { index }, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_extract")] +pub fn py_utf8_extract(expr: PyExpr, pattern: PyExpr, index: usize) -> PyResult { + Ok(utf8_extract(expr.into(), pattern.into(), index).into()) +} diff --git a/src/daft-functions/src/utf8/extract_all.rs b/src/daft-functions/src/utf8/extract_all.rs new file mode 100644 index 0000000000..b40c6ad47c --- /dev/null +++ b/src/daft-functions/src/utf8/extract_all.rs @@ -0,0 +1,74 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8ExtractAll { + pub index: usize, +} + +#[typetag::serde] +impl ScalarUDF for Utf8ExtractAll { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "extractall" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { + (Ok(data_field), Ok(pattern_field)) => { + match (&data_field.dtype, &pattern_field.dtype) { + (DataType::Utf8, DataType::Utf8) => { + Ok(Field::new(data_field.name, DataType::List(Box::new(DataType::Utf8)))) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to extractAll to be utf8, but received {data_field} and {pattern_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data, pattern] => data.utf8_extract_all(pattern, self.index), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_extract_all(input: ExprRef, pattern: ExprRef, index: usize) -> ExprRef { + ScalarFunction::new(Utf8ExtractAll { index }, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_extract_all")] +pub fn py_utf8_extract_all(expr: PyExpr, pattern: PyExpr, index: usize) -> PyResult { + Ok(utf8_extract_all(expr.into(), pattern.into(), index).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/find.rs b/src/daft-functions/src/utf8/find.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/find.rs rename to src/daft-functions/src/utf8/find.rs index d184d17c5f..3ec11bec97 100644 --- a/src/daft-dsl/src/functions/utf8/find.rs +++ b/src/daft-functions/src/utf8/find.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Find {} -pub(super) struct FindEvaluator {} - -impl FunctionEvaluator for FindEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Find { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "find" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, substr] => match (data.to_field(schema), substr.to_field(schema)) { (Ok(data_field), Ok(substr_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for FindEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, substr] => data.utf8_find(substr), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for FindEvaluator { } } } + +#[must_use] +pub fn utf8_find(input: ExprRef, substr: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Find {}, vec![input, substr]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_find")] +pub fn py_utf8_find(expr: PyExpr, substr: PyExpr) -> PyResult { + Ok(utf8_find(expr.into(), substr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/ilike.rs b/src/daft-functions/src/utf8/ilike.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/ilike.rs rename to src/daft-functions/src/utf8/ilike.rs index 35c0ce1e20..54c0d4ca8e 100644 --- a/src/daft-dsl/src/functions/utf8/ilike.rs +++ b/src/daft-functions/src/utf8/ilike.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Ilike {} -pub(super) struct IlikeEvaluator {} - -impl FunctionEvaluator for IlikeEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Ilike { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "ilike" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { (Ok(data_field), Ok(pattern_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for IlikeEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, pattern] => data.utf8_ilike(pattern), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for IlikeEvaluator { } } } + +#[must_use] +pub fn utf8_ilike(input: ExprRef, pattern: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Ilike {}, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_ilike")] +pub fn py_utf8_ilike(expr: PyExpr, pattern: PyExpr) -> PyResult { + Ok(utf8_ilike(expr.into(), pattern.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/left.rs b/src/daft-functions/src/utf8/left.rs similarity index 54% rename from src/daft-dsl/src/functions/utf8/left.rs rename to src/daft-functions/src/utf8/left.rs index ffde503901..c055ad7ecb 100644 --- a/src/daft-dsl/src/functions/utf8/left.rs +++ b/src/daft-functions/src/utf8/left.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Left {} -pub(super) struct LeftEvaluator {} - -impl FunctionEvaluator for LeftEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Left { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "left" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, nchars] => match (data.to_field(schema), nchars.to_field(schema)) { (Ok(data_field), Ok(nchars_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for LeftEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, nchars] => data.utf8_left(nchars), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for LeftEvaluator { } } } + +#[must_use] +pub fn utf8_left(input: ExprRef, nchars: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Left {}, vec![input, nchars]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_left")] +pub fn py_utf8_left(expr: PyExpr, nchars: PyExpr) -> PyResult { + Ok(utf8_left(expr.into(), nchars.into()).into()) +} diff --git a/src/daft-functions/src/utf8/length.rs b/src/daft-functions/src/utf8/length.rs new file mode 100644 index 0000000000..8d58d8ae27 --- /dev/null +++ b/src/daft-functions/src/utf8/length.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Length {} + +#[typetag::serde] +impl ScalarUDF for Utf8Length { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "length" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::UInt64)), + _ => Err(DaftError::TypeError(format!( + "Expects input to length to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_length(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_length(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Length {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_length")] +pub fn py_utf8_length(expr: PyExpr) -> PyResult { + Ok(utf8_length(expr.into()).into()) +} diff --git a/src/daft-functions/src/utf8/length_bytes.rs b/src/daft-functions/src/utf8/length_bytes.rs new file mode 100644 index 0000000000..dbcb841701 --- /dev/null +++ b/src/daft-functions/src/utf8/length_bytes.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8LengthBytes {} + +#[typetag::serde] +impl ScalarUDF for Utf8LengthBytes { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "length_bytes" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::UInt64)), + _ => Err(DaftError::TypeError(format!( + "Expects input to length_bytes to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_length_bytes(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_length_bytes(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8LengthBytes {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_length_bytes")] +pub fn py_utf8_length_bytes(expr: PyExpr) -> PyResult { + Ok(utf8_length_bytes(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/like.rs b/src/daft-functions/src/utf8/like.rs similarity index 54% rename from src/daft-dsl/src/functions/utf8/like.rs rename to src/daft-functions/src/utf8/like.rs index a2a2a96def..915a805a9b 100644 --- a/src/daft-dsl/src/functions/utf8/like.rs +++ b/src/daft-functions/src/utf8/like.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Like {} -pub(super) struct LikeEvaluator {} - -impl FunctionEvaluator for LikeEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Like { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "like" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { (Ok(data_field), Ok(pattern_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for LikeEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, pattern] => data.utf8_like(pattern), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for LikeEvaluator { } } } + +#[must_use] +pub fn utf8_like(input: ExprRef, pattern: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Like {}, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_like")] +pub fn py_utf8_like(expr: PyExpr, pattern: PyExpr) -> PyResult { + Ok(utf8_like(expr.into(), pattern.into()).into()) +} diff --git a/src/daft-functions/src/utf8/lower.rs b/src/daft-functions/src/utf8/lower.rs new file mode 100644 index 0000000000..7935168d9b --- /dev/null +++ b/src/daft-functions/src/utf8/lower.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Lower {} + +#[typetag::serde] +impl ScalarUDF for Utf8Lower { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "lower" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to lower to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_lower(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_lower(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Lower {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_lower")] +pub fn py_utf8_lower(expr: PyExpr) -> PyResult { + Ok(utf8_lower(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/lpad.rs b/src/daft-functions/src/utf8/lpad.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/lpad.rs rename to src/daft-functions/src/utf8/lpad.rs index 9880568aed..89808d645a 100644 --- a/src/daft-dsl/src/functions/utf8/lpad.rs +++ b/src/daft-functions/src/utf8/lpad.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Lpad {} -pub(super) struct LpadEvaluator {} - -impl FunctionEvaluator for LpadEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Lpad { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "lpad" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, length, pad] => { let data = data.to_field(schema)?; @@ -35,7 +45,7 @@ impl FunctionEvaluator for LpadEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, length, pad] => data.utf8_lpad(length, pad), _ => Err(DaftError::ValueError(format!( @@ -45,3 +55,20 @@ impl FunctionEvaluator for LpadEvaluator { } } } + +#[must_use] +pub fn utf8_lpad(input: ExprRef, length: ExprRef, pad: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Lpad {}, vec![input, length, pad]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_lpad")] +pub fn py_utf8_lpad(expr: PyExpr, length: PyExpr, pad: PyExpr) -> PyResult { + Ok(utf8_lpad(expr.into(), length.into(), pad.into()).into()) +} diff --git a/src/daft-functions/src/utf8/lstrip.rs b/src/daft-functions/src/utf8/lstrip.rs new file mode 100644 index 0000000000..f7441a8ac2 --- /dev/null +++ b/src/daft-functions/src/utf8/lstrip.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Lstrip {} + +#[typetag::serde] +impl ScalarUDF for Utf8Lstrip { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "lstrip" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to lstrip to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_lstrip(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_lstrip(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Lstrip {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_lstrip")] +pub fn py_utf8_lstrip(expr: PyExpr) -> PyResult { + Ok(utf8_lstrip(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/match_.rs b/src/daft-functions/src/utf8/match_.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/match_.rs rename to src/daft-functions/src/utf8/match_.rs index 7455aca17c..0a9cbc8a8c 100644 --- a/src/daft-dsl/src/functions/utf8/match_.rs +++ b/src/daft-functions/src/utf8/match_.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Match {} -pub(super) struct MatchEvaluator {} - -impl FunctionEvaluator for MatchEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Match { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "match" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { (Ok(data_field), Ok(pattern_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for MatchEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, pattern] => data.utf8_match(pattern), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for MatchEvaluator { } } } + +#[must_use] +pub fn utf8_match(input: ExprRef, pattern: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Match {}, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_match")] +pub fn py_utf8_match(expr: PyExpr, pattern: PyExpr) -> PyResult { + Ok(utf8_match(expr.into(), pattern.into()).into()) +} diff --git a/src/daft-functions/src/utf8/mod.rs b/src/daft-functions/src/utf8/mod.rs new file mode 100644 index 0000000000..c6c4105b55 --- /dev/null +++ b/src/daft-functions/src/utf8/mod.rs @@ -0,0 +1,111 @@ +mod capitalize; +mod contains; +mod endswith; +mod extract; +mod extract_all; +mod find; +mod ilike; +mod left; +mod length; +mod length_bytes; +mod like; +mod lower; +mod lpad; +mod lstrip; +mod match_; +mod normalize; +mod repeat; +mod replace; +mod reverse; +mod right; +mod rpad; +mod rstrip; +mod split; +mod startswith; +mod substr; +mod to_date; +mod to_datetime; +mod upper; + +pub use capitalize::{utf8_capitalize as capitalize, Utf8Capitalize}; +pub use contains::{utf8_contains as contains, Utf8Contains}; +pub use endswith::{utf8_endswith as endswith, Utf8Endswith}; +pub use extract::{utf8_extract as extract, Utf8Extract}; +pub use extract_all::{utf8_extract_all as extract_all, Utf8ExtractAll}; +pub use find::{utf8_find as find, Utf8Find}; +pub use ilike::{utf8_ilike as ilike, Utf8Ilike}; +pub use left::{utf8_left as left, Utf8Left}; +pub use length::{utf8_length as length, Utf8Length}; +pub use length_bytes::{utf8_length_bytes as length_bytes, Utf8LengthBytes}; +pub use like::{utf8_like as like, Utf8Like}; +pub use lower::{utf8_lower as lower, Utf8Lower}; +pub use lpad::{utf8_lpad as lpad, Utf8Lpad}; +pub use lstrip::{utf8_lstrip as lstrip, Utf8Lstrip}; +pub use match_::{utf8_match as match_, Utf8Match}; +pub use normalize::{utf8_normalize as normalize, Utf8Normalize}; +#[cfg(feature = "python")] +use pyo3::prelude::*; +pub use repeat::{utf8_repeat as repeat, Utf8Repeat}; +pub use replace::{utf8_replace as replace, Utf8Replace}; +pub use reverse::{utf8_reverse as reverse, Utf8Reverse}; +pub use right::{utf8_right as right, Utf8Right}; +pub use rpad::{utf8_rpad as rpad, Utf8Rpad}; +pub use rstrip::{utf8_rstrip as rstrip, Utf8Rstrip}; +pub use split::{utf8_split as split, Utf8Split}; +pub use startswith::{utf8_startswith as startswith, Utf8Startswith}; +pub use substr::{utf8_substr as substr, Utf8Substr}; +pub use to_date::{utf8_to_date as to_date, Utf8ToDate}; +pub use to_datetime::{utf8_to_datetime as to_datetime, Utf8ToDatetime}; +pub use upper::{utf8_upper as upper, Utf8Upper}; + +#[cfg(feature = "python")] +pub fn register_modules(parent: &Bound) -> PyResult<()> { + parent.add_function(wrap_pyfunction_bound!( + capitalize::py_utf8_capitalize, + parent + )?)?; + parent.add_function(wrap_pyfunction_bound!(contains::py_utf8_contains, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(endswith::py_utf8_endswith, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(extract::py_utf8_extract, parent)?)?; + parent.add_function(wrap_pyfunction_bound!( + extract_all::py_utf8_extract_all, + parent + )?)?; + parent.add_function(wrap_pyfunction_bound!(find::py_utf8_find, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(like::py_utf8_like, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(ilike::py_utf8_ilike, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(left::py_utf8_left, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(length::py_utf8_length, parent)?)?; + parent.add_function(wrap_pyfunction_bound!( + length_bytes::py_utf8_length_bytes, + parent + )?)?; + parent.add_function(wrap_pyfunction_bound!(lower::py_utf8_lower, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(lpad::py_utf8_lpad, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(lstrip::py_utf8_lstrip, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(match_::py_utf8_match, parent)?)?; + parent.add_function(wrap_pyfunction_bound!( + normalize::py_utf8_normalize, + parent + )?)?; + parent.add_function(wrap_pyfunction_bound!(repeat::py_utf8_repeat, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(replace::py_utf8_replace, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(reverse::py_utf8_reverse, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(right::py_utf8_right, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(rpad::py_utf8_rpad, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(rstrip::py_utf8_rstrip, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(split::py_utf8_split, parent)?)?; + parent.add_function(wrap_pyfunction_bound!( + startswith::py_utf8_startswith, + parent + )?)?; + parent.add_function(wrap_pyfunction_bound!(substr::py_utf8_substr, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(to_date::py_utf8_to_date, parent)?)?; + parent.add_function(wrap_pyfunction_bound!( + to_datetime::py_utf8_to_datetime, + parent + )?)?; + parent.add_function(wrap_pyfunction_bound!(upper::py_utf8_upper, parent)?)?; + + Ok(()) +} diff --git a/src/daft-functions/src/utf8/normalize.rs b/src/daft-functions/src/utf8/normalize.rs new file mode 100644 index 0000000000..b9455d23d3 --- /dev/null +++ b/src/daft-functions/src/utf8/normalize.rs @@ -0,0 +1,87 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Normalize { + pub opts: Utf8NormalizeOptions, +} + +#[typetag::serde] +impl ScalarUDF for Utf8Normalize { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "normalize" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to normalize to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_normalize(self.opts), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_normalize(input: ExprRef, opts: Utf8NormalizeOptions) -> ExprRef { + ScalarFunction::new(Utf8Normalize { opts }, vec![input]).into() +} + +use daft_core::array::ops::Utf8NormalizeOptions; +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_normalize")] +pub fn py_utf8_normalize( + expr: PyExpr, + remove_punct: bool, + lowercase: bool, + nfd_unicode: bool, + white_space: bool, +) -> PyResult { + Ok(utf8_normalize( + expr.into(), + Utf8NormalizeOptions { + remove_punct, + lowercase, + nfd_unicode, + white_space, + }, + ) + .into()) +} diff --git a/src/daft-dsl/src/functions/utf8/repeat.rs b/src/daft-functions/src/utf8/repeat.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/repeat.rs rename to src/daft-functions/src/utf8/repeat.rs index c321a6920a..dc74a4bfed 100644 --- a/src/daft-dsl/src/functions/utf8/repeat.rs +++ b/src/daft-functions/src/utf8/repeat.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Repeat {} -pub(super) struct RepeatEvaluator {} - -impl FunctionEvaluator for RepeatEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Repeat { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "repeat" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, ntimes] => match (data.to_field(schema), ntimes.to_field(schema)) { (Ok(data_field), Ok(ntimes_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for RepeatEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, ntimes] => data.utf8_repeat(ntimes), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for RepeatEvaluator { } } } + +#[must_use] +pub fn utf8_repeat(input: ExprRef, ntimes: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Repeat {}, vec![input, ntimes]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_repeat")] +pub fn py_utf8_repeat(expr: PyExpr, ntimes: PyExpr) -> PyResult { + Ok(utf8_repeat(expr.into(), ntimes.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/replace.rs b/src/daft-functions/src/utf8/replace.rs similarity index 50% rename from src/daft-dsl/src/functions/utf8/replace.rs rename to src/daft-functions/src/utf8/replace.rs index 022f98ac17..76134c8136 100644 --- a/src/daft-dsl/src/functions/utf8/replace.rs +++ b/src/daft-functions/src/utf8/replace.rs @@ -1,17 +1,29 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ReplaceEvaluator {} +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Replace { + pub regex: bool, +} -impl FunctionEvaluator for ReplaceEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Replace { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "replace" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern, replacement] => match ( data.to_field(schema), @@ -37,15 +49,9 @@ impl FunctionEvaluator for ReplaceEvaluator { } } - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { - [data, pattern, replacement] => { - let regex = match expr { - FunctionExpr::Utf8(Utf8Expr::Replace(regex)) => regex, - _ => panic!("Expected Utf8 Replace Expr, got {expr}"), - }; - data.utf8_replace(pattern, replacement, *regex) - } + [data, pattern, replacement] => data.utf8_replace(pattern, replacement, self.regex), _ => Err(DaftError::ValueError(format!( "Expected 3 input args, got {}", inputs.len() @@ -53,3 +59,30 @@ impl FunctionEvaluator for ReplaceEvaluator { } } } + +#[must_use] +pub fn utf8_replace( + input: ExprRef, + pattern: ExprRef, + replacement: ExprRef, + regex: bool, +) -> ExprRef { + ScalarFunction::new(Utf8Replace { regex }, vec![input, pattern, replacement]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_replace")] +pub fn py_utf8_replace( + expr: PyExpr, + pattern: PyExpr, + replacement: PyExpr, + regex: bool, +) -> PyResult { + Ok(utf8_replace(expr.into(), pattern.into(), replacement.into(), regex).into()) +} diff --git a/src/daft-functions/src/utf8/reverse.rs b/src/daft-functions/src/utf8/reverse.rs new file mode 100644 index 0000000000..60674fc168 --- /dev/null +++ b/src/daft-functions/src/utf8/reverse.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Reverse {} + +#[typetag::serde] +impl ScalarUDF for Utf8Reverse { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "reverse" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to reverse to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_reverse(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_reverse(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Reverse {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_reverse")] +pub fn py_utf8_reverse(expr: PyExpr) -> PyResult { + Ok(utf8_reverse(expr.into()).into()) +} diff --git a/src/daft-functions/src/utf8/right.rs b/src/daft-functions/src/utf8/right.rs new file mode 100644 index 0000000000..fbac7742b4 --- /dev/null +++ b/src/daft-functions/src/utf8/right.rs @@ -0,0 +1,72 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Right {} + +#[typetag::serde] +impl ScalarUDF for Utf8Right { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "right" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data, nchars] => match (data.to_field(schema), nchars.to_field(schema)) { + (Ok(data_field), Ok(nchars_field)) => { + match (&data_field.dtype, &nchars_field.dtype) { + (DataType::Utf8, dt) if dt.is_integer() => { + Ok(Field::new(data_field.name, DataType::Utf8)) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to right to be utf8 and integer, but received {data_field} and {nchars_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data, nchars] => data.utf8_right(nchars), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_right(input: ExprRef, nchars: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Right {}, vec![input, nchars]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_right")] +pub fn py_utf8_right(expr: PyExpr, nchars: PyExpr) -> PyResult { + Ok(utf8_right(expr.into(), nchars.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/rpad.rs b/src/daft-functions/src/utf8/rpad.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/rpad.rs rename to src/daft-functions/src/utf8/rpad.rs index f7c0769fac..2a0864a578 100644 --- a/src/daft-dsl/src/functions/utf8/rpad.rs +++ b/src/daft-functions/src/utf8/rpad.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Rpad {} -pub(super) struct RpadEvaluator {} - -impl FunctionEvaluator for RpadEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Rpad { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "rpad" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, length, pad] => { let data = data.to_field(schema)?; @@ -35,7 +45,7 @@ impl FunctionEvaluator for RpadEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, length, pad] => data.utf8_rpad(length, pad), _ => Err(DaftError::ValueError(format!( @@ -45,3 +55,20 @@ impl FunctionEvaluator for RpadEvaluator { } } } + +#[must_use] +pub fn utf8_rpad(input: ExprRef, length: ExprRef, pad: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Rpad {}, vec![input, length, pad]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_rpad")] +pub fn py_utf8_rpad(expr: PyExpr, length: PyExpr, pad: PyExpr) -> PyResult { + Ok(utf8_rpad(expr.into(), length.into(), pad.into()).into()) +} diff --git a/src/daft-functions/src/utf8/rstrip.rs b/src/daft-functions/src/utf8/rstrip.rs new file mode 100644 index 0000000000..b2528a99ac --- /dev/null +++ b/src/daft-functions/src/utf8/rstrip.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Rstrip {} + +#[typetag::serde] +impl ScalarUDF for Utf8Rstrip { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "rstrip" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to rstrip to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_rstrip(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_rstrip(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Rstrip {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_rstrip")] +pub fn py_utf8_rstrip(expr: PyExpr) -> PyResult { + Ok(utf8_rstrip(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/split.rs b/src/daft-functions/src/utf8/split.rs similarity index 50% rename from src/daft-dsl/src/functions/utf8/split.rs rename to src/daft-functions/src/utf8/split.rs index 0518786055..60e8110ce1 100644 --- a/src/daft-dsl/src/functions/utf8/split.rs +++ b/src/daft-functions/src/utf8/split.rs @@ -1,17 +1,29 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct SplitEvaluator {} +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Split { + pub regex: bool, +} -impl FunctionEvaluator for SplitEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Split { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "split" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { (Ok(data_field), Ok(pattern_field)) => { @@ -33,15 +45,9 @@ impl FunctionEvaluator for SplitEvaluator { } } - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { - [data, pattern] => { - let regex = match expr { - FunctionExpr::Utf8(Utf8Expr::Split(regex)) => regex, - _ => panic!("Expected Utf8 Split Expr, got {expr}"), - }; - data.utf8_split(pattern, *regex) - } + [data, pattern] => data.utf8_split(pattern, self.regex), _ => Err(DaftError::ValueError(format!( "Expected 2 input args, got {}", inputs.len() @@ -49,3 +55,20 @@ impl FunctionEvaluator for SplitEvaluator { } } } + +#[must_use] +pub fn utf8_split(input: ExprRef, pattern: ExprRef, regex: bool) -> ExprRef { + ScalarFunction::new(Utf8Split { regex }, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_split")] +pub fn py_utf8_split(expr: PyExpr, pattern: PyExpr, regex: bool) -> PyResult { + Ok(utf8_split(expr.into(), pattern.into(), regex).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/startswith.rs b/src/daft-functions/src/utf8/startswith.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/startswith.rs rename to src/daft-functions/src/utf8/startswith.rs index 01ae5eda7e..3a0bb50d2b 100644 --- a/src/daft-dsl/src/functions/utf8/startswith.rs +++ b/src/daft-functions/src/utf8/startswith.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Startswith {} -pub(super) struct StartswithEvaluator {} - -impl FunctionEvaluator for StartswithEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Startswith { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "startswith" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { (Ok(data_field), Ok(pattern_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for StartswithEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, pattern] => data.utf8_startswith(pattern), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for StartswithEvaluator { } } } + +#[must_use] +pub fn utf8_startswith(input: ExprRef, pattern: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Startswith {}, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_startswith")] +pub fn py_utf8_startswith(expr: PyExpr, pattern: PyExpr) -> PyResult { + Ok(utf8_startswith(expr.into(), pattern.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/substr.rs b/src/daft-functions/src/utf8/substr.rs similarity index 54% rename from src/daft-dsl/src/functions/utf8/substr.rs rename to src/daft-functions/src/utf8/substr.rs index d2ec60256a..90628e441c 100644 --- a/src/daft-dsl/src/functions/utf8/substr.rs +++ b/src/daft-functions/src/utf8/substr.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Substr {} -pub(super) struct SubstrEvaluator {} - -impl FunctionEvaluator for SubstrEvaluator { - fn fn_name(&self) -> &'static str { +#[typetag::serde] +impl ScalarUDF for Utf8Substr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { "substr" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, start, length] => { let data = data.to_field(schema)?; @@ -37,7 +47,7 @@ impl FunctionEvaluator for SubstrEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, start, length] => data.utf8_substr(start, length), _ => Err(DaftError::ValueError(format!( @@ -47,3 +57,20 @@ impl FunctionEvaluator for SubstrEvaluator { } } } + +#[must_use] +pub fn utf8_substr(input: ExprRef, start: ExprRef, length: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Substr {}, vec![input, start, length]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_substr")] +pub fn py_utf8_substr(expr: PyExpr, start: PyExpr, length: PyExpr) -> PyResult { + Ok(utf8_substr(expr.into(), start.into(), length.into()).into()) +} diff --git a/src/daft-functions/src/utf8/to_date.rs b/src/daft-functions/src/utf8/to_date.rs new file mode 100644 index 0000000000..911cca84c9 --- /dev/null +++ b/src/daft-functions/src/utf8/to_date.rs @@ -0,0 +1,77 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8ToDate { + pub format: String, +} + +#[typetag::serde] +impl ScalarUDF for Utf8ToDate { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "to_date" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Date)), + _ => Err(DaftError::TypeError(format!( + "Expects inputs to to_date to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_to_date(&self.format), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_to_date>(input: ExprRef, format: S) -> ExprRef { + ScalarFunction::new( + Utf8ToDate { + format: format.into(), + }, + vec![input], + ) + .into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_to_date")] +pub fn py_utf8_to_date(expr: PyExpr, format: &str) -> PyResult { + Ok(utf8_to_date::<&str>(expr.into(), format).into()) +} diff --git a/src/daft-functions/src/utf8/to_datetime.rs b/src/daft-functions/src/utf8/to_datetime.rs new file mode 100644 index 0000000000..862859699f --- /dev/null +++ b/src/daft-functions/src/utf8/to_datetime.rs @@ -0,0 +1,90 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + datatypes::infer_timeunit_from_format_string, + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8ToDatetime { + pub format: String, + pub timezone: Option, +} + +#[typetag::serde] +impl ScalarUDF for Utf8ToDatetime { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "to_datetime" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => { + let timeunit = infer_timeunit_from_format_string(&self.format); + Ok(Field::new( + data_field.name, + DataType::Timestamp(timeunit, self.timezone.clone()), + )) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to to_datetime to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_to_datetime(&self.format, self.timezone.as_deref()), + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_to_datetime>( + input: ExprRef, + format: S, + timezone: Option, +) -> ExprRef { + ScalarFunction::new( + Utf8ToDatetime { + format: format.into(), + timezone: timezone.map(|s| s.into()), + }, + vec![input], + ) + .into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_to_datetime")] +pub fn py_utf8_to_datetime(expr: PyExpr, format: &str, timezone: Option<&str>) -> PyResult { + Ok(utf8_to_datetime::<&str>(expr.into(), format, timezone).into()) +} diff --git a/src/daft-functions/src/utf8/upper.rs b/src/daft-functions/src/utf8/upper.rs new file mode 100644 index 0000000000..b6500e825f --- /dev/null +++ b/src/daft-functions/src/utf8/upper.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Upper {} + +#[typetag::serde] +impl ScalarUDF for Utf8Upper { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "upper" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to upper to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_upper(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_upper(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Upper {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_upper")] +pub fn py_utf8_upper(expr: PyExpr) -> PyResult { + Ok(utf8_upper(expr.into()).into()) +} diff --git a/src/daft-logical-plan/src/display.rs b/src/daft-logical-plan/src/display.rs index 26c7470fa7..6bcdba360b 100644 --- a/src/daft-logical-plan/src/display.rs +++ b/src/daft-logical-plan/src/display.rs @@ -34,11 +34,8 @@ mod test { use common_display::mermaid::{MermaidDisplay, MermaidDisplayOptions, SubgraphOptions}; use common_error::DaftResult; use daft_core::prelude::*; - use daft_dsl::{ - col, - functions::utf8::{endswith, startswith}, - lit, - }; + use daft_dsl::{col, lit}; + use daft_functions::utf8::{endswith, startswith}; use pretty_assertions::assert_eq; use crate::{ diff --git a/src/daft-sql/src/modules/utf8.rs b/src/daft-sql/src/modules/utf8.rs index 084da08962..edf3b3133e 100644 --- a/src/daft-sql/src/modules/utf8.rs +++ b/src/daft-sql/src/modules/utf8.rs @@ -1,12 +1,5 @@ use daft_core::array::ops::Utf8NormalizeOptions; -use daft_dsl::{ - binary_op, - functions::{ - self, - utf8::{normalize, Utf8Expr}, - }, - ExprRef, LiteralValue, Operator, -}; +use daft_dsl::{binary_op, ExprRef, LiteralValue, Operator}; use daft_functions::{ count_matches::{utf8_count_matches, CountMatchesFunction}, tokenize::{tokenize_decode, tokenize_encode, TokenizeDecodeFunction, TokenizeEncodeFunction}, @@ -14,51 +7,176 @@ use daft_functions::{ use super::SQLModule; use crate::{ - ensure, error::{PlannerError, SQLPlannerResult}, functions::{SQLFunction, SQLFunctionArguments}, - invalid_operation_err, unsupported_sql_err, + invalid_operation_err, }; +fn utf8_unary( + func: impl Fn(ExprRef) -> ExprRef, + sql_name: &str, + arg_name: &str, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, +) -> SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(func(input)) + } + _ => invalid_operation_err!( + "invalid arguments for {sql_name}. Expected {sql_name}({arg_name})", + ), + } +} + +fn utf8_binary( + func: impl Fn(ExprRef, ExprRef) -> ExprRef, + sql_name: &str, + arg_name_1: &str, + arg_name_2: &str, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, +) -> SQLPlannerResult { + match inputs { + [input1, input2] => { + let input1 = planner.plan_function_arg(input1)?; + let input2 = planner.plan_function_arg(input2)?; + Ok(func(input1, input2)) + } + _ => invalid_operation_err!( + "invalid arguments for {sql_name}. Expected {sql_name}({arg_name_1}, {arg_name_2})", + ), + } +} + +fn utf8_ternary( + func: impl Fn(ExprRef, ExprRef, ExprRef) -> ExprRef, + sql_name: &str, + arg_name_1: &str, + arg_name_2: &str, + arg_name_3: &str, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, +) -> SQLPlannerResult { + match inputs { + [input1, input2, input3] => { + let input1 = planner.plan_function_arg(input1)?; + let input2 = planner.plan_function_arg(input2)?; + let input3 = planner.plan_function_arg(input3)?; + Ok(func(input1, input2, input3)) + }, + _ => invalid_operation_err!( + "invalid arguments for {sql_name}. Expected {sql_name}({arg_name_1}, {arg_name_2}, {arg_name_3})", + ), + } +} + +macro_rules! utf8_function { + ($name:ident, $sql_name:expr, $func:expr, $doc:expr, $arg_name:expr) => { + pub struct $name; + impl SQLFunction for $name { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + utf8_unary($func, $sql_name, $arg_name, inputs, planner) + } + + fn docstrings(&self, _alias: &str) -> String { + $doc.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &[$arg_name] + } + } + }; + ($name:ident, $sql_name:expr, $func:expr, $doc:expr, $arg_name_1:expr, $arg_name_2:expr) => { + pub struct $name; + impl SQLFunction for $name { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + utf8_binary($func, $sql_name, $arg_name_1, $arg_name_2, inputs, planner) + } + + fn docstrings(&self, _alias: &str) -> String { + $doc.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &[$arg_name_1, $arg_name_2] + } + } + }; + ($name:ident, $sql_name:expr, $func:expr, $doc:expr, $arg_name_1:expr, $arg_name_2:expr, $arg_name_3:expr) => { + pub struct $name; + impl SQLFunction for $name { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + utf8_ternary( + $func, + $sql_name, + $arg_name_1, + $arg_name_2, + $arg_name_3, + inputs, + planner, + ) + } + + fn docstrings(&self, _alias: &str) -> String { + $doc.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &[$arg_name_1, $arg_name_2, $arg_name_3] + } + } + }; +} + pub struct SQLModuleUtf8; impl SQLModule for SQLModuleUtf8 { fn register(parent: &mut crate::functions::SQLFunctions) { - use Utf8Expr::{ - Capitalize, Contains, EndsWith, Extract, ExtractAll, Find, Left, Length, LengthBytes, - Lower, Lpad, Lstrip, Match, Repeat, Replace, Reverse, Right, Rpad, Rstrip, Split, - StartsWith, ToDate, ToDatetime, Upper, - }; - parent.add_fn("ends_with", EndsWith); - parent.add_fn("starts_with", StartsWith); - parent.add_fn("contains", Contains); - parent.add_fn("split", Split(false)); + parent.add_fn("ends_with", SQLUtf8EndsWith); + parent.add_fn("starts_with", SQLUtf8StartsWith); + parent.add_fn("contains", SQLUtf8Contains); + parent.add_fn("split", SQLUtf8Split); // TODO add split variants // parent.add("split", f(Split(false))); - parent.add_fn("regexp_match", Match); - parent.add_fn("regexp_extract", Extract(0)); - parent.add_fn("regexp_extract_all", ExtractAll(0)); - parent.add_fn("regexp_replace", Replace(true)); - parent.add_fn("regexp_split", Split(true)); + parent.add_fn("regexp_match", SQLUtf8RegexpMatch); + parent.add_fn("regexp_extract", SQLUtf8RegexpExtract); + parent.add_fn("regexp_extract_all", SQLUtf8RegexpExtractAll); + parent.add_fn("regexp_replace", SQLUtf8RegexpReplace); + parent.add_fn("regexp_split", SQLUtf8RegexpSplit); // TODO add replace variants // parent.add("replace", f(Replace(false))); - parent.add_fn("length", Length); - parent.add_fn("length_bytes", LengthBytes); - parent.add_fn("lower", Lower); - parent.add_fn("upper", Upper); - parent.add_fn("lstrip", Lstrip); - parent.add_fn("rstrip", Rstrip); - parent.add_fn("reverse", Reverse); - parent.add_fn("capitalize", Capitalize); - parent.add_fn("left", Left); - parent.add_fn("right", Right); - parent.add_fn("find", Find); - parent.add_fn("rpad", Rpad); - parent.add_fn("lpad", Lpad); - parent.add_fn("repeat", Repeat); - - parent.add_fn("to_date", ToDate(String::new())); - parent.add_fn("to_datetime", ToDatetime(String::new(), None)); + parent.add_fn("length", SQLUtf8Length); + parent.add_fn("length_bytes", SQLUtf8LengthBytes); + parent.add_fn("lower", SQLUtf8Lower); + parent.add_fn("upper", SQLUtf8Upper); + parent.add_fn("lstrip", SQLUtf8Lstrip); + parent.add_fn("rstrip", SQLUtf8Rstrip); + parent.add_fn("reverse", SQLUtf8Reverse); + parent.add_fn("capitalize", SQLUtf8Capitalize); + parent.add_fn("left", SQLUtf8Left); + parent.add_fn("right", SQLUtf8Right); + parent.add_fn("find", SQLUtf8Find); + parent.add_fn("rpad", SQLUtf8Rpad); + parent.add_fn("lpad", SQLUtf8Lpad); + parent.add_fn("repeat", SQLUtf8Repeat); + + parent.add_fn("to_date", SQLUtf8ToDate); + parent.add_fn("to_datetime", SQLUtf8ToDatetime); parent.add_fn("count_matches", SQLCountMatches); parent.add_fn("normalize", SQLNormalize); parent.add_fn("tokenize_encode", SQLTokenizeEncode); @@ -67,255 +185,334 @@ impl SQLModule for SQLModuleUtf8 { } } -impl SQLModuleUtf8 {} - -impl SQLFunction for Utf8Expr { +utf8_function!( + SQLUtf8EndsWith, + "ends_with", + daft_functions::utf8::endswith, + "Returns true if the string ends with the specified substring", + "string_input", + "substring" +); + +utf8_function!( + SQLUtf8StartsWith, + "starts_with", + daft_functions::utf8::startswith, + "Returns true if the string starts with the specified substring", + "string_input", + "substring" +); + +utf8_function!( + SQLUtf8Contains, + "contains", + daft_functions::utf8::contains, + "Returns true if the string contains the specified substring", + "string_input", + "substring" +); + +utf8_function!( + SQLUtf8Split, + "split", + |input, pattern| daft_functions::utf8::split(input, pattern, false), + "Splits the string by the specified delimiter and returns an array of substrings", + "string_input", + "delimiter" +); + +utf8_function!( + SQLUtf8RegexpMatch, + "regexp_match", + daft_functions::utf8::match_, + "Returns true if the string matches the specified regular expression pattern", + "string_input", + "pattern" +); + +utf8_function!( + SQLUtf8RegexpReplace, + "regexp_replace", + |input, pattern, replacement| daft_functions::utf8::replace(input, pattern, replacement, false), + "Replaces all occurrences of a substring with a new string", + "string_input", + "pattern", + "replacement" +); + +utf8_function!( + SQLUtf8RegexpSplit, + "regexp_split", + |input, pattern| daft_functions::utf8::split(input, pattern, true), + "Splits the string by the specified delimiter and returns an array of substrings", + "string_input", + "delimiter" +); + +utf8_function!( + SQLUtf8Length, + "length", + daft_functions::utf8::length, + "Returns the length of the string", + "string_input" +); + +utf8_function!( + SQLUtf8LengthBytes, + "length_bytes", + daft_functions::utf8::length_bytes, + "Returns the length of the string in bytes", + "string_input" +); + +utf8_function!( + SQLUtf8Lower, + "lower", + daft_functions::utf8::lower, + "Converts the string to lowercase", + "string_input" +); + +utf8_function!( + SQLUtf8Upper, + "upper", + daft_functions::utf8::upper, + "Converts the string to uppercase", + "string_input" +); + +utf8_function!( + SQLUtf8Lstrip, + "lstrip", + daft_functions::utf8::lstrip, + "Removes leading whitespace from the string", + "string_input" +); + +utf8_function!( + SQLUtf8Rstrip, + "rstrip", + daft_functions::utf8::rstrip, + "Removes trailing whitespace from the string", + "string_input" +); + +utf8_function!( + SQLUtf8Reverse, + "reverse", + daft_functions::utf8::reverse, + "Reverses the order of characters in the string", + "string_input" +); + +utf8_function!( + SQLUtf8Capitalize, + "capitalize", + daft_functions::utf8::capitalize, + "Capitalizes the first character of the string", + "string_input" +); + +utf8_function!( + SQLUtf8Left, + "left", + daft_functions::utf8::left, + "Returns the specified number of leftmost characters from the string", + "string_input", + "length" +); + +utf8_function!( + SQLUtf8Right, + "right", + daft_functions::utf8::right, + "Returns the specified number of rightmost characters from the string", + "string_input", + "length" +); + +utf8_function!( + SQLUtf8Find, + "find", + daft_functions::utf8::find, + "Returns the index of the first occurrence of a substring within the string", + "string_input", + "substring" +); + +utf8_function!( + SQLUtf8Rpad, + "rpad", + daft_functions::utf8::rpad, + "Pads the string on the right side with the specified string until it reaches the specified length", + "string_input", "length", "pad" +); + +utf8_function!( + SQLUtf8Lpad, + "lpad", + daft_functions::utf8::lpad, + "Pads the string on the left side with the specified string until it reaches the specified length", + "string_input", "length", "pad" +); + +utf8_function!( + SQLUtf8Repeat, + "repeat", + daft_functions::utf8::repeat, + "Repeats the string the specified number of times", + "string_input", + "count" +); + +pub struct SQLUtf8RegexpExtract; + +impl SQLFunction for SQLUtf8RegexpExtract { fn to_expr( &self, inputs: &[sqlparser::ast::FunctionArg], planner: &crate::planner::SQLPlanner, ) -> SQLPlannerResult { - let inputs = self.args_to_expr_unnamed(inputs, planner)?; - to_expr(self, &inputs) + match inputs { + [input, pattern] => { + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + Ok(daft_functions::utf8::extract(input, pattern, 0)) + } + [input, pattern, idx] => { + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + let idx = planner.plan_function_arg(idx)?.as_literal().and_then(LiteralValue::as_i64).ok_or_else(|| { + PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract, found {idx:?}")) + })? as usize; + Ok(daft_functions::utf8::extract(input, pattern, idx)) + } + _ => invalid_operation_err!("regexp_extract takes exactly two or three arguments"), + } } fn docstrings(&self, _alias: &str) -> String { - match self { - Self::EndsWith => "Returns true if the string ends with the specified substring".to_string(), - Self::StartsWith => "Returns true if the string starts with the specified substring".to_string(), - Self::Contains => "Returns true if the string contains the specified substring".to_string(), - Self::Split(_) => "Splits the string by the specified delimiter and returns an array of substrings".to_string(), - Self::Match => "Returns true if the string matches the specified regular expression pattern".to_string(), - Self::Extract(_) => "Extracts the first substring that matches the specified regular expression pattern".to_string(), - Self::ExtractAll(_) => "Extracts all substrings that match the specified regular expression pattern".to_string(), - Self::Replace(_) => "Replaces all occurrences of a substring with a new string".to_string(), - Self::Like => "Returns true if the string matches the specified SQL LIKE pattern".to_string(), - Self::Ilike => "Returns true if the string matches the specified SQL LIKE pattern (case-insensitive)".to_string(), - Self::Length => "Returns the length of the string".to_string(), - Self::Lower => "Converts the string to lowercase".to_string(), - Self::Upper => "Converts the string to uppercase".to_string(), - Self::Lstrip => "Removes leading whitespace from the string".to_string(), - Self::Rstrip => "Removes trailing whitespace from the string".to_string(), - Self::Reverse => "Reverses the order of characters in the string".to_string(), - Self::Capitalize => "Capitalizes the first character of the string".to_string(), - Self::Left => "Returns the specified number of leftmost characters from the string".to_string(), - Self::Right => "Returns the specified number of rightmost characters from the string".to_string(), - Self::Find => "Returns the index of the first occurrence of a substring within the string".to_string(), - Self::Rpad => "Pads the string on the right side with the specified string until it reaches the specified length".to_string(), - Self::Lpad => "Pads the string on the left side with the specified string until it reaches the specified length".to_string(), - Self::Repeat => "Repeats the string the specified number of times".to_string(), - Self::Substr => "Returns a substring of the string starting at the specified position and length".to_string(), - Self::ToDate(_) => "Parses the string as a date using the specified format.".to_string(), - Self::ToDatetime(_, _) => "Parses the string as a datetime using the specified format.".to_string(), - Self::LengthBytes => "Returns the length of the string in bytes".to_string(), - Self::Normalize(_) => "Normalizes a string for more useful deduplication and data cleaning".to_string(), - } + "Extracts the first substring that matches the specified regular expression pattern" + .to_string() } fn arg_names(&self) -> &'static [&'static str] { - match self { - Self::EndsWith => &["string_input", "substring"], - Self::StartsWith => &["string_input", "substring"], - Self::Contains => &["string_input", "substring"], - Self::Split(_) => &["string_input", "delimiter"], - Self::Match => &["string_input", "pattern"], - Self::Extract(_) => &["string_input", "pattern"], - Self::ExtractAll(_) => &["string_input", "pattern"], - Self::Replace(_) => &["string_input", "pattern", "replacement"], - Self::Like => &["string_input", "pattern"], - Self::Ilike => &["string_input", "pattern"], - Self::Length => &["string_input"], - Self::Lower => &["string_input"], - Self::Upper => &["string_input"], - Self::Lstrip => &["string_input"], - Self::Rstrip => &["string_input"], - Self::Reverse => &["string_input"], - Self::Capitalize => &["string_input"], - Self::Left => &["string_input", "length"], - Self::Right => &["string_input", "length"], - Self::Find => &["string_input", "substring"], - Self::Rpad => &["string_input", "length", "pad"], - Self::Lpad => &["string_input", "length", "pad"], - Self::Repeat => &["string_input", "count"], - Self::Substr => &["string_input", "start", "length"], - Self::ToDate(_) => &["string_input", "format"], - Self::ToDatetime(_, _) => &["string_input", "format"], - Self::LengthBytes => &["string_input"], - Self::Normalize(_) => &[ - "input", - "remove_punct", - "lowercase", - "nfd_unicode", - "white_space", - ], - } + &["string_input", "pattern"] } } -fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { - use functions::utf8::{ - capitalize, contains, endswith, extract, extract_all, find, left, length, length_bytes, - lower, lpad, lstrip, match_, repeat, replace, reverse, right, rpad, rstrip, split, - startswith, to_date, to_datetime, upper, Utf8Expr, - }; - use Utf8Expr::{ - Capitalize, Contains, EndsWith, Extract, ExtractAll, Find, Ilike, Left, Length, - LengthBytes, Like, Lower, Lpad, Lstrip, Match, Normalize, Repeat, Replace, Reverse, Right, - Rpad, Rstrip, Split, StartsWith, Substr, ToDate, ToDatetime, Upper, - }; - match expr { - EndsWith => { - ensure!(args.len() == 2, "endswith takes exactly two arguments"); - Ok(endswith(args[0].clone(), args[1].clone())) - } - StartsWith => { - ensure!(args.len() == 2, "startswith takes exactly two arguments"); - Ok(startswith(args[0].clone(), args[1].clone())) - } - Contains => { - ensure!(args.len() == 2, "contains takes exactly two arguments"); - Ok(contains(args[0].clone(), args[1].clone())) - } - Split(true) => { - ensure!(args.len() == 2, "split takes exactly two arguments"); - Ok(split(args[0].clone(), args[1].clone(), true)) - } - Split(false) => { - ensure!(args.len() == 2, "split takes exactly two arguments"); - Ok(split(args[0].clone(), args[1].clone(), false)) - } - Match => { - ensure!(args.len() == 2, "regexp_match takes exactly two arguments"); - Ok(match_(args[0].clone(), args[1].clone())) - } - Extract(_) => match args { - [input, pattern] => Ok(extract(input.clone(), pattern.clone(), 0)), - [input, pattern, idx] => { - let idx = idx.as_literal().and_then(daft_dsl::LiteralValue::as_i64).ok_or_else(|| { - PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract, found {idx:?}")) - })?; +pub struct SQLUtf8RegexpExtractAll; - Ok(extract(input.clone(), pattern.clone(), idx as usize)) - } - _ => { - invalid_operation_err!("regexp_extract takes exactly two or three arguments") +impl SQLFunction for SQLUtf8RegexpExtractAll { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, pattern] => { + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + Ok(daft_functions::utf8::extract_all(input, pattern, 0)) } - }, - ExtractAll(_) => match args { - [input, pattern] => Ok(extract_all(input.clone(), pattern.clone(), 0)), [input, pattern, idx] => { - let idx = idx.as_literal().and_then(daft_dsl::LiteralValue::as_i64).ok_or_else(|| { - PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract, found {idx:?}")) - })?; - - Ok(extract_all(input.clone(), pattern.clone(), idx as usize)) - } - _ => { - invalid_operation_err!("regexp_extract_all takes exactly two or three arguments") + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + let idx = planner.plan_function_arg(idx)?.as_literal().and_then(LiteralValue::as_i64).ok_or_else(|| { + PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract_all, found {idx:?}")) + })? as usize; + Ok(daft_functions::utf8::extract_all(input, pattern, idx)) } - }, - Replace(_) => { - ensure!(args.len() == 3, "replace takes exactly three arguments"); - Ok(replace( - args[0].clone(), - args[1].clone(), - args[2].clone(), - false, - )) - } - Like => { - unreachable!("like should be handled by the parser") + _ => invalid_operation_err!("regexp_extract_all takes exactly two or three arguments"), } - Ilike => { - unreachable!("ilike should be handled by the parser") - } - Length => { - ensure!(args.len() == 1, "length takes exactly one argument"); - Ok(length(args[0].clone())) - } - LengthBytes => { - ensure!(args.len() == 1, "length_bytes takes exactly one argument"); - Ok(length_bytes(args[0].clone())) - } - Lower => { - ensure!(args.len() == 1, "lower takes exactly one argument"); - Ok(lower(args[0].clone())) - } - Upper => { - ensure!(args.len() == 1, "upper takes exactly one argument"); - Ok(upper(args[0].clone())) - } - Lstrip => { - ensure!(args.len() == 1, "lstrip takes exactly one argument"); - Ok(lstrip(args[0].clone())) - } - Rstrip => { - ensure!(args.len() == 1, "rstrip takes exactly one argument"); - Ok(rstrip(args[0].clone())) - } - Reverse => { - ensure!(args.len() == 1, "reverse takes exactly one argument"); - Ok(reverse(args[0].clone())) - } - Capitalize => { - ensure!(args.len() == 1, "capitalize takes exactly one argument"); - Ok(capitalize(args[0].clone())) - } - Left => { - ensure!(args.len() == 2, "left takes exactly two arguments"); - Ok(left(args[0].clone(), args[1].clone())) - } - Right => { - ensure!(args.len() == 2, "right takes exactly two arguments"); - Ok(right(args[0].clone(), args[1].clone())) - } - Find => { - ensure!(args.len() == 2, "find takes exactly two arguments"); - Ok(find(args[0].clone(), args[1].clone())) - } - Rpad => { - ensure!(args.len() == 3, "rpad takes exactly three arguments"); - Ok(rpad(args[0].clone(), args[1].clone(), args[2].clone())) - } - Lpad => { - ensure!(args.len() == 3, "lpad takes exactly three arguments"); - Ok(lpad(args[0].clone(), args[1].clone(), args[2].clone())) - } - Repeat => { - ensure!(args.len() == 2, "repeat takes exactly two arguments"); - Ok(repeat(args[0].clone(), args[1].clone())) - } - Substr => { - unreachable!("substr should be handled by the parser") - } - ToDate(_) => { - ensure!(args.len() == 2, "to_date takes exactly two arguments"); - let fmt = match args[1].as_ref().as_literal() { - Some(LiteralValue::Utf8(s)) => s, - _ => invalid_operation_err!("to_date format must be a string"), - }; - Ok(to_date(args[0].clone(), fmt)) - } - ToDatetime(..) => { - ensure!( - args.len() >= 2, - "to_datetime takes either two or three arguments" - ); - let fmt = match args[1].as_ref().as_literal() { - Some(LiteralValue::Utf8(s)) => s, - _ => invalid_operation_err!("to_datetime format must be a string"), - }; - let tz = match args.get(2).and_then(|e| e.as_ref().as_literal()) { - Some(LiteralValue::Utf8(s)) => Some(s.as_str()), - _ => invalid_operation_err!("to_datetime timezone must be a string"), - }; - - Ok(to_datetime(args[0].clone(), fmt, tz)) + } + + fn docstrings(&self, _alias: &str) -> String { + "Extracts all substrings that match the specified regular expression pattern".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["string_input", "pattern"] + } +} + +pub struct SQLUtf8ToDate; + +impl SQLFunction for SQLUtf8ToDate { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, fmt] => { + let input = planner.plan_function_arg(input)?; + let fmt = planner.plan_function_arg(fmt)?; + let fmt = fmt + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation("to_date format must be a string") + })?; + Ok(daft_functions::utf8::to_date(input, fmt)) + } + _ => invalid_operation_err!("to_date takes exactly two arguments"), } - Normalize(_) => { - unsupported_sql_err!("normalize") + } + + fn docstrings(&self, _alias: &str) -> String { + "Parses the string as a date using the specified format.".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["string_input", "format"] + } +} + +pub struct SQLUtf8ToDatetime; + +impl SQLFunction for SQLUtf8ToDatetime { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, fmt] => { + let input = planner.plan_function_arg(input)?; + let fmt = planner.plan_function_arg(fmt)?; + let fmt = fmt + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation("to_datetime format must be a string") + })?; + Ok(daft_functions::utf8::to_datetime(input, fmt, None)) + } + [input, fmt, tz] => { + let input = planner.plan_function_arg(input)?; + let fmt = planner.plan_function_arg(fmt)?; + let fmt = fmt + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation("to_datetime format must be a string") + })?; + let tz = planner.plan_function_arg(tz)?; + let tz = tz.as_literal().and_then(|lit| lit.as_str()); + Ok(daft_functions::utf8::to_datetime(input, fmt, tz)) + } + _ => invalid_operation_err!("to_datetime takes either two or three arguments"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Parses the string as a datetime using the specified format.".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["string_input", "format"] + } } pub struct SQLCountMatches; @@ -404,7 +601,10 @@ impl SQLFunction for SQLNormalize { match inputs { [input] => { let input = planner.plan_function_arg(input)?; - Ok(normalize(input, Utf8NormalizeOptions::default())) + Ok(daft_functions::utf8::normalize( + input, + Utf8NormalizeOptions::default(), + )) } [input, args @ ..] => { let input = planner.plan_function_arg(input)?; @@ -413,7 +613,7 @@ impl SQLFunction for SQLNormalize { &["remove_punct", "lowercase", "nfd_unicode", "white_space"], 0, )?; - Ok(normalize(input, args)) + Ok(daft_functions::utf8::normalize(input, args)) } _ => invalid_operation_err!("Invalid arguments for normalize"), } diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index d241fc7fbb..6f73cedd04 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -7,12 +7,13 @@ use std::{ use common_error::{DaftError, DaftResult}; use daft_core::prelude::*; use daft_dsl::{ - col, - functions::utf8::{ilike, like, to_date, to_datetime}, - has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator, - Subquery, + col, has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, + Operator, Subquery, +}; +use daft_functions::{ + numeric::{ceil::ceil, floor::floor}, + utf8::{ilike, like, to_date, to_datetime}, }; -use daft_functions::numeric::{ceil::ceil, floor::floor}; use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{ @@ -1261,7 +1262,7 @@ impl SQLPlanner { let start = self.plan_expr(substring_from)?; let length = self.plan_expr(substring_for)?; - Ok(daft_dsl::functions::utf8::substr(expr, start, length)) + Ok(daft_functions::utf8::substr(expr, start, length)) } SQLExpr::Substring { special: false, .. } => { unsupported_sql_err!("`SUBSTRING(expr [FROM start] [FOR len])` syntax")