-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(clip): add binary_min, binary_max and clip function
- Loading branch information
1 parent
b942d44
commit 54e79d8
Showing
10 changed files
with
452 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
use common_error::DaftResult; | ||
|
||
use crate::{ | ||
array::DataArray, | ||
datatypes::{DaftNumericType, Float32Type, Float64Type}, | ||
prelude::DaftIntegerType, | ||
}; | ||
|
||
impl<T> DataArray<T> | ||
where | ||
T: DaftNumericType + DaftIntegerType, // Need the DaftIntegerType to tell the compiler that this doesn't apply for Float32/Float64, so we can specialize the implementation | ||
T::Native: Ord, | ||
{ | ||
pub fn min(&self, rhs: &Self) -> DaftResult<Self> { | ||
self.binary_apply(rhs, |l, r| l.min(r)) | ||
} | ||
pub fn max(&self, rhs: &Self) -> DaftResult<Self> { | ||
self.binary_apply(rhs, |l, r| l.max(r)) | ||
} | ||
} | ||
|
||
// Ideally, I'd like to further specialize the template to use the float's min/max version, | ||
// but I am too dumb for now to figure out Rust's type system, so let's just do this for now. | ||
|
||
impl DataArray<Float32Type> { | ||
pub fn min(&self, rhs: &Self) -> DaftResult<Self> { | ||
self.binary_apply(rhs, |l, r| l.min(r)) | ||
} | ||
pub fn max(&self, rhs: &Self) -> DaftResult<Self> { | ||
self.binary_apply(rhs, |l, r| l.max(r)) | ||
} | ||
} | ||
|
||
impl DataArray<Float64Type> { | ||
pub fn min(&self, rhs: &Self) -> DaftResult<Self> { | ||
self.binary_apply(rhs, |l, r| l.min(r)) | ||
} | ||
pub fn max(&self, rhs: &Self) -> DaftResult<Self> { | ||
self.binary_apply(rhs, |l, r| l.max(r)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
use common_error::{DaftError, DaftResult}; | ||
use daft_schema::prelude::*; | ||
|
||
use crate::{ | ||
datatypes::InferDataType, | ||
series::{IntoSeries, Series}, | ||
}; | ||
|
||
impl Series { | ||
pub fn binary_min(&self, rhs: &Self) -> DaftResult<Self> { | ||
let (_, _, output_type) = InferDataType::from(self.data_type()) | ||
.comparison_op(&InferDataType::from(rhs.data_type()))?; | ||
|
||
match &output_type { | ||
DataType::Int8 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.i8()?.min(rhs_casted.i8()?)?.into_series()) | ||
} | ||
DataType::Int16 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.i16()?.min(rhs_casted.i16()?)?.into_series()) | ||
} | ||
DataType::Int32 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.i32()?.min(rhs_casted.i32()?)?.into_series()) | ||
} | ||
DataType::Int64 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.i64()?.min(rhs_casted.i64()?)?.into_series()) | ||
} | ||
DataType::UInt8 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.u8()?.min(rhs_casted.u8()?)?.into_series()) | ||
} | ||
DataType::UInt16 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.u16()?.min(rhs_casted.u16()?)?.into_series()) | ||
} | ||
DataType::UInt32 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.u32()?.min(rhs_casted.u32()?)?.into_series()) | ||
} | ||
DataType::UInt64 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.u64()?.min(rhs_casted.u64()?)?.into_series()) | ||
} | ||
DataType::Float32 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.f32()?.min(rhs_casted.f32()?)?.into_series()) | ||
} | ||
DataType::Float64 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.f64()?.min(rhs_casted.f64()?)?.into_series()) | ||
} | ||
dt => Err(DaftError::TypeError(format!( | ||
"min not implemented for {}", | ||
dt | ||
))), | ||
} | ||
} | ||
|
||
pub fn binary_max(&self, rhs: &Self) -> DaftResult<Self> { | ||
let (_, _, output_type) = InferDataType::from(self.data_type()) | ||
.comparison_op(&InferDataType::from(rhs.data_type()))?; | ||
|
||
match &output_type { | ||
DataType::Int8 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.i8()?.max(rhs_casted.i8()?)?.into_series()) | ||
} | ||
DataType::Int16 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.i16()?.max(rhs_casted.i16()?)?.into_series()) | ||
} | ||
DataType::Int32 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.i32()?.max(rhs_casted.i32()?)?.into_series()) | ||
} | ||
DataType::Int64 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.i64()?.max(rhs_casted.i64()?)?.into_series()) | ||
} | ||
DataType::UInt8 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.u8()?.max(rhs_casted.u8()?)?.into_series()) | ||
} | ||
DataType::UInt16 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.u16()?.max(rhs_casted.u16()?)?.into_series()) | ||
} | ||
DataType::UInt32 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.u32()?.max(rhs_casted.u32()?)?.into_series()) | ||
} | ||
DataType::UInt64 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.u64()?.max(rhs_casted.u64()?)?.into_series()) | ||
} | ||
DataType::Float32 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.f32()?.max(rhs_casted.f32()?)?.into_series()) | ||
} | ||
DataType::Float64 => { | ||
let lhs_casted = self.cast(&output_type)?; | ||
let rhs_casted = rhs.cast(&output_type)?; | ||
Ok(lhs_casted.f64()?.max(rhs_casted.f64()?)?.into_series()) | ||
} | ||
dt => Err(DaftError::TypeError(format!( | ||
"max not implemented for {}", | ||
dt | ||
))), | ||
} | ||
} | ||
|
||
pub fn clip(&self, min: &Self, max: &Self) -> DaftResult<Self> { | ||
// We follow numpy's semantics in defining clip (equivalent to np.minimum(a_max, np.maximum(a, a_min)). | ||
// NOTE: As per numpy, this **doesn't** throw an error if max < min unlike the std::clamp function, it just returns an array that's entirely a_max. | ||
self.binary_max(min)?.binary_min(max) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.