diff --git a/Cargo.lock b/Cargo.lock index 738ac004c6..ca0ef238de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1604,6 +1604,7 @@ dependencies = [ "serde", "serde_json", "sketches-ddsketch", + "unicode-normalization", "xxhash-rust", ] @@ -5176,9 +5177,9 @@ checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" [[package]] name = "unicode-normalization" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" dependencies = [ "tinyvec", ] diff --git a/daft/daft.pyi b/daft/daft.pyi index 63acd596a2..5d61c5d220 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1126,6 +1126,7 @@ class PyExpr: def utf8_substr(self, start: PyExpr, length: PyExpr) -> PyExpr: ... def utf8_to_date(self, format: str) -> PyExpr: ... def utf8_to_datetime(self, format: str, timezone: str | None = None) -> PyExpr: ... + def utf8_normalize(self, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PyExpr: ... def image_decode(self, raise_error_on_failure: bool) -> PyExpr: ... def image_encode(self, image_format: ImageFormat) -> PyExpr: ... def image_resize(self, w: int, h: int) -> PyExpr: ... @@ -1261,6 +1262,7 @@ class PySeries: def utf8_substr(self, start: PySeries, length: PySeries | None = None) -> PySeries: ... def utf8_to_date(self, format: str) -> PySeries: ... def utf8_to_datetime(self, format: str, timezone: str | None = None) -> PySeries: ... + def utf8_normalize(self, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PySeries: ... def is_nan(self) -> PySeries: ... def is_inf(self) -> PySeries: ... def not_nan(self) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index ecdb894bb3..4cd689abb4 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1664,6 +1664,47 @@ def to_datetime(self, format: str, timezone: str | None = None) -> Expression: """ return Expression._from_pyexpr(self._expr.utf8_to_datetime(format, timezone)) + def normalize( + self, + *, + remove_punct: bool = True, + lowercase: bool = True, + nfd_unicode: bool = True, + white_space: bool = True, + ): + """Normalizes a string for more useful deduplication. + + .. NOTE:: + All processing options are on by default. + + Example: + >>> df = daft.from_pydict({"x": ["hello world", "Hello, world!", "HELLO, \\nWORLD!!!!"]}) + >>> df = df.with_column("normalized", df["x"].str.normalize()) + >>> df.show() + ╭───────────────┬─────────────╮ + │ x ┆ normalized │ + │ --- ┆ --- │ + │ Utf8 ┆ Utf8 │ + ╞═══════════════╪═════════════╡ + │ hello world ┆ hello world │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ Hello, world! ┆ hello world │ + ├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤ + │ HELLO, ┆ hello world │ + │ WORLD!!!! ┆ │ + ╰───────────────┴─────────────╯ + + Args: + remove_punct: Whether to remove all punctuation (ASCII). + lowercase: Whether to convert the string to lowercase. + nfd_unicode: Whether to normalize and decompose Unicode characters according to NFD. + white_space: Whether to normalize whitespace, replacing newlines etc with spaces and removing double spaces. + + Returns: + Expression: a String expression which is normalized. + """ + return Expression._from_pyexpr(self._expr.utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space)) + class ExpressionListNamespace(ExpressionNamespace): def join(self, delimiter: str | Expression) -> Expression: diff --git a/daft/series.py b/daft/series.py index b55286bf1e..854b92eeba 100644 --- a/daft/series.py +++ b/daft/series.py @@ -849,6 +849,25 @@ def substr(self, start: Series, length: Series | None = None) -> Series: assert self._series is not None and start._series is not None return Series._from_pyseries(self._series.utf8_substr(start._series, length._series)) + def normalize( + self, + *, + remove_punct: bool = True, + lowercase: bool = True, + nfd_unicode: bool = True, + white_space: bool = True, + ) -> Series: + if not isinstance(remove_punct, bool): + raise ValueError(f"expected bool for remove_punct but got {type(remove_punct)}") + if not isinstance(lowercase, bool): + raise ValueError(f"expected bool for lowercase but got {type(lowercase)}") + if not isinstance(nfd_unicode, bool): + raise ValueError(f"expected bool for nfd_unicode but got {type(nfd_unicode)}") + if not isinstance(white_space, bool): + raise ValueError(f"expected bool for white_space but got {type(white_space)}") + assert self._series is not None + return Series._from_pyseries(self._series.utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space)) + class SeriesDateNamespace(SeriesNamespace): def date(self) -> Series: diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index d54c748790..27f143d882 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -150,6 +150,7 @@ The following methods are available under the ``expr.str`` attribute. Expression.str.substr Expression.str.to_date Expression.str.to_datetime + Expression.str.normalize .. _api-float-expression-operations: diff --git a/src/daft-core/Cargo.toml b/src/daft-core/Cargo.toml index 84178b1018..bd003f993f 100644 --- a/src/daft-core/Cargo.toml +++ b/src/daft-core/Cargo.toml @@ -46,6 +46,7 @@ regex = {workspace = true} serde = {workspace = true} serde_json = {workspace = true} sketches-ddsketch = {workspace = true} +unicode-normalization = "0.1.23" [dependencies.image] default-features = false diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index 742e01a6aa..d58a2bf8c2 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -56,7 +56,7 @@ mod utf8; pub use sort::{build_multi_array_bicompare, build_multi_array_compare}; -pub use utf8::PadPlacement; +pub use utf8::{PadPlacement, Utf8NormalizeOptions}; use common_error::DaftResult; diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index e97c1ca656..a9c8661501 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -18,6 +18,8 @@ use chrono::Datelike; use common_error::{DaftError, DaftResult}; use itertools::Itertools; use num_traits::NumCast; +use serde::{Deserialize, Serialize}; +use unicode_normalization::UnicodeNormalization; use super::{as_arrow::AsArrow, full::FullNull}; @@ -348,6 +350,14 @@ pub enum PadPlacement { Right, } +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8NormalizeOptions { + pub remove_punct: bool, + pub lowercase: bool, + pub nfd_unicode: bool, + pub white_space: bool, +} + impl Utf8Array { pub fn endswith(&self, pattern: &Utf8Array) -> DaftResult { self.binary_broadcasted_compare( @@ -1328,6 +1338,44 @@ impl Utf8Array { Ok(result) } + pub fn normalize(&self, opts: Utf8NormalizeOptions) -> DaftResult { + let whitespace_regex = regex::Regex::new(r"\s+").unwrap(); + + let arrow_result = self + .as_arrow() + .iter() + .map(|maybe_s| { + if let Some(s) = maybe_s { + let mut s = s.to_string(); + + if opts.remove_punct { + s = s.chars().filter(|c| !c.is_ascii_punctuation()).collect(); + } + + if opts.lowercase { + s = s.to_lowercase(); + } + + if opts.white_space { + s = whitespace_regex + .replace_all(s.as_str().trim(), " ") + .to_string(); + } + + if opts.nfd_unicode { + s = s.nfd().collect(); + } + + Ok(Some(s)) + } else { + Ok(None) + } + }) + .collect::>>()?; + + Ok(Utf8Array::from((self.name(), Box::new(arrow_result)))) + } + fn unary_broadcasted_op(&self, operation: ScalarKernel) -> DaftResult where ScalarKernel: Fn(&str) -> Cow<'_, str>, diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index b8cc11fca9..39d616b38f 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -8,7 +8,11 @@ use pyo3::{ }; use crate::{ - array::{ops::DaftLogical, pseudo_arrow::PseudoArrowArray, DataArray}, + array::{ + ops::{DaftLogical, Utf8NormalizeOptions}, + pseudo_arrow::PseudoArrowArray, + DataArray, + }, count_mode::CountMode, datatypes::{DataType, Field, ImageFormat, ImageMode, PythonType}, ffi, @@ -488,6 +492,23 @@ impl PySeries { Ok(self.series.utf8_to_datetime(format, timezone)?.into()) } + pub fn utf8_normalize( + &self, + remove_punct: bool, + lowercase: bool, + nfd_unicode: bool, + white_space: bool, + ) -> PyResult { + let opts = Utf8NormalizeOptions { + remove_punct, + lowercase, + nfd_unicode, + white_space, + }; + + Ok(self.series.utf8_normalize(opts)?.into()) + } + pub fn is_nan(&self) -> PyResult { Ok(self.series.is_nan()?.into()) } diff --git a/src/daft-core/src/series/ops/utf8.rs b/src/daft-core/src/series/ops/utf8.rs index 6e862635dd..65fc738c51 100644 --- a/src/daft-core/src/series/ops/utf8.rs +++ b/src/daft-core/src/series/ops/utf8.rs @@ -1,4 +1,4 @@ -use crate::array::ops::PadPlacement; +use crate::array::ops::{PadPlacement, Utf8NormalizeOptions}; use crate::series::array_impl::IntoSeries; use crate::series::Series; use crate::{datatypes::*, with_match_integer_daft_types}; @@ -247,4 +247,8 @@ impl Series { pub fn utf8_to_datetime(&self, format: &str, timezone: Option<&str>) -> DaftResult { self.with_utf8_array(|arr| Ok(arr.to_datetime(format, timezone)?.into_series())) } + + pub fn utf8_normalize(&self, opts: Utf8NormalizeOptions) -> DaftResult { + self.with_utf8_array(|arr| Ok(arr.normalize(opts)?.into_series())) + } } diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs index a3b8ca2e29..212678a6f9 100644 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ b/src/daft-dsl/src/functions/utf8/mod.rs @@ -12,6 +12,7 @@ mod lower; mod lpad; mod lstrip; mod match_; +mod normalize; mod repeat; mod replace; mod reverse; @@ -27,6 +28,7 @@ mod upper; use capitalize::CapitalizeEvaluator; use contains::ContainsEvaluator; +use daft_core::array::ops::Utf8NormalizeOptions; use endswith::EndswithEvaluator; use extract::ExtractEvaluator; use extract_all::ExtractAllEvaluator; @@ -38,6 +40,7 @@ use like::LikeEvaluator; use lower::LowerEvaluator; use lpad::LpadEvaluator; use lstrip::LstripEvaluator; +use normalize::NormalizeEvaluator; use repeat::RepeatEvaluator; use replace::ReplaceEvaluator; use reverse::ReverseEvaluator; @@ -84,6 +87,7 @@ pub enum Utf8Expr { Substr, ToDate(String), ToDatetime(String, Option), + Normalize(Utf8NormalizeOptions), } impl Utf8Expr { @@ -117,6 +121,7 @@ impl Utf8Expr { Substr => &SubstrEvaluator {}, ToDate(_) => &ToDateEvaluator {}, ToDatetime(_, _) => &ToDatetimeEvaluator {}, + Normalize(_) => &NormalizeEvaluator {}, } } } @@ -331,3 +336,11 @@ pub fn to_datetime(data: ExprRef, format: &str, timezone: Option<&str>) -> ExprR } .into() } + +pub fn normalize(data: ExprRef, opts: Utf8NormalizeOptions) -> ExprRef { + Expr::Function { + func: super::FunctionExpr::Utf8(Utf8Expr::Normalize(opts)), + inputs: vec![data], + } + .into() +} diff --git a/src/daft-dsl/src/functions/utf8/normalize.rs b/src/daft-dsl/src/functions/utf8/normalize.rs new file mode 100644 index 0000000000..0c3f23b5f5 --- /dev/null +++ b/src/daft-dsl/src/functions/utf8/normalize.rs @@ -0,0 +1,53 @@ +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::Series, +}; + +use crate::functions::FunctionExpr; +use crate::ExprRef; +use common_error::{DaftError, DaftResult}; + +use super::{super::FunctionEvaluator, Utf8Expr}; + +pub(super) struct NormalizeEvaluator {} + +impl FunctionEvaluator for NormalizeEvaluator { + fn fn_name(&self) -> &'static str { + "normalize" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to normalize to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { + match inputs { + [data] => { + let opts = match expr { + FunctionExpr::Utf8(Utf8Expr::Normalize(opts)) => opts, + _ => panic!("Expected Utf8 Normalize Expr, got {expr}"), + }; + data.utf8_normalize(*opts) + } + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 98cd19dcac..e086de0ec7 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -4,6 +4,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use common_error::DaftError; +use daft_core::array::ops::Utf8NormalizeOptions; use daft_core::python::datatype::PyTimeUnit; use daft_core::python::PySeries; use serde::{Deserialize, Serialize}; @@ -687,6 +688,24 @@ impl PyExpr { Ok(to_datetime(self.into(), format, timezone).into()) } + pub fn utf8_normalize( + &self, + remove_punct: bool, + lowercase: bool, + nfd_unicode: bool, + white_space: bool, + ) -> PyResult { + use crate::functions::utf8::normalize; + let opts = Utf8NormalizeOptions { + remove_punct, + lowercase, + nfd_unicode, + white_space, + }; + + Ok(normalize(self.into(), opts).into()) + } + pub fn image_decode(&self, raise_error_on_failure: bool) -> PyResult { use crate::functions::image::decode; Ok(decode(self.into(), raise_error_on_failure).into()) diff --git a/tests/expressions/typing/test_str.py b/tests/expressions/typing/test_str.py index 5b2e146076..0d7c27dc6f 100644 --- a/tests/expressions/typing/test_str.py +++ b/tests/expressions/typing/test_str.py @@ -189,3 +189,27 @@ def test_str_to_datetime(): run_kernel=lambda: s.str.to_datetime(format), resolvable=True, ) + + +@pytest.mark.parametrize("remove_punct", [False, True]) +@pytest.mark.parametrize("lowercase", [False, True]) +@pytest.mark.parametrize("nfd_unicode", [False, True]) +@pytest.mark.parametrize("white_space", [False, True]) +def test_str_normalize(remove_punct, lowercase, nfd_unicode, white_space): + s = Series.from_arrow(pa.array(["hello world", "Hello, world!", "Hêllø, \nworłd!!"]), name="col") + assert_typing_resolve_vs_runtime_behavior( + data=[s], + expr=col("col").str.normalize( + remove_punct=remove_punct, + lowercase=lowercase, + nfd_unicode=nfd_unicode, + white_space=white_space, + ), + run_kernel=lambda: s.str.normalize( + remove_punct=remove_punct, + lowercase=lowercase, + nfd_unicode=nfd_unicode, + white_space=white_space, + ), + resolvable=True, + ) diff --git a/tests/series/test_utf8_ops.py b/tests/series/test_utf8_ops.py index 705a20e671..e4c7bcb888 100644 --- a/tests/series/test_utf8_ops.py +++ b/tests/series/test_utf8_ops.py @@ -1,6 +1,9 @@ from __future__ import annotations import datetime +import re +import string +import unicodedata import pyarrow as pa import pytest @@ -1476,3 +1479,68 @@ def test_series_utf8_to_bad_datetime() -> None: s = Series.from_arrow(pa.array(["2021-100-20"])) with pytest.raises(ValueError): s.str.to_datetime("%Y-%m-%d %H:%M:%S", "UTC") + + +# source: RedPajama +def manual_normalize(text, remove_punct, lowercase, nfd_unicode, white_space): + if text is None: + return None + + if remove_punct: + text = text.translate(str.maketrans("", "", string.punctuation)) + + if lowercase: + text = text.lower() + + if white_space: + text = text.strip() + text = re.sub(r"\s+", " ", text) + + if nfd_unicode: + text = unicodedata.normalize("NFD", text) + + return text + + +NORMALIZE_TEST_DATA = [ + "regular text no changes", + "Regular texT WITH uPpErCaSe", + "ALL UPPERCASE TEXT", + "text, with... punctuation!!!", + "!&# #%*!()%*@# &*%#& @*( #*(@%()))", + "!@#$%^&*()+_-=~`[]\\/.,';?><\":|}{", + "UPPERCASE, AND, PUNCTUATION!?", + "füñķÿ úňìčõðė", + "füñķÿ, úňìčõðė!", + "FüÑķŸ úňÌčõÐė", + "FüÑķŸ, úňÌčõÐė!", + "way too much space", + " space all over the place ", + "other\ntypes\tof\r\nspace characters", + "too\n\n\t\r\n \n\t\tmuch\n\t\tspace\t \n\n \t\r\n \t", + None, + "TOO\n\n\t\r\n \n\t\tMUCH\n\t\tsPACe\t \n\n \t\r\n \t", + "too,\n\n?\t\r\n \n\t\tmuc%h!!%\n\t\t\\SPACE??!\t \n\n \t\r\n \t", + "FüÑķŸ, úňÌčõÐė! AND EVERYTHING else TOO \t\na\t\t\nbCDe 😃😌😝", + "", + "specialcase", + "SPECIALCASE", + "😃", + None, +] + + +@pytest.mark.parametrize("remove_punct", [False, True]) +@pytest.mark.parametrize("lowercase", [False, True]) +@pytest.mark.parametrize("nfd_unicode", [False, True]) +@pytest.mark.parametrize("white_space", [False, True]) +def test_series_utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space) -> None: + s = Series.from_pylist(NORMALIZE_TEST_DATA) + a = s.str.normalize( + remove_punct=remove_punct, + lowercase=lowercase, + nfd_unicode=nfd_unicode, + white_space=white_space, + ).to_pylist() + b = [manual_normalize(t, remove_punct, lowercase, nfd_unicode, white_space) for t in NORMALIZE_TEST_DATA] + assert a == b diff --git a/tests/table/utf8/test_normalize.py b/tests/table/utf8/test_normalize.py new file mode 100644 index 0000000000..ad204cee47 --- /dev/null +++ b/tests/table/utf8/test_normalize.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import re +import string +import unicodedata + +import pytest + +from daft.expressions import col +from daft.table import MicroPartition + + +# source: RedPajama +def manual_normalize(text, remove_punct, lowercase, nfd_unicode, white_space): + if text is None: + return None + + if remove_punct: + text = text.translate(str.maketrans("", "", string.punctuation)) + + if lowercase: + text = text.lower() + + if white_space: + text = text.strip() + text = re.sub(r"\s+", " ", text) + + if nfd_unicode: + text = unicodedata.normalize("NFD", text) + + return text + + +NORMALIZE_TEST_DATA = [ + "regular text no changes", + "Regular texT WITH uPpErCaSe", + "ALL UPPERCASE TEXT", + "text, with... punctuation!!!", + "!&# #%*!()%*@# &*%#& @*( #*(@%()))", + "!@#$%^&*()+_-=~`[]\\/.,';?><\":|}{", + "UPPERCASE, AND, PUNCTUATION!?", + "füñķÿ úňìčõðė", + "füñķÿ, úňìčõðė!", + "FüÑķŸ úňÌčõÐė", + "FüÑķŸ, úňÌčõÐė!", + "way too much space", + " space all over the place ", + "other\ntypes\tof\r\nspace characters", + "too\n\n\t\r\n \n\t\tmuch\n\t\tspace\t \n\n \t\r\n \t", + None, + "TOO\n\n\t\r\n \n\t\tMUCH\n\t\tsPACe\t \n\n \t\r\n \t", + "too,\n\n?\t\r\n \n\t\tmuc%h!!%\n\t\t\\SPACE??!\t \n\n \t\r\n \t", + "FüÑķŸ, úňÌčõÐė! AND EVERYTHING else TOO \t\na\t\t\nbCDe 😃😌😝", + "", + "specialcase", + "SPECIALCASE", + "😃", + None, +] + + +@pytest.mark.parametrize("remove_punct", [False, True]) +@pytest.mark.parametrize("lowercase", [False, True]) +@pytest.mark.parametrize("nfd_unicode", [False, True]) +@pytest.mark.parametrize("white_space", [False, True]) +def test_utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space): + table = MicroPartition.from_pydict({"col": NORMALIZE_TEST_DATA}) + result = table.eval_expression_list( + [ + col("col").str.normalize( + remove_punct=remove_punct, + lowercase=lowercase, + nfd_unicode=nfd_unicode, + white_space=white_space, + ) + ] + ) + expected = [manual_normalize(t, remove_punct, lowercase, nfd_unicode, white_space) for t in NORMALIZE_TEST_DATA] + assert result.to_pydict() == {"col": expected}