diff --git a/crates/polars-ops/src/chunked_array/strings/namespace.rs b/crates/polars-ops/src/chunked_array/strings/namespace.rs index 83713b7889521..9948b0f006b57 100644 --- a/crates/polars-ops/src/chunked_array/strings/namespace.rs +++ b/crates/polars-ops/src/chunked_array/strings/namespace.rs @@ -590,6 +590,29 @@ pub trait StringNameSpaceImpl: AsString { Ok(substring::substring(ca, offset.i64()?, length.u64()?)) } + + /// Slice the first `n` values of the string. + /// + /// Determines a substring starting at the beginning of the string up to offset `n` of each + /// element in `array`. `n` can be negative, in which case the slice ends `n` characters from + /// the end of the string. + fn str_head(&self, n: &Series) -> PolarsResult { + let ca = self.as_string(); + let n = n.cast(&DataType::Int64)?; + + Ok(substring::head(ca, n.i64()?)) + } + + /// Slice the last `n` values of the string. + /// + /// Determines a substring starting at offset `n` of each element in `array`. `n` can be + /// negative, in which case the slice begins `n` characters from the end of the string. + fn str_tail(&self, n: &Series) -> PolarsResult { + let ca = self.as_string(); + let n = n.cast(&DataType::Int64)?; + + Ok(substring::tail(ca, n.i64()?)) + } } impl StringNameSpaceImpl for StringChunked {} diff --git a/crates/polars-ops/src/chunked_array/strings/substring.rs b/crates/polars-ops/src/chunked_array/strings/substring.rs index 690567396fb8a..fb01a1a6c826d 100644 --- a/crates/polars-ops/src/chunked_array/strings/substring.rs +++ b/crates/polars-ops/src/chunked_array/strings/substring.rs @@ -1,6 +1,49 @@ use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise}; use polars_core::prelude::{Int64Chunked, StringChunked, UInt64Chunked}; +fn head_binary(opt_str_val: Option<&str>, opt_n: Option) -> Option<&str> { + if let (Some(str_val), Some(mut n)) = (opt_str_val, opt_n) { + let str_len = str_val.len() as i64; + if n >= str_len { + Some(str_val) + } else if (n == 0) | (str_len == 0) | (n <= -str_len) { + Some("") + } else { + if n < 0 { + // If `n` is negative, it counts from the end of the string. + n += str_len; // adding negative value + } + Some(&str_val[0..n as usize]) + } + } else { + None + } +} + +fn tail_binary(opt_str_val: Option<&str>, opt_n: Option) -> Option<&str> { + if let (Some(str_val), Some(mut n)) = (opt_str_val, opt_n) { + let str_len = str_val.len() as i64; + if n >= str_len { + Some(str_val) + } else if (n == 0) | (str_len == 0) | (n <= -str_len) { + Some("") + } else { + // We re-assign `n` to be the start of the slice. + // The end of the slice is always the end of the string. + if n < 0 { + // If `n` is negative, we count from the beginning. + n = -n; + } else { + // If `n` is positive, we count from the end. + n = str_len - n; + } + Some(&str_val[n as usize..str_len as usize]) + } + } else { + None + } +} + fn substring_ternary( opt_str_val: Option<&str>, opt_offset: Option, @@ -115,3 +158,35 @@ pub(super) fn substring( _ => ternary_elementwise(ca, offset, length, substring_ternary), } } + +pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> StringChunked { + match (ca.len(), n.len()) { + (_, 1) => { + // SAFETY: index `0` is in bound. + let n = unsafe { n.get_unchecked(0) }; + unary_elementwise(ca, |str_val| head_binary(str_val, n)).with_name(ca.name()) + }, + (1, _) => { + // SAFETY: index `0` is in bound. + let str_val = unsafe { ca.get_unchecked(0) }; + unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name()) + }, + _ => binary_elementwise(ca, n, head_binary), + } +} + +pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> StringChunked { + match (ca.len(), n.len()) { + (_, 1) => { + // SAFETY: index `0` is in bound. + let n = unsafe { n.get_unchecked(0) }; + unary_elementwise(ca, |str_val| tail_binary(str_val, n)).with_name(ca.name()) + }, + (1, _) => { + // SAFETY: index `0` is in bound. + let str_val = unsafe { ca.get_unchecked(0) }; + unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name()) + }, + _ => binary_elementwise(ca, n, tail_binary), + } +} diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index 13d18d790c63b..6c90bafe617a0 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -82,6 +82,8 @@ pub enum StringFunction { fill_char: char, }, Slice, + Head, + Tail, #[cfg(feature = "string_encoding")] HexEncode, #[cfg(feature = "binary_encoding")] @@ -166,7 +168,7 @@ impl StringFunction { #[cfg(feature = "binary_encoding")] Base64Decode(_) => mapper.with_dtype(DataType::Binary), Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix - | StripSuffix | Slice => mapper.with_same_dtype(), + | StripSuffix | Slice | Head | Tail => mapper.with_same_dtype(), #[cfg(feature = "string_pad")] PadStart { .. } | PadEnd { .. } | ZFill => mapper.with_same_dtype(), #[cfg(feature = "dtype-struct")] @@ -210,6 +212,8 @@ impl Display for StringFunction { ToInteger { .. } => "to_integer", #[cfg(feature = "regex")] Find { .. } => "find", + Head { .. } => "head", + Tail { .. } => "tail", #[cfg(feature = "extract_jsonpath")] JsonDecode { .. } => "json_decode", LenBytes => "len_bytes", @@ -345,6 +349,8 @@ impl From for SpecialEq> { #[cfg(feature = "string_to_integer")] ToInteger(base, strict) => map!(strings::to_integer, base, strict), Slice => map_as_slice!(strings::str_slice), + Head => map_as_slice!(strings::str_head), + Tail => map_as_slice!(strings::str_tail), #[cfg(feature = "string_encoding")] HexEncode => map!(strings::hex_encode), #[cfg(feature = "binary_encoding")] @@ -892,7 +898,8 @@ pub(super) fn to_integer(s: &Series, base: u32, strict: bool) -> PolarsResult PolarsResult { + +fn _ensure_lengths(s: &[Series]) -> bool { // Calculate the post-broadcast length and ensure everything is consistent. let len = s .iter() @@ -900,9 +907,14 @@ pub(super) fn str_slice(s: &[Series]) -> PolarsResult { .filter(|l| *l != 1) .max() .unwrap_or(1); + s.iter() + .all(|series| series.len() == 1 || series.len() == len) +} + +pub(super) fn str_slice(s: &[Series]) -> PolarsResult { polars_ensure!( - s.iter().all(|series| series.len() == 1 || series.len() == len), - ComputeError: "all series in `str_slice` should have equal or unit length" + _ensure_lengths(s), + ComputeError: "all series in `str_slice` should have equal or unit length", ); let ca = s[0].str()?; let offset = &s[1]; @@ -910,6 +922,26 @@ pub(super) fn str_slice(s: &[Series]) -> PolarsResult { Ok(ca.str_slice(offset, length)?.into_series()) } +pub(super) fn str_head(s: &[Series]) -> PolarsResult { + polars_ensure!( + _ensure_lengths(s), + ComputeError: "all series in `str_head` should have equal or unit length", + ); + let ca = s[0].str()?; + let n = &s[1]; + Ok(ca.str_head(n)?.into_series()) +} + +pub(super) fn str_tail(s: &[Series]) -> PolarsResult { + polars_ensure!( + _ensure_lengths(s), + ComputeError: "all series in `str_tail` should have equal or unit length", + ); + let ca = s[0].str()?; + let n = &s[1]; + Ok(ca.str_tail(n)?.into_series()) +} + #[cfg(feature = "string_encoding")] pub(super) fn hex_encode(s: &Series) -> PolarsResult { Ok(s.str()?.hex_encode().into_series()) diff --git a/crates/polars-plan/src/dsl/string.rs b/crates/polars-plan/src/dsl/string.rs index 42a7cb2471fd1..581c5306a00c3 100644 --- a/crates/polars-plan/src/dsl/string.rs +++ b/crates/polars-plan/src/dsl/string.rs @@ -526,6 +526,26 @@ impl StringNameSpace { ) } + /// Take the first `n` characters of the string values. + pub fn head(self, n: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::Head), + &[n], + false, + false, + ) + } + + /// Take the last `n` characters of the string values. + pub fn tail(self, n: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::StringExpr(StringFunction::Tail), + &[n], + false, + false, + ) + } + pub fn explode(self) -> Expr { self.0 .apply_private(FunctionExpr::StringExpr(StringFunction::Explode)) diff --git a/py-polars/polars/expr/string.py b/py-polars/polars/expr/string.py index e046fa14eea3a..240a051500085 100644 --- a/py-polars/polars/expr/string.py +++ b/py-polars/polars/expr/string.py @@ -2180,6 +2180,82 @@ def slice( length = parse_as_expression(length) return wrap_expr(self._pyexpr.str_slice(offset, length)) + def head(self, n: int | IntoExprColumn = 10) -> Expr: + """ + Return the first n characters of each string in a Utf8 Series. + + Parameters + ---------- + n + Length of the slice. Negative indexing supported. + + Returns + ------- + Expr + Expression of data type :class:`Utf8`. + + Notes + ----- + A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte + when working with ASCII text, and a maximum of 4 bytes otherwise. + + Examples + -------- + >>> df = pl.DataFrame({"s": ["pear", None, "papaya", "dragonfruit"]}) + >>> df.with_columns(pl.col("s").str.head(3).alias("s_head3")) + shape: (4, 2) + ┌─────────────┬─────────┐ + │ s ┆ s_head3 │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞═════════════╪═════════╡ + │ pear ┆ pea │ + │ null ┆ null │ + │ papaya ┆ pap │ + │ dragonfruit ┆ dra │ + └─────────────┴─────────┘ + """ + n = parse_as_expression(n) + return wrap_expr(self._pyexpr.str_head(n)) + + def tail(self, n: int | IntoExprColumn = 10) -> Expr: + """ + Return the last n characters of each string in a Utf8 Series. + + Parameters + ---------- + n + Length of the slice. Negative indexing is supported. + + Returns + ------- + Expr + Expression of data type :class:`Utf8`. + + Notes + ----- + A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte + when working with ASCII text, and a maximum of 4 bytes otherwise. + + Examples + -------- + >>> df = pl.DataFrame({"s": ["pear", None, "papaya", "dragonfruit"]}) + >>> df.with_columns(pl.col("s").str.tail(3).alias("s_tail3")) + shape: (4, 2) + ┌─────────────┬─────────┐ + │ s ┆ s_tail3 │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞═════════════╪═════════╡ + │ pear ┆ ear │ + │ null ┆ null │ + │ papaya ┆ aya │ + │ dragonfruit ┆ uit │ + └─────────────┴─────────┘ + """ + n = parse_as_expression(n) + return wrap_expr(self._pyexpr.str_tail(n)) + def explode(self) -> Expr: """ Returns a column with a separate row for every string character. diff --git a/py-polars/polars/series/string.py b/py-polars/polars/series/string.py index 881d3ce2b5c53..c101015828669 100644 --- a/py-polars/polars/series/string.py +++ b/py-polars/polars/series/string.py @@ -1636,6 +1636,72 @@ def slice( ] """ + def head(self, n: int | IntoExprColumn = 10) -> Series: + """ + Return the first n characters of each string in a Utf8 Series. + + Parameters + ---------- + n + Length of the slice + + Returns + ------- + Series + Series of data type :class:`Struct` with fields of data type :class:`Utf8`. + + Notes + ----- + A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte + when working with ASCII text, and a maximum of 4 bytes otherwise. + + Examples + -------- + >>> s = pl.Series("s", ["pear", None, "papaya", "dragonfruit"]) + >>> s.str.head(3) + shape: (4,) + Series: 's' [str] + [ + "pea" + null + "pap" + "dra" + ] + """ + + def tail(self, n: int | IntoExprColumn = 10) -> Series: + """ + Return the last n characters of each string in a Utf8 Series. + + Parameters + ---------- + n + Length of the slice + + Returns + ------- + Series + Series of data type :class:`Struct` with fields of data type :class:`Utf8`. + + Notes + ----- + A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte + when working with ASCII text, and a maximum of 4 bytes otherwise. + + Examples + -------- + >>> s = pl.Series("s", ["pear", None, "papaya", "dragonfruit"]) + >>> s.str.tail(3) + shape: (4,) + Series: 's' [str] + [ + "ear" + null + "aya" + "uit" + ] + """ + def explode(self) -> Series: """ Returns a column with a separate row for every string character. diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 34668b79efe30..969defa74e499 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -422,16 +422,17 @@ impl PyExpr { fn gather_every(&self, n: usize, offset: usize) -> Self { self.inner.clone().gather_every(n, offset).into() } - fn tail(&self, n: usize) -> Self { - self.inner.clone().tail(Some(n)).into() + + fn slice(&self, offset: Self, length: Self) -> Self { + self.inner.clone().slice(offset.inner, length.inner).into() } fn head(&self, n: usize) -> Self { self.inner.clone().head(Some(n)).into() } - fn slice(&self, offset: Self, length: Self) -> Self { - self.inner.clone().slice(offset.inner, length.inner).into() + fn tail(&self, n: usize) -> Self { + self.inner.clone().tail(Some(n)).into() } fn append(&self, other: Self, upcast: bool) -> Self { diff --git a/py-polars/src/expr/string.rs b/py-polars/src/expr/string.rs index e4e8b7bcceb7d..e14375a4886b1 100644 --- a/py-polars/src/expr/string.rs +++ b/py-polars/src/expr/string.rs @@ -102,6 +102,14 @@ impl PyExpr { .into() } + fn str_head(&self, n: Self) -> Self { + self.inner.clone().str().head(n.inner).into() + } + + fn str_tail(&self, n: Self) -> Self { + self.inner.clone().str().tail(n.inner).into() + } + fn str_explode(&self) -> Self { self.inner.clone().str().explode().into() } diff --git a/py-polars/tests/unit/namespaces/string/test_string.py b/py-polars/tests/unit/namespaces/string/test_string.py index 1d34025f539a4..7323b66e51608 100644 --- a/py-polars/tests/unit/namespaces/string/test_string.py +++ b/py-polars/tests/unit/namespaces/string/test_string.py @@ -46,6 +46,72 @@ def test_str_slice_expr() -> None: df.select(pl.col("a").str.slice(0, -1)) +def test_str_head() -> None: + df = pl.DataFrame({"a": ["foobar", "barfoo", None]}) + assert df["a"].str.head(0).to_list() == ["", "", None] + assert df["a"].str.head(-3).to_list() == ["foo", "bar", None] + assert df["a"].str.head(100).to_list() == ["foobar", "barfoo", None] + + +def test_str_head_expr() -> None: + df = pl.DataFrame( + { + "a": ["abcdef", None, "abcdef", "abcd", ""], + "n": [1, 3, None, -3, 2], + } + ) + out = df.select( + n_expr=pl.col("a").str.head("n"), + n=pl.col("a").str.head(2), + str_lit=pl.col("a").str.head(pl.lit(2)), + lit_expr=pl.lit("abcdef").str.head("n"), + lit_n=pl.lit("abcdef").str.head(2), + ) + expected = pl.DataFrame( + { + "n_expr": ["a", None, None, "a", ""], + "n": ["ab", None, "ab", "ab", ""], + "str_lit": ["ab", None, "ab", "ab", ""], + "lit_expr": ["a", "abc", None, "abc", "ab"], + "lit_n": ["ab", "ab", "ab", "ab", "ab"], + } + ) + assert_frame_equal(out, expected) + + +def test_str_tail() -> None: + df = pl.DataFrame({"a": ["foobar", "barfoo", None]}) + assert df["a"].str.tail(0).to_list() == ["", "", None] + assert df["a"].str.tail(-3).to_list() == ["bar", "foo", None] + assert df["a"].str.tail(100).to_list() == ["foobar", "barfoo", None] + + +def test_str_tail_expr() -> None: + df = pl.DataFrame( + { + "a": ["abcdef", None, "abcdef", "abcdef", ""], + "n": [1, 3, None, -3, 2], + } + ) + out = df.select( + n_expr=pl.col("a").str.tail("n"), + n=pl.col("a").str.tail(2), + str_lit=pl.col("a").str.tail(pl.lit(2)), + lit_expr=pl.lit("abcdef").str.tail("n"), + lit_n=pl.lit("abcdef").str.tail(2), + ) + expected = pl.DataFrame( + { + "n_expr": ["f", None, None, "def", ""], + "n": ["ef", None, "ef", "ef", ""], + "str_lit": ["ef", None, "ef", "ef", ""], + "lit_expr": ["f", "def", None, "def", "ef"], + "lit_n": ["ef", "ef", "ef", "ef", "ef"], + } + ) + assert_frame_equal(out, expected) + + def test_str_len_bytes() -> None: s = pl.Series(["Café", None, "345", "東京"]) result = s.str.len_bytes()