Skip to content

Commit

Permalink
Add str_head and str_tail
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Feb 12, 2024
1 parent 41b33c5 commit 496f4de
Show file tree
Hide file tree
Showing 9 changed files with 375 additions and 8 deletions.
23 changes: 23 additions & 0 deletions crates/polars-ops/src/chunked_array/strings/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,29 @@ pub trait StringNameSpaceImpl: AsString {

Ok(substring::substring(ca, offset.i64()?, length.u64()?))
}

/// Slice the first `n` values of the string.
///
/// Determines a substring starting at the beginning of the string up to offset `n` of each
/// element in `array`. `n` can be negative, in which case the slice ends `n` characters from
/// the end of the string.
fn str_head(&self, n: &Series) -> PolarsResult<StringChunked> {
let ca = self.as_string();
let n = n.cast(&DataType::Int64)?;

Ok(substring::head(ca, n.i64()?))
}

/// Slice the last `n` values of the string.
///
/// Determines a substring starting at offset `n` of each element in `array`. `n` can be
/// negative, in which case the slice begins `n` characters from the end of the string.
fn str_tail(&self, n: &Series) -> PolarsResult<StringChunked> {
let ca = self.as_string();
let n = n.cast(&DataType::Int64)?;

Ok(substring::tail(ca, n.i64()?))
}
}

impl StringNameSpaceImpl for StringChunked {}
75 changes: 75 additions & 0 deletions crates/polars-ops/src/chunked_array/strings/substring.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,49 @@
use polars_core::prelude::arity::{binary_elementwise, ternary_elementwise, unary_elementwise};
use polars_core::prelude::{Int64Chunked, StringChunked, UInt64Chunked};

fn head_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {
if let (Some(str_val), Some(mut n)) = (opt_str_val, opt_n) {
let str_len = str_val.len() as i64;
if n >= str_len {
Some(str_val)
} else if (n == 0) | (str_len == 0) | (n <= -str_len) {
Some("")
} else {
if n < 0 {
// If `n` is negative, it counts from the end of the string.
n += str_len; // adding negative value
}
Some(&str_val[0..n as usize])
}
} else {
None
}
}

fn tail_binary(opt_str_val: Option<&str>, opt_n: Option<i64>) -> Option<&str> {
if let (Some(str_val), Some(mut n)) = (opt_str_val, opt_n) {
let str_len = str_val.len() as i64;
if n >= str_len {
Some(str_val)
} else if (n == 0) | (str_len == 0) | (n <= -str_len) {
Some("")
} else {
// We re-assign `n` to be the start of the slice.
// The end of the slice is always the end of the string.
if n < 0 {
// If `n` is negative, we count from the beginning.
n = -n;
} else {
// If `n` is positive, we count from the end.
n = str_len - n;
}
Some(&str_val[n as usize..str_len as usize])
}
} else {
None
}
}

fn substring_ternary(
opt_str_val: Option<&str>,
opt_offset: Option<i64>,
Expand Down Expand Up @@ -115,3 +158,35 @@ pub(super) fn substring(
_ => ternary_elementwise(ca, offset, length, substring_ternary),
}
}

pub(super) fn head(ca: &StringChunked, n: &Int64Chunked) -> StringChunked {
match (ca.len(), n.len()) {
(_, 1) => {
// SAFETY: index `0` is in bound.
let n = unsafe { n.get_unchecked(0) };
unary_elementwise(ca, |str_val| head_binary(str_val, n)).with_name(ca.name())
},
(1, _) => {
// SAFETY: index `0` is in bound.
let str_val = unsafe { ca.get_unchecked(0) };
unary_elementwise(n, |n| head_binary(str_val, n)).with_name(ca.name())
},
_ => binary_elementwise(ca, n, head_binary),
}
}

