Skip to content

Commit

Permalink
feat: expression plugins (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Sep 18, 2023
1 parent a46d98c commit 7d5e384
Show file tree
Hide file tree
Showing 36 changed files with 624 additions and 89 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
working-directory: pyo3-polars

- run: make install
working-directory: example
working-directory: example/extend_polars_python_dispatch

- run: venv/bin/python run.py
working-directory: example
working-directory: example/extend_polars_python_dispatch
15 changes: 15 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[workspace]
resolver = "2"
members = [
"example/derive_expression/expression_lib",
"example/extend_polars_python_dispatch/extend_polars",
"pyo3-polars",
"pyo3-polars-derive",
]

[workspace.dependencies]
polars = {version = "0.33.2", default-features=false}
polars-core = {version = "0.33.2", default-features=false}
polars-ffi = {ersion = "0.33.2", default-features=false}
polars-plan = {version = "0.33.2", default-feautres=false}
polars-lazy = {version = "0.33.2", default-features=false}
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# Pyo3 extensions for Polars
## 1. Shared library plugins for Polars
This is new functionality and not entirely stable, but should be preferred over `2.` as this
will circumvent the GIL and will be the way we want to support extending polars.

See more in `examples/derive_expression`.

## 2. Pyo3 extensions for Polars
<a href="https://crates.io/crates/pyo3-polars">
<img src="https://img.shields.io/crates/v/pyo3-polars.svg"/>
</a>
Expand Down
25 changes: 25 additions & 0 deletions example/derive_expression/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

SHELL=/bin/bash

venv: ## Set up virtual environment
python3 -m venv venv
venv/bin/pip install -r requirements.txt

install: venv
unset CONDA_PREFIX && \
source venv/bin/activate && maturin develop -m expression_lib/Cargo.toml

install-release: venv
unset CONDA_PREFIX && \
source venv/bin/activate && maturin develop --release -m expression_lib/Cargo.toml

clean:
-@rm -r venv
-@cd experssion_lib && cargo clean


run: install
source venv/bin/activate && python run.py

run-release: install-release
source venv/bin/activate && python run.py
File renamed without changes.
15 changes: 15 additions & 0 deletions example/derive_expression/expression_lib/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "expression_lib"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "expression_lib"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.19.0", features = ["extension-module"] }
pyo3-polars = { version = "*", path = "../../../pyo3-polars", features=["derive"] }
polars = { workspace = true, features = ["fmt"], default-features=false }
polars-plan = { workspace = true, default-features=false }
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import polars as pl
from polars.type_aliases import IntoExpr
from polars.utils.udfs import _get_shared_lib_location

lib = _get_shared_lib_location(__file__)


@pl.api.register_expr_namespace("language")
class Language:
def __init__(self, expr: pl.Expr):
self._expr = expr

def pig_latinnify(self) -> pl.Expr:
return self._expr._register_plugin(
lib=lib,
symbol="pig_latinnify",
is_elementwise=True,
)

@pl.api.register_expr_namespace("dist")
class Distance:
def __init__(self, expr: pl.Expr):
self._expr = expr

def hamming_distance(self, other: IntoExpr) -> pl.Expr:
return self._expr._register_plugin(
lib=lib,
args=[other],
symbol="hamming_distance",
is_elementwise=True,
)

def jaccard_similarity(self, other: IntoExpr) -> pl.Expr:
return self._expr._register_plugin(
lib=lib,
args=[other],
symbol="jaccard_similarity",
is_elementwise=True,
)

def haversine(self, start_lat: IntoExpr, start_long: IntoExpr, end_lat: IntoExpr, end_long: IntoExpr) -> pl.Expr:
return self._expr._register_plugin(
lib=lib,
args=[start_lat, start_long, end_lat, end_long],
symbol="haversine",
is_elementwise=True,
cast_to_supertypes=True
)
14 changes: 14 additions & 0 deletions example/derive_expression/expression_lib/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[build-system]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"

[project]
name = "expression_lib"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]


95 changes: 95 additions & 0 deletions example/derive_expression/expression_lib/src/distances.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use polars::datatypes::PlHashSet;
use polars::export::arrow::array::PrimitiveArray;
use polars::export::num::Float;
use polars::prelude::*;
use pyo3_polars::export::polars_core::utils::arrow::types::NativeType;
use pyo3_polars::export::polars_core::with_match_physical_integer_type;
use std::hash::Hash;

#[allow(clippy::all)]
pub(super) fn naive_hamming_dist(a: &str, b: &str) -> u32 {
let x = a.as_bytes();
let y = b.as_bytes();
x.iter()
.zip(y)
.fold(0, |a, (b, c)| a + (*b ^ *c).count_ones() as u32)
}

fn jacc_helper<T: NativeType + Hash + Eq>(a: &PrimitiveArray<T>, b: &PrimitiveArray<T>) -> f64 {
// convert to hashsets over Option<T>
let s1 = a.into_iter().collect::<PlHashSet<_>>();
let s2 = b.into_iter().collect::<PlHashSet<_>>();

// count the number of intersections
let s3_len = s1.intersection(&s2).count();
// return similarity
s3_len as f64 / (s1.len() + s2.len() - s3_len) as f64
}

