Skip to content

Commit

Permalink
feat(clip): add binary_min, binary_max and clip function
Browse files Browse the repository at this point in the history
  • Loading branch information
conradsoon committed Oct 28, 2024
1 parent b942d44 commit 54e79d8
Show file tree
Hide file tree
Showing 10 changed files with 452 additions and 4 deletions.
6 changes: 6 additions & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,9 @@ def to_struct(inputs: list[PyExpr]) -> PyExpr: ...
def abs(expr: PyExpr) -> PyExpr: ...
def cbrt(expr: PyExpr) -> PyExpr: ...
def ceil(expr: PyExpr) -> PyExpr: ...
def binary_min(expr: PyExpr, other: PyExpr) -> PyExpr: ...
def binary_max(expr: PyExpr, other: PyExpr) -> PyExpr: ...
def clip(expr: PyExpr, min: PyExpr, max: PyExpr) -> PyExpr: ...
def exp(expr: PyExpr) -> PyExpr: ...
def floor(expr: PyExpr) -> PyExpr: ...
def log2(expr: PyExpr) -> PyExpr: ...
Expand Down Expand Up @@ -1360,6 +1363,9 @@ class PySeries:
def floor(self) -> PySeries: ...
def sign(self) -> PySeries: ...
def round(self, decimal: int) -> PySeries: ...
def binary_min(self, other: PySeries) -> PySeries: ...
def binary_max(self, other: PySeries) -> PySeries: ...
def clip(self, min: PySeries, max: PySeries) -> PySeries: ...
def sqrt(self) -> PySeries: ...
def cbrt(self) -> PySeries: ...
def sin(self) -> PySeries: ...
Expand Down
16 changes: 12 additions & 4 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,10 +618,18 @@ def ceil(self) -> Expression:
expr = native.ceil(self._expr)
return Expression._from_pyexpr(expr)

def floor(self) -> Expression:
"""The floor of a numeric expression (``expr.floor()``)"""
expr = native.floor(self._expr)
return Expression._from_pyexpr(expr)
def binary_min(self, other: Expression) -> Expression:
expr = Expression._to_expression(other)
return Expression._from_pyexpr(native.binary_min(self._expr, expr._expr))

def binary_max(self, other: Expression) -> Expression:
expr = Expression._to_expression(other)
return Expression._from_pyexpr(native.binary_max(self._expr, expr._expr))

def clip(self, min: Expression, max: Expression) -> Expression:
min_expr = Expression._to_expression(min)
max_expr = Expression._to_expression(max)
return Expression._from_pyexpr(native.clip(self._expr, min_expr._expr, max_expr._expr))

def sign(self) -> Expression:
"""The sign of a numeric expression (``expr.sign()``)"""
Expand Down
9 changes: 9 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,15 @@ def sign(self) -> Series:
def round(self, decimal: int) -> Series:
return Series._from_pyseries(self._series.round(decimal))

def binary_min(self, other: Series) -> Series:
return Series._from_pyseries(self._series.binary_min(other._series))

def binary_max(self, other: Series) -> Series:
return Series._from_pyseries(self._series.binary_max(other._series))

def clip(self, min: Series, max: Series) -> Series:
return Series._from_pyseries(self._series.clip(min._series, max._series))

def sqrt(self) -> Series:
return Series._from_pyseries(self._series.sqrt())

Expand Down
41 changes: 41 additions & 0 deletions src/daft-core/src/array/ops/clip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use common_error::DaftResult;

use crate::{
array::DataArray,
datatypes::{DaftNumericType, Float32Type, Float64Type},
prelude::DaftIntegerType,
};

impl<T> DataArray<T>
where
T: DaftNumericType + DaftIntegerType, // Need the DaftIntegerType to tell the compiler that this doesn't apply for Float32/Float64, so we can specialize the implementation
T::Native: Ord,
{
pub fn min(&self, rhs: &Self) -> DaftResult<Self> {
self.binary_apply(rhs, |l, r| l.min(r))
}
pub fn max(&self, rhs: &Self) -> DaftResult<Self> {
self.binary_apply(rhs, |l, r| l.max(r))
}
}

// Ideally, I'd like to further specialize the template to use the float's min/max version,
// but I am too dumb for now to figure out Rust's type system, so let's just do this for now.

impl DataArray<Float32Type> {
pub fn min(&self, rhs: &Self) -> DaftResult<Self> {
self.binary_apply(rhs, |l, r| l.min(r))
}
pub fn max(&self, rhs: &Self) -> DaftResult<Self> {
self.binary_apply(rhs, |l, r| l.max(r))
}
}