pub(super) fn tail(ca: &StringChunked, n: &Int64Chunked) -> StringChunked {
match (ca.len(), n.len()) {
(_, 1) => {
// SAFETY: index `0` is in bound.
let n = unsafe { n.get_unchecked(0) };
unary_elementwise(ca, |str_val| tail_binary(str_val, n)).with_name(ca.name())
},
(1, _) => {
// SAFETY: index `0` is in bound.
let str_val = unsafe { ca.get_unchecked(0) };
unary_elementwise(n, |n| tail_binary(str_val, n)).with_name(ca.name())
},
_ => binary_elementwise(ca, n, tail_binary),
}
}
40 changes: 36 additions & 4 deletions crates/polars-plan/src/dsl/function_expr/strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ pub enum StringFunction {
fill_char: char,
},
Slice,
Head,
Tail,
#[cfg(feature = "string_encoding")]
HexEncode,
#[cfg(feature = "binary_encoding")]
Expand Down Expand Up @@ -166,7 +168,7 @@ impl StringFunction {
#[cfg(feature = "binary_encoding")]
Base64Decode(_) => mapper.with_dtype(DataType::Binary),
Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix
| StripSuffix | Slice => mapper.with_same_dtype(),
| StripSuffix | Slice | Head | Tail => mapper.with_same_dtype(),
#[cfg(feature = "string_pad")]
PadStart { .. } | PadEnd { .. } | ZFill => mapper.with_same_dtype(),
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -210,6 +212,8 @@ impl Display for StringFunction {
ToInteger { .. } => "to_integer",
#[cfg(feature = "regex")]
Find { .. } => "find",
Head { .. } => "head",
Tail { .. } => "tail",
#[cfg(feature = "extract_jsonpath")]
JsonDecode { .. } => "json_decode",
LenBytes => "len_bytes",
Expand Down Expand Up @@ -345,6 +349,8 @@ impl From<StringFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "string_to_integer")]
ToInteger(base, strict) => map!(strings::to_integer, base, strict),
Slice => map_as_slice!(strings::str_slice),
Head => map_as_slice!(strings::str_head),
Tail => map_as_slice!(strings::str_tail),
#[cfg(feature = "string_encoding")]
HexEncode => map!(strings::hex_encode),
#[cfg(feature = "binary_encoding")]
Expand Down Expand Up @@ -892,24 +898,50 @@ pub(super) fn to_integer(s: &Series, base: u32, strict: bool) -> PolarsResult<Se
let ca = s.str()?;
ca.to_integer(base, strict).map(|ok| ok.into_series())
}
pub(super) fn str_slice(s: &[Series]) -> PolarsResult<Series> {

fn _ensure_lengths(s: &[Series]) -> bool {
// Calculate the post-broadcast length and ensure everything is consistent.
let len = s
.iter()
.map(|series| series.len())
.filter(|l| *l != 1)
.max()
.unwrap_or(1);
s.iter()
.all(|series| series.len() == 1 || series.len() == len)
}

pub(super) fn str_slice(s: &[Series]) -> PolarsResult<Series> {
polars_ensure!(
s.iter().all(|series| series.len() == 1 || series.len() == len),
ComputeError: "all series in `str_slice` should have equal or unit length"
_ensure_lengths(s),
ComputeError: "all series in `str_slice` should have equal or unit length",
);
let ca = s[0].str()?;
let offset = &s[1];
let length = &s[2];
Ok(ca.str_slice(offset, length)?.into_series())
}

pub(super) fn str_head(s: &[Series]) -> PolarsResult<Series> {
polars_ensure!(
_ensure_lengths(s),
ComputeError: "all series in `str_head` should have equal or unit length",
);
let ca = s[0].str()?;
let n = &s[1];
Ok(ca.str_head(n)?.into_series())
}

pub(super) fn str_tail(s: &[Series]) -> PolarsResult<Series> {
polars_ensure!(
_ensure_lengths(s),
ComputeError: "all series in `str_tail` should have equal or unit length",
);
let ca = s[0].str()?;
let n = &s[1];
Ok(ca.str_tail(n)?.into_series())
}

#[cfg(feature = "string_encoding")]
pub(super) fn hex_encode(s: &Series) -> PolarsResult<Series> {
Ok(s.str()?.hex_encode().into_series())
Expand Down
20 changes: 20 additions & 0 deletions crates/polars-plan/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,26 @@ impl StringNameSpace {
)
}

/// Take the first `n` characters of the string values.
pub fn head(self, n: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Head),
&[n],
false,
false,
)
}

