Skip to content

Commit

Permalink
feat(python): Expose string expression nodes to python (#16221)
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller authored May 16, 2024
1 parent a743934 commit 98a2d9b
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 4 deletions.
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub(super) use self::rolling::RollingFunction;
#[cfg(feature = "rolling_window_by")]
pub(super) use self::rolling_by::RollingFunctionBy;
#[cfg(feature = "strings")]
pub(crate) use self::strings::StringFunction;
pub use self::strings::StringFunction;
#[cfg(feature = "dtype-struct")]
pub(crate) use self::struct_::StructFunction;
#[cfg(feature = "trigonometry")]
Expand Down
197 changes: 194 additions & 3 deletions py-polars/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use polars_core::series::IsSorted;
use polars_plan::dsl::function_expr::rolling::RollingFunction;
use polars_plan::dsl::function_expr::rolling_by::RollingFunctionBy;
use polars_plan::dsl::function_expr::trigonometry::TrigonometricFunction;
use polars_plan::dsl::BooleanFunction;
use polars_plan::dsl::{BooleanFunction, StringFunction};
use polars_plan::prelude::{
AAggExpr, AExpr, FunctionExpr, GroupbyOptions, LiteralValue, Operator, PowFunction,
WindowMapping, WindowType,
Expand Down Expand Up @@ -59,6 +59,53 @@ pub enum PyOperator {
LogicalOr,
}

#[pyclass(name = "StringFunction")]
pub enum PyStringFunction {
ConcatHorizontal,
ConcatVertical,
Contains,
CountMatches,
EndsWith,
Explode,
Extract,
ExtractAll,
ExtractGroups,
Find,
ToInteger,
LenBytes,
LenChars,
Lowercase,
JsonDecode,
JsonPathMatch,
Replace,
Reverse,
PadStart,
PadEnd,
Slice,
Head,
Tail,
HexEncode,
HexDecode,
Base64Encode,
Base64Decode,
StartsWith,
StripChars,
StripCharsStart,
StripCharsEnd,
StripPrefix,
StripSuffix,
SplitExact,
SplitN,
Strptime,
Split,
ToDecimal,
Titlecase,
Uppercase,
ZFill,
ContainsMany,
ReplaceMany,
}

#[pymethods]
impl PyOperator {
fn __hash__(&self) -> u64 {
Expand Down Expand Up @@ -571,8 +618,152 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
FunctionExpr::ListExpr(_) => {
return Err(PyNotImplementedError::new_err("list expr"))
},
FunctionExpr::StringExpr(_) => {
return Err(PyNotImplementedError::new_err("string expr"))
FunctionExpr::StringExpr(strfun) => match strfun {
StringFunction::ConcatHorizontal {
delimiter,
ignore_nulls,
} => (
PyStringFunction::ConcatHorizontal.into_py(py),
delimiter,
ignore_nulls,
)
.to_object(py),
StringFunction::ConcatVertical {
delimiter,
ignore_nulls,
} => (
PyStringFunction::ConcatVertical.into_py(py),
delimiter,
ignore_nulls,
)
.to_object(py),
StringFunction::Contains { literal, strict } => {
(PyStringFunction::Contains.into_py(py), literal, strict).to_object(py)
},
StringFunction::CountMatches(_) => {
(PyStringFunction::CountMatches.into_py(py),).to_object(py)
},
StringFunction::EndsWith => {
(PyStringFunction::EndsWith.into_py(py),).to_object(py)
},
StringFunction::Explode => {
(PyStringFunction::Explode.into_py(py),).to_object(py)
},
StringFunction::Extract(_) => {
(PyStringFunction::Extract.into_py(py),).to_object(py)
},
StringFunction::ExtractAll => {
(PyStringFunction::ExtractAll.into_py(py),).to_object(py)
},
StringFunction::ExtractGroups { dtype, pat } => (
PyStringFunction::ExtractGroups.into_py(py),
Wrap(dtype.clone()).to_object(py),
pat,
)
.to_object(py),
StringFunction::Find { literal, strict } => {
(PyStringFunction::Find.into_py(py), literal, strict).to_object(py)
},
StringFunction::ToInteger(_) => {
(PyStringFunction::ToInteger.into_py(py),).to_object(py)
},
StringFunction::LenBytes => {
(PyStringFunction::LenBytes.into_py(py),).to_object(py)
},
StringFunction::LenChars => {
(PyStringFunction::LenChars.into_py(py),).to_object(py)
},
StringFunction::Lowercase => {
(PyStringFunction::Lowercase.into_py(py),).to_object(py)
},
StringFunction::JsonDecode {
dtype: _,
infer_schema_len,
} => (PyStringFunction::JsonDecode.into_py(py), infer_schema_len).to_object(py),
StringFunction::JsonPathMatch => {
(PyStringFunction::JsonPathMatch.into_py(py),).to_object(py)
},
StringFunction::Replace { n, literal } => {
(PyStringFunction::Replace.into_py(py), n, literal).to_object(py)
},
StringFunction::Reverse => {
(PyStringFunction::Reverse.into_py(py),).to_object(py)
},
StringFunction::PadStart { length, fill_char } => {
(PyStringFunction::PadStart.into_py(py), length, fill_char).to_object(py)
},
StringFunction::PadEnd { length, fill_char } => {
(PyStringFunction::PadEnd.into_py(py), length, fill_char).to_object(py)
},
StringFunction::Slice => (PyStringFunction::Slice.into_py(py),).to_object(py),
StringFunction::Head => (PyStringFunction::Head.into_py(py),).to_object(py),
StringFunction::Tail => (PyStringFunction::Tail.into_py(py),).to_object(py),
StringFunction::HexEncode => {
(PyStringFunction::HexEncode.into_py(py),).to_object(py)
},
StringFunction::HexDecode(_) => {
(PyStringFunction::HexDecode.into_py(py),).to_object(py)
},
StringFunction::Base64Encode => {
(PyStringFunction::Base64Encode.into_py(py),).to_object(py)
},
StringFunction::Base64Decode(_) => {
(PyStringFunction::Base64Decode.into_py(py),).to_object(py)
},
StringFunction::StartsWith => {
(PyStringFunction::StartsWith.into_py(py),).to_object(py)
},
StringFunction::StripChars => {
(PyStringFunction::StripChars.into_py(py),).to_object(py)
},
StringFunction::StripCharsStart => {
(PyStringFunction::StripCharsStart.into_py(py),).to_object(py)
},
StringFunction::StripCharsEnd => {
(PyStringFunction::StripCharsEnd.into_py(py),).to_object(py)
},
StringFunction::StripPrefix => {
(PyStringFunction::StripPrefix.into_py(py),).to_object(py)
},
StringFunction::StripSuffix => {
(PyStringFunction::StripSuffix.into_py(py),).to_object(py)
},
StringFunction::SplitExact { n, inclusive } => {
(PyStringFunction::SplitExact.into_py(py), n, inclusive).to_object(py)
},
StringFunction::SplitN(_) => {
(PyStringFunction::SplitN.into_py(py),).to_object(py)
},
StringFunction::Strptime(_, _) => {
(PyStringFunction::Strptime.into_py(py),).to_object(py)
},
StringFunction::Split(_) => {
(PyStringFunction::Split.into_py(py),).to_object(py)
},
StringFunction::ToDecimal(_) => {
(PyStringFunction::ToDecimal.into_py(py),).to_object(py)
},
StringFunction::Titlecase => {
(PyStringFunction::Titlecase.into_py(py),).to_object(py)
},
StringFunction::Uppercase => {
(PyStringFunction::Uppercase.into_py(py),).to_object(py)
},
StringFunction::ZFill => (PyStringFunction::ZFill.into_py(py),).to_object(py),
StringFunction::ContainsMany {
ascii_case_insensitive,
} => (
PyStringFunction::ContainsMany.into_py(py),
ascii_case_insensitive,
)
.to_object(py),
StringFunction::ReplaceMany {
ascii_case_insensitive,
} => (
PyStringFunction::ReplaceMany.into_py(py),
ascii_case_insensitive,
)
.to_object(py),
},
FunctionExpr::StructExpr(_) => {
return Err(PyNotImplementedError::new_err("struct expr"))
Expand Down
1 change: 1 addition & 0 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ fn _expr_nodes(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<Len>().unwrap();
m.add_class::<Window>().unwrap();
m.add_class::<PyOperator>().unwrap();
m.add_class::<PyStringFunction>().unwrap();
// Options
m.add_class::<PyWindowMapping>().unwrap();
m.add_class::<PyRollingGroupOptions>().unwrap();
Expand Down

0 comments on commit 98a2d9b

Please sign in to comment.