impl DataArray<Float64Type> {
pub fn min(&self, rhs: &Self) -> DaftResult<Self> {
self.binary_apply(rhs, |l, r| l.min(r))
}
pub fn max(&self, rhs: &Self) -> DaftResult<Self> {
self.binary_apply(rhs, |l, r| l.max(r))
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub(crate) mod broadcast;
pub(crate) mod cast;
mod cbrt;
mod ceil;
mod clip;
mod compare_agg;
mod comparison;
mod concat;
Expand Down
12 changes: 12 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ impl PySeries {
Ok(self.series.round(decimal)?.into())
}

pub fn binary_min(&self, other: &Self) -> PyResult<Self> {
Ok(self.series.binary_min(&other.series)?.into())
}

pub fn binary_max(&self, other: &Self) -> PyResult<Self> {
Ok(self.series.binary_max(&other.series)?.into())
}

pub fn clip(&self, min: &Self, max: &Self) -> PyResult<Self> {
Ok(self.series.clip(&min.series, &max.series)?.into())
}

pub fn sqrt(&self) -> PyResult<Self> {
Ok(self.series.sqrt()?.into())
}
Expand Down
139 changes: 139 additions & 0 deletions src/daft-core/src/series/ops/clip.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
use common_error::{DaftError, DaftResult};
use daft_schema::prelude::*;

use crate::{
datatypes::InferDataType,
series::{IntoSeries, Series},
};

impl Series {
pub fn binary_min(&self, rhs: &Self) -> DaftResult<Self> {
let (_, _, output_type) = InferDataType::from(self.data_type())
.comparison_op(&InferDataType::from(rhs.data_type()))?;

match &output_type {
DataType::Int8 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.i8()?.min(rhs_casted.i8()?)?.into_series())
}
DataType::Int16 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.i16()?.min(rhs_casted.i16()?)?.into_series())
}
DataType::Int32 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.i32()?.min(rhs_casted.i32()?)?.into_series())
}
DataType::Int64 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.i64()?.min(rhs_casted.i64()?)?.into_series())
}
DataType::UInt8 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.u8()?.min(rhs_casted.u8()?)?.into_series())
}
DataType::UInt16 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.u16()?.min(rhs_casted.u16()?)?.into_series())
}
DataType::UInt32 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.u32()?.min(rhs_casted.u32()?)?.into_series())
}
DataType::UInt64 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.u64()?.min(rhs_casted.u64()?)?.into_series())
}
DataType::Float32 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.f32()?.min(rhs_casted.f32()?)?.into_series())
}
DataType::Float64 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.f64()?.min(rhs_casted.f64()?)?.into_series())
}
dt => Err(DaftError::TypeError(format!(
"min not implemented for {}",
dt
))),
}
}

pub fn binary_max(&self, rhs: &Self) -> DaftResult<Self> {
let (_, _, output_type) = InferDataType::from(self.data_type())
.comparison_op(&InferDataType::from(rhs.data_type()))?;

match &output_type {
DataType::Int8 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.i8()?.max(rhs_casted.i8()?)?.into_series())
}
DataType::Int16 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.i16()?.max(rhs_casted.i16()?)?.into_series())
}
DataType::Int32 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.i32()?.max(rhs_casted.i32()?)?.into_series())
}
DataType::Int64 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.i64()?.max(rhs_casted.i64()?)?.into_series())
}
DataType::UInt8 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.u8()?.max(rhs_casted.u8()?)?.into_series())
}
DataType::UInt16 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.u16()?.max(rhs_casted.u16()?)?.into_series())
}
DataType::UInt32 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.u32()?.max(rhs_casted.u32()?)?.into_series())
}
DataType::UInt64 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.u64()?.max(rhs_casted.u64()?)?.into_series())
}
DataType::Float32 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.f32()?.max(rhs_casted.f32()?)?.into_series())
}
DataType::Float64 => {
let lhs_casted = self.cast(&output_type)?;
let rhs_casted = rhs.cast(&output_type)?;
Ok(lhs_casted.f64()?.max(rhs_casted.f64()?)?.into_series())
}
dt => Err(DaftError::TypeError(format!(
"max not implemented for {}",
dt
))),
}
}

pub fn clip(&self, min: &Self, max: &Self) -> DaftResult<Self> {
// We follow numpy's semantics in defining clip (equivalent to np.minimum(a_max, np.maximum(a, a_min)).
// NOTE: As per numpy, this **doesn't** throw an error if max < min unlike the std::clamp function, it just returns an array that's entirely a_max.
self.binary_max(min)?.binary_min(max)
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub mod broadcast;
pub mod cast;
pub mod cbrt;
pub mod ceil;
pub mod clip;
pub mod comparison;
pub mod concat;
pub mod downcast;
Expand Down
Loading

0 comments on commit 54e79d8

Please sign in to comment.