/// Take the last `n` characters of the string values.
pub fn tail(self, n: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::StringExpr(StringFunction::Tail),
&[n],
false,
false,
)
}

pub fn explode(self) -> Expr {
self.0
.apply_private(FunctionExpr::StringExpr(StringFunction::Explode))
Expand Down
76 changes: 76 additions & 0 deletions py-polars/polars/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -2180,6 +2180,82 @@ def slice(
length = parse_as_expression(length)
return wrap_expr(self._pyexpr.str_slice(offset, length))

def head(self, n: int | IntoExprColumn = 10) -> Expr:
"""
Return the first n characters of each string in a Utf8 Series.
Parameters
----------
n
Length of the slice. Negative indexing supported.
Returns
-------
Expr
Expression of data type :class:`Utf8`.
Notes
-----
A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte
when working with ASCII text, and a maximum of 4 bytes otherwise.
Examples
--------
>>> df = pl.DataFrame({"s": ["pear", None, "papaya", "dragonfruit"]})
>>> df.with_columns(pl.col("s").str.head(3).alias("s_head3"))
shape: (4, 2)
┌─────────────┬─────────┐
│ s ┆ s_head3 │
│ --- ┆ --- │
│ str ┆ str │
╞═════════════╪═════════╡
│ pear ┆ pea │
│ null ┆ null │
│ papaya ┆ pap │
│ dragonfruit ┆ dra │
└─────────────┴─────────┘
"""
n = parse_as_expression(n)
return wrap_expr(self._pyexpr.str_head(n))

def tail(self, n: int | IntoExprColumn = 10) -> Expr:
"""
Return the last n characters of each string in a Utf8 Series.
Parameters
----------
n
Length of the slice. Negative indexing is supported.
Returns
-------
Expr
Expression of data type :class:`Utf8`.
Notes
-----
A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte
when working with ASCII text, and a maximum of 4 bytes otherwise.
Examples
--------
>>> df = pl.DataFrame({"s": ["pear", None, "papaya", "dragonfruit"]})
>>> df.with_columns(pl.col("s").str.tail(3).alias("s_tail3"))
shape: (4, 2)
┌─────────────┬─────────┐
│ s ┆ s_tail3 │
│ --- ┆ --- │
│ str ┆ str │
╞═════════════╪═════════╡
│ pear ┆ ear │
│ null ┆ null │
│ papaya ┆ aya │
│ dragonfruit ┆ uit │
└─────────────┴─────────┘
"""
n = parse_as_expression(n)
return wrap_expr(self._pyexpr.str_tail(n))

def explode(self) -> Expr:
"""
Returns a column with a separate row for every string character.
Expand Down
66 changes: 66 additions & 0 deletions py-polars/polars/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,72 @@ def slice(
]
"""

def head(self, n: int | IntoExprColumn = 10) -> Series:
"""
Return the first n characters of each string in a Utf8 Series.
Parameters
----------
n
Length of the slice
Returns
-------
Series
Series of data type :class:`Struct` with fields of data type :class:`Utf8`.
Notes
-----
A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte
when working with ASCII text, and a maximum of 4 bytes otherwise.
Examples
--------
>>> s = pl.Series("s", ["pear", None, "papaya", "dragonfruit"])
>>> s.str.head(3)
shape: (4,)
Series: 's' [str]
[
"pea"
null
"pap"
"dra"
]
"""

def tail(self, n: int | IntoExprColumn = 10) -> Series:
"""
Return the last n characters of each string in a Utf8 Series.
Parameters
----------
n
Length of the slice
Returns
-------
Series
Series of data type :class:`Struct` with fields of data type :class:`Utf8`.
Notes
-----
A "character" is a valid (non-surrogate) UTF-8 codepoint, which is a single byte
when working with ASCII text, and a maximum of 4 bytes otherwise.
Examples
--------
>>> s = pl.Series("s", ["pear", None, "papaya", "dragonfruit"])
>>> s.str.tail(3)
shape: (4,)
Series: 's' [str]
[
"ear"
null
"aya"
"uit"
]
"""

def explode(self) -> Series:
"""
Returns a column with a separate row for every string character.
Expand Down
Loading

0 comments on commit 496f4de

Please sign in to comment.