Skip to content

Commit

Permalink
[FEAT] String normalize expression (#2450)
Browse files Browse the repository at this point in the history
Adds an expression to normalize strings, for preprocessing for
deduplication. Offers four options: removing punctuation, lowercasing,
removing extra whitespace, and Unicode normalization.
  • Loading branch information
Vince7778 committed Jun 29, 2024
1 parent 86ded60 commit 29b02a1
Show file tree
Hide file tree
Showing 16 changed files with 399 additions and 5 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,7 @@ class PyExpr:
def utf8_substr(self, start: PyExpr, length: PyExpr) -> PyExpr: ...
def utf8_to_date(self, format: str) -> PyExpr: ...
def utf8_to_datetime(self, format: str, timezone: str | None = None) -> PyExpr: ...
def utf8_normalize(self, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PyExpr: ...
def image_decode(self, raise_error_on_failure: bool) -> PyExpr: ...
def image_encode(self, image_format: ImageFormat) -> PyExpr: ...
def image_resize(self, w: int, h: int) -> PyExpr: ...
Expand Down Expand Up @@ -1261,6 +1262,7 @@ class PySeries:
def utf8_substr(self, start: PySeries, length: PySeries | None = None) -> PySeries: ...
def utf8_to_date(self, format: str) -> PySeries: ...
def utf8_to_datetime(self, format: str, timezone: str | None = None) -> PySeries: ...
def utf8_normalize(self, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def is_inf(self) -> PySeries: ...
def not_nan(self) -> PySeries: ...
Expand Down
41 changes: 41 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,47 @@ def to_datetime(self, format: str, timezone: str | None = None) -> Expression:
"""
return Expression._from_pyexpr(self._expr.utf8_to_datetime(format, timezone))

def normalize(
self,
*,
remove_punct: bool = True,
lowercase: bool = True,
nfd_unicode: bool = True,
white_space: bool = True,
):
"""Normalizes a string for more useful deduplication.
.. NOTE::
All processing options are on by default.
Example:
>>> df = daft.from_pydict({"x": ["hello world", "Hello, world!", "HELLO, \\nWORLD!!!!"]})
>>> df = df.with_column("normalized", df["x"].str.normalize())
>>> df.show()
╭───────────────┬─────────────╮
│ x ┆ normalized │
│ --- ┆ --- │
│ Utf8 ┆ Utf8 │
╞═══════════════╪═════════════╡
│ hello world ┆ hello world │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ Hello, world! ┆ hello world │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌┤
│ HELLO, ┆ hello world │
│ WORLD!!!! ┆ │
╰───────────────┴─────────────╯
Args:
remove_punct: Whether to remove all punctuation (ASCII).
lowercase: Whether to convert the string to lowercase.
nfd_unicode: Whether to normalize and decompose Unicode characters according to NFD.
white_space: Whether to normalize whitespace, replacing newlines etc with spaces and removing double spaces.
Returns:
Expression: a String expression which is normalized.
"""
return Expression._from_pyexpr(self._expr.utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space))


class ExpressionListNamespace(ExpressionNamespace):
def join(self, delimiter: str | Expression) -> Expression:
Expand Down
19 changes: 19 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,25 @@ def substr(self, start: Series, length: Series | None = None) -> Series:
assert self._series is not None and start._series is not None
return Series._from_pyseries(self._series.utf8_substr(start._series, length._series))

def normalize(
self,
*,
remove_punct: bool = True,
lowercase: bool = True,
nfd_unicode: bool = True,
white_space: bool = True,
) -> Series:
if not isinstance(remove_punct, bool):
raise ValueError(f"expected bool for remove_punct but got {type(remove_punct)}")
if not isinstance(lowercase, bool):
raise ValueError(f"expected bool for lowercase but got {type(lowercase)}")
if not isinstance(nfd_unicode, bool):
raise ValueError(f"expected bool for nfd_unicode but got {type(nfd_unicode)}")
if not isinstance(white_space, bool):
raise ValueError(f"expected bool for white_space but got {type(white_space)}")
assert self._series is not None
return Series._from_pyseries(self._series.utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space))


class SeriesDateNamespace(SeriesNamespace):
def date(self) -> Series:
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ The following methods are available under the ``expr.str`` attribute.
Expression.str.substr
Expression.str.to_date
Expression.str.to_datetime
Expression.str.normalize

.. _api-float-expression-operations:

Expand Down
1 change: 1 addition & 0 deletions src/daft-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ regex = {workspace = true}
serde = {workspace = true}
serde_json = {workspace = true}
sketches-ddsketch = {workspace = true}
unicode-normalization = "0.1.23"

[dependencies.image]
default-features = false
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ mod utf8;

pub use sort::{build_multi_array_bicompare, build_multi_array_compare};

pub use utf8::PadPlacement;
pub use utf8::{PadPlacement, Utf8NormalizeOptions};

use common_error::DaftResult;

Expand Down
48 changes: 48 additions & 0 deletions src/daft-core/src/array/ops/utf8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use chrono::Datelike;
use common_error::{DaftError, DaftResult};
use itertools::Itertools;
use num_traits::NumCast;
use serde::{Deserialize, Serialize};
use unicode_normalization::UnicodeNormalization;

use super::{as_arrow::AsArrow, full::FullNull};

Expand Down Expand Up @@ -348,6 +350,14 @@ pub enum PadPlacement {
Right,
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct Utf8NormalizeOptions {
pub remove_punct: bool,
pub lowercase: bool,
pub nfd_unicode: bool,
pub white_space: bool,
}

impl Utf8Array {
pub fn endswith(&self, pattern: &Utf8Array) -> DaftResult<BooleanArray> {
self.binary_broadcasted_compare(
Expand Down Expand Up @@ -1328,6 +1338,44 @@ impl Utf8Array {
Ok(result)
}

pub fn normalize(&self, opts: Utf8NormalizeOptions) -> DaftResult<Utf8Array> {
let whitespace_regex = regex::Regex::new(r"\s+").unwrap();

let arrow_result = self
.as_arrow()
.iter()
.map(|maybe_s| {
if let Some(s) = maybe_s {
let mut s = s.to_string();

if opts.remove_punct {
s = s.chars().filter(|c| !c.is_ascii_punctuation()).collect();
}

if opts.lowercase {
s = s.to_lowercase();
}

if opts.white_space {
s = whitespace_regex
.replace_all(s.as_str().trim(), " ")
.to_string();
}

if opts.nfd_unicode {
s = s.nfd().collect();
}

Ok(Some(s))
} else {
Ok(None)
}
})
.collect::<DaftResult<arrow2::array::Utf8Array<i64>>>()?;

Ok(Utf8Array::from((self.name(), Box::new(arrow_result))))
}

fn unary_broadcasted_op<ScalarKernel>(&self, operation: ScalarKernel) -> DaftResult<Utf8Array>
where
ScalarKernel: Fn(&str) -> Cow<'_, str>,
Expand Down
23 changes: 22 additions & 1 deletion src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ use pyo3::{
};

use crate::{
array::{ops::DaftLogical, pseudo_arrow::PseudoArrowArray, DataArray},
array::{
ops::{DaftLogical, Utf8NormalizeOptions},
pseudo_arrow::PseudoArrowArray,
DataArray,
},
count_mode::CountMode,
datatypes::{DataType, Field, ImageFormat, ImageMode, PythonType},
ffi,
Expand Down Expand Up @@ -488,6 +492,23 @@ impl PySeries {
Ok(self.series.utf8_to_datetime(format, timezone)?.into())
}

pub fn utf8_normalize(
&self,
remove_punct: bool,
lowercase: bool,
nfd_unicode: bool,
white_space: bool,
) -> PyResult<Self> {
let opts = Utf8NormalizeOptions {
remove_punct,
lowercase,
nfd_unicode,
white_space,
};

Ok(self.series.utf8_normalize(opts)?.into())
}

pub fn is_nan(&self) -> PyResult<Self> {
Ok(self.series.is_nan()?.into())
}
Expand Down
6 changes: 5 additions & 1 deletion src/daft-core/src/series/ops/utf8.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::array::ops::PadPlacement;
use crate::array::ops::{PadPlacement, Utf8NormalizeOptions};
use crate::series::array_impl::IntoSeries;
use crate::series::Series;
use crate::{datatypes::*, with_match_integer_daft_types};
Expand Down Expand Up @@ -247,4 +247,8 @@ impl Series {
pub fn utf8_to_datetime(&self, format: &str, timezone: Option<&str>) -> DaftResult<Series> {
self.with_utf8_array(|arr| Ok(arr.to_datetime(format, timezone)?.into_series()))
}

pub fn utf8_normalize(&self, opts: Utf8NormalizeOptions) -> DaftResult<Series> {
self.with_utf8_array(|arr| Ok(arr.normalize(opts)?.into_series()))
}
}
13 changes: 13 additions & 0 deletions src/daft-dsl/src/functions/utf8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ mod lower;
mod lpad;
mod lstrip;
mod match_;
mod normalize;
mod repeat;
mod replace;
mod reverse;
Expand All @@ -27,6 +28,7 @@ mod upper;

use capitalize::CapitalizeEvaluator;
use contains::ContainsEvaluator;
use daft_core::array::ops::Utf8NormalizeOptions;
use endswith::EndswithEvaluator;
use extract::ExtractEvaluator;
use extract_all::ExtractAllEvaluator;
Expand All @@ -38,6 +40,7 @@ use like::LikeEvaluator;
use lower::LowerEvaluator;
use lpad::LpadEvaluator;
use lstrip::LstripEvaluator;
use normalize::NormalizeEvaluator;
use repeat::RepeatEvaluator;
use replace::ReplaceEvaluator;
use reverse::ReverseEvaluator;
Expand Down Expand Up @@ -84,6 +87,7 @@ pub enum Utf8Expr {
Substr,
ToDate(String),
ToDatetime(String, Option<String>),
Normalize(Utf8NormalizeOptions),
}

impl Utf8Expr {
Expand Down Expand Up @@ -117,6 +121,7 @@ impl Utf8Expr {
Substr => &SubstrEvaluator {},
ToDate(_) => &ToDateEvaluator {},
ToDatetime(_, _) => &ToDatetimeEvaluator {},
Normalize(_) => &NormalizeEvaluator {},
}
}
}
Expand Down Expand Up @@ -331,3 +336,11 @@ pub fn to_datetime(data: ExprRef, format: &str, timezone: Option<&str>) -> ExprR
}
.into()
}

pub fn normalize(data: ExprRef, opts: Utf8NormalizeOptions) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Utf8(Utf8Expr::Normalize(opts)),
inputs: vec![data],
}
.into()
}
53 changes: 53 additions & 0 deletions src/daft-dsl/src/functions/utf8/normalize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::functions::FunctionExpr;
use crate::ExprRef;
use common_error::{DaftError, DaftResult};

use super::{super::FunctionEvaluator, Utf8Expr};

pub(super) struct NormalizeEvaluator {}

impl FunctionEvaluator for NormalizeEvaluator {
fn fn_name(&self) -> &'static str {
"normalize"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[data] => match data.to_field(schema) {
Ok(data_field) => match &data_field.dtype {
DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)),
_ => Err(DaftError::TypeError(format!(
"Expects input to normalize to be utf8, but received {data_field}",
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult<Series> {
match inputs {
[data] => {
let opts = match expr {
FunctionExpr::Utf8(Utf8Expr::Normalize(opts)) => opts,
_ => panic!("Expected Utf8 Normalize Expr, got {expr}"),
};
data.utf8_normalize(*opts)
}
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}
}
19 changes: 19 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::hash::{Hash, Hasher};
use std::sync::Arc;

use common_error::DaftError;
use daft_core::array::ops::Utf8NormalizeOptions;
use daft_core::python::datatype::PyTimeUnit;
use daft_core::python::PySeries;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -687,6 +688,24 @@ impl PyExpr {
Ok(to_datetime(self.into(), format, timezone).into())
}

pub fn utf8_normalize(
&self,
remove_punct: bool,
lowercase: bool,
nfd_unicode: bool,
white_space: bool,
) -> PyResult<Self> {
use crate::functions::utf8::normalize;
let opts = Utf8NormalizeOptions {
remove_punct,
lowercase,
nfd_unicode,
white_space,
};

Ok(normalize(self.into(), opts).into())
}

pub fn image_decode(&self, raise_error_on_failure: bool) -> PyResult<Self> {
use crate::functions::image::decode;
Ok(decode(self.into(), raise_error_on_failure).into())
Expand Down
Loading

0 comments on commit 29b02a1

Please sign in to comment.