pub(super) fn naive_jaccard_sim(a: &ListChunked, b: &ListChunked) -> PolarsResult<Float64Chunked> {
polars_ensure!(
a.inner_dtype() == b.inner_dtype(),
ComputeError: "inner data types don't match"
);
polars_ensure!(
a.inner_dtype().is_integer(),
ComputeError: "inner data types must be integer"
);
Ok(with_match_physical_integer_type!(a.inner_dtype(), |$T| {
polars::prelude::arity::binary_elementwise(a, b, |a, b| {
match (a, b) {
(Some(a), Some(b)) => {
let a = a.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
let b = b.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
Some(jacc_helper(a, b))
},
_ => None
}
})
}))
}

fn haversine_elementwise<T: Float>(start_lat: T, start_long: T, end_lat: T, end_long: T) -> T {
let r_in_km = T::from(6371.0).unwrap();
let two = T::from(2.0).unwrap();
let one = T::one();

let d_lat = (end_lat - start_lat).to_radians();
let d_lon = (end_long - start_long).to_radians();
let lat1 = (start_lat).to_radians();
let lat2 = (end_lat).to_radians();

let a = ((d_lat / two).sin()) * ((d_lat / two).sin())
+ ((d_lon / two).sin()) * ((d_lon / two).sin()) * (lat1.cos()) * (lat2.cos());
let c = two * ((a.sqrt()).atan2((one - a).sqrt()));
r_in_km * c
}

pub(super) fn naive_haversine<T>(
start_lat: &ChunkedArray<T>,
start_long: &ChunkedArray<T>,
end_lat: &ChunkedArray<T>,
end_long: &ChunkedArray<T>,
) -> PolarsResult<ChunkedArray<T>>
where
T: PolarsFloatType,
T::Native: Float,
{
let out: ChunkedArray<T> = start_lat
.into_iter()
.zip(start_long.into_iter())
.zip(end_lat.into_iter())
.zip(end_long.into_iter())
.map(|(((start_lat, start_long), end_lat), end_long)| {
let start_lat = start_lat?;
let start_long = start_long?;
let end_lat = end_lat?;
let end_long = end_long?;
Some(haversine_elementwise(
start_lat, start_long, end_lat, end_long,
))
})
.collect();

Ok(out.with_name(start_lat.name()))
}
61 changes: 61 additions & 0 deletions example/derive_expression/expression_lib/src/expressions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use polars::prelude::*;
use polars_plan::dsl::FieldsMapper;
use pyo3_polars::derive::polars_expr;
use std::fmt::Write;

fn pig_latin_str(value: &str, output: &mut String) {
if let Some(first_char) = value.chars().next() {
write!(output, "{}{}ay", &value[1..], first_char).unwrap()
}
}

#[polars_expr(output_type=Utf8)]
fn pig_latinnify(inputs: &[Series]) -> PolarsResult<Series> {
let ca = inputs[0].utf8()?;
let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str);
Ok(out.into_series())
}

#[polars_expr(output_type=Float64)]
fn jaccard_similarity(inputs: &[Series]) -> PolarsResult<Series> {
let a = inputs[0].list()?;
let b = inputs[1].list()?;
crate::distances::naive_jaccard_sim(a, b).map(|ca| ca.into_series())
}

#[polars_expr(output_type=Float64)]
fn hamming_distance(inputs: &[Series]) -> PolarsResult<Series> {
let a = inputs[0].utf8()?;
let b = inputs[1].utf8()?;
let out: UInt32Chunked =
arity::binary_elementwise_values(a, b, crate::distances::naive_hamming_dist);
Ok(out.into_series())
}

fn haversine_output(input_fields: &[Field]) -> PolarsResult<Field> {
FieldsMapper::new(input_fields).map_to_float_dtype()
}

#[polars_expr(type_func=haversine_output)]
fn haversine(inputs: &[Series]) -> PolarsResult<Series> {
let out = match inputs[0].dtype() {
DataType::Float32 => {
let start_lat = inputs[0].f32().unwrap();
let start_long = inputs[1].f32().unwrap();
let end_lat = inputs[2].f32().unwrap();
let end_long = inputs[3].f32().unwrap();
crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)?
.into_series()
}
DataType::Float64 => {
let start_lat = inputs[0].f64().unwrap();
let start_long = inputs[1].f64().unwrap();
let end_lat = inputs[2].f64().unwrap();
let end_long = inputs[3].f64().unwrap();
crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)?
.into_series()
}
_ => unimplemented!(),
};
Ok(out)
}
2 changes: 2 additions & 0 deletions example/derive_expression/expression_lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mod distances;
mod expressions;
1 change: 1 addition & 0 deletions example/derive_expression/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
maturin
19 changes: 19 additions & 0 deletions example/derive_expression/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import polars as pl
from expression_lib import Language, Distance

df = pl.DataFrame({
"names": ["Richard", "Alice", "Bob"],
"moons": ["full", "half", "red"],
"dist_a": [[12, 32, 1], [], [1, -2]],
"dist_b": [[-12, 1], [43], [876, -45, 9]]
})


out = df.with_columns(
pig_latin = pl.col("names").language.pig_latinnify()
).with_columns(
hamming_dist = pl.col("names").dist.hamming_distance("pig_latin"),
jaccard_sim = pl.col("dist_a").dist.jaccard_similarity("dist_b")
)

print(out)
70 changes: 0 additions & 70 deletions example/extend_polars/.github/workflows/CI.yml

This file was deleted.

Loading

0 comments on commit 7d5e384

Please sign in to comment.