From 7d5e3842b80b4742ac1b1f559c36faeafe6eb6ef Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 18 Sep 2023 11:26:15 +0200 Subject: [PATCH] feat: expression plugins (#26) --- .github/workflows/CI.yml | 4 +- Cargo.toml | 15 +++ README.md | 8 +- example/derive_expression/Makefile | 25 +++++ .../expression_lib}/.gitignore | 0 .../expression_lib/Cargo.toml | 15 +++ .../expression_lib/expression_lib/__init__.py | 48 +++++++++ .../expression_lib/pyproject.toml | 14 +++ .../expression_lib/src/distances.rs | 95 ++++++++++++++++++ .../expression_lib/src/expressions.rs | 61 ++++++++++++ .../expression_lib/src/lib.rs | 2 + example/derive_expression/requirements.txt | 1 + example/derive_expression/run.py | 19 ++++ .../extend_polars/.github/workflows/CI.yml | 70 ------------- .../Makefile | 9 +- .../extend_polars/.gitignore | 72 ++++++++++++++ .../extend_polars/Cargo.toml | 8 +- .../extend_polars/pyproject.toml | 0 .../extend_polars/src/lib.rs | 0 .../extend_polars/src/parallel_jaccard_mod.rs | 0 .../requirements.txt | 0 .../run.py | 0 pyo3-polars-derive/Cargo.toml | 24 +++++ pyo3-polars-derive/src/attr.rs | 50 ++++++++++ pyo3-polars-derive/src/keywords.rs | 2 + pyo3-polars-derive/src/lib.rs | 98 +++++++++++++++++++ pyo3-polars-derive/tests/01.rs | 19 ++++ pyo3-polars-derive/tests/02.rs | 14 +++ pyo3-polars-derive/tests/run.rs | 6 ++ pyo3-polars/Cargo.toml | 15 +-- pyo3-polars/src/derive.rs | 1 + pyo3-polars/src/error.rs | 3 +- pyo3-polars/src/export.rs | 3 + pyo3-polars/src/ffi/to_py.rs | 2 +- pyo3-polars/src/ffi/to_rust.rs | 2 +- pyo3-polars/src/lib.rs | 8 +- 36 files changed, 624 insertions(+), 89 deletions(-) create mode 100644 Cargo.toml create mode 100644 example/derive_expression/Makefile rename example/{extend_polars => derive_expression/expression_lib}/.gitignore (100%) create mode 100644 example/derive_expression/expression_lib/Cargo.toml create mode 100644 example/derive_expression/expression_lib/expression_lib/__init__.py create mode 100644 example/derive_expression/expression_lib/pyproject.toml create mode 100644 example/derive_expression/expression_lib/src/distances.rs create mode 100644 example/derive_expression/expression_lib/src/expressions.rs create mode 100644 example/derive_expression/expression_lib/src/lib.rs create mode 100644 example/derive_expression/requirements.txt create mode 100644 example/derive_expression/run.py delete mode 100644 example/extend_polars/.github/workflows/CI.yml rename example/{ => extend_polars_python_dispatch}/Makefile (69%) create mode 100644 example/extend_polars_python_dispatch/extend_polars/.gitignore rename example/{ => extend_polars_python_dispatch}/extend_polars/Cargo.toml (60%) rename example/{ => extend_polars_python_dispatch}/extend_polars/pyproject.toml (100%) rename example/{ => extend_polars_python_dispatch}/extend_polars/src/lib.rs (100%) rename example/{ => extend_polars_python_dispatch}/extend_polars/src/parallel_jaccard_mod.rs (100%) rename example/{ => extend_polars_python_dispatch}/requirements.txt (100%) rename example/{ => extend_polars_python_dispatch}/run.py (100%) create mode 100644 pyo3-polars-derive/Cargo.toml create mode 100644 pyo3-polars-derive/src/attr.rs create mode 100644 pyo3-polars-derive/src/keywords.rs create mode 100644 pyo3-polars-derive/src/lib.rs create mode 100644 pyo3-polars-derive/tests/01.rs create mode 100644 pyo3-polars-derive/tests/02.rs create mode 100644 pyo3-polars-derive/tests/run.rs create mode 100644 pyo3-polars/src/derive.rs create mode 100644 pyo3-polars/src/export.rs diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 1eb0ee6..da9d92f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -37,7 +37,7 @@ jobs: working-directory: pyo3-polars - run: make install - working-directory: example + working-directory: example/extend_polars_python_dispatch - run: venv/bin/python run.py - working-directory: example + working-directory: example/extend_polars_python_dispatch diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..0b842d9 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,15 @@ +[workspace] +resolver = "2" +members = [ + "example/derive_expression/expression_lib", + "example/extend_polars_python_dispatch/extend_polars", + "pyo3-polars", + "pyo3-polars-derive", +] + +[workspace.dependencies] +polars = {version = "0.33.2", default-features=false} +polars-core = {version = "0.33.2", default-features=false} +polars-ffi = {ersion = "0.33.2", default-features=false} +polars-plan = {version = "0.33.2", default-feautres=false} +polars-lazy = {version = "0.33.2", default-features=false} diff --git a/README.md b/README.md index e1a41af..66e87ea 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,10 @@ -# Pyo3 extensions for Polars +## 1. Shared library plugins for Polars +This is new functionality and not entirely stable, but should be preferred over `2.` as this +will circumvent the GIL and will be the way we want to support extending polars. + +See more in `examples/derive_expression`. + +## 2. Pyo3 extensions for Polars diff --git a/example/derive_expression/Makefile b/example/derive_expression/Makefile new file mode 100644 index 0000000..7aef8de --- /dev/null +++ b/example/derive_expression/Makefile @@ -0,0 +1,25 @@ + +SHELL=/bin/bash + +venv: ## Set up virtual environment + python3 -m venv venv + venv/bin/pip install -r requirements.txt + +install: venv + unset CONDA_PREFIX && \ + source venv/bin/activate && maturin develop -m expression_lib/Cargo.toml + +install-release: venv + unset CONDA_PREFIX && \ + source venv/bin/activate && maturin develop --release -m expression_lib/Cargo.toml + +clean: + -@rm -r venv + -@cd experssion_lib && cargo clean + + +run: install + source venv/bin/activate && python run.py + +run-release: install-release + source venv/bin/activate && python run.py diff --git a/example/extend_polars/.gitignore b/example/derive_expression/expression_lib/.gitignore similarity index 100% rename from example/extend_polars/.gitignore rename to example/derive_expression/expression_lib/.gitignore diff --git a/example/derive_expression/expression_lib/Cargo.toml b/example/derive_expression/expression_lib/Cargo.toml new file mode 100644 index 0000000..b767c1f --- /dev/null +++ b/example/derive_expression/expression_lib/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "expression_lib" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "expression_lib" +crate-type = ["cdylib"] + +[dependencies] +pyo3 = { version = "0.19.0", features = ["extension-module"] } +pyo3-polars = { version = "*", path = "../../../pyo3-polars", features=["derive"] } +polars = { workspace = true, features = ["fmt"], default-features=false } +polars-plan = { workspace = true, default-features=false } diff --git a/example/derive_expression/expression_lib/expression_lib/__init__.py b/example/derive_expression/expression_lib/expression_lib/__init__.py new file mode 100644 index 0000000..b0b8570 --- /dev/null +++ b/example/derive_expression/expression_lib/expression_lib/__init__.py @@ -0,0 +1,48 @@ +import polars as pl +from polars.type_aliases import IntoExpr +from polars.utils.udfs import _get_shared_lib_location + +lib = _get_shared_lib_location(__file__) + + +@pl.api.register_expr_namespace("language") +class Language: + def __init__(self, expr: pl.Expr): + self._expr = expr + + def pig_latinnify(self) -> pl.Expr: + return self._expr._register_plugin( + lib=lib, + symbol="pig_latinnify", + is_elementwise=True, + ) + +@pl.api.register_expr_namespace("dist") +class Distance: + def __init__(self, expr: pl.Expr): + self._expr = expr + + def hamming_distance(self, other: IntoExpr) -> pl.Expr: + return self._expr._register_plugin( + lib=lib, + args=[other], + symbol="hamming_distance", + is_elementwise=True, + ) + + def jaccard_similarity(self, other: IntoExpr) -> pl.Expr: + return self._expr._register_plugin( + lib=lib, + args=[other], + symbol="jaccard_similarity", + is_elementwise=True, + ) + + def haversine(self, start_lat: IntoExpr, start_long: IntoExpr, end_lat: IntoExpr, end_long: IntoExpr) -> pl.Expr: + return self._expr._register_plugin( + lib=lib, + args=[start_lat, start_long, end_lat, end_long], + symbol="haversine", + is_elementwise=True, + cast_to_supertypes=True + ) diff --git a/example/derive_expression/expression_lib/pyproject.toml b/example/derive_expression/expression_lib/pyproject.toml new file mode 100644 index 0000000..851e5d7 --- /dev/null +++ b/example/derive_expression/expression_lib/pyproject.toml @@ -0,0 +1,14 @@ +[build-system] +requires = ["maturin>=1.0,<2.0"] +build-backend = "maturin" + +[project] +name = "expression_lib" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] + + diff --git a/example/derive_expression/expression_lib/src/distances.rs b/example/derive_expression/expression_lib/src/distances.rs new file mode 100644 index 0000000..943aa4a --- /dev/null +++ b/example/derive_expression/expression_lib/src/distances.rs @@ -0,0 +1,95 @@ +use polars::datatypes::PlHashSet; +use polars::export::arrow::array::PrimitiveArray; +use polars::export::num::Float; +use polars::prelude::*; +use pyo3_polars::export::polars_core::utils::arrow::types::NativeType; +use pyo3_polars::export::polars_core::with_match_physical_integer_type; +use std::hash::Hash; + +#[allow(clippy::all)] +pub(super) fn naive_hamming_dist(a: &str, b: &str) -> u32 { + let x = a.as_bytes(); + let y = b.as_bytes(); + x.iter() + .zip(y) + .fold(0, |a, (b, c)| a + (*b ^ *c).count_ones() as u32) +} + +fn jacc_helper(a: &PrimitiveArray, b: &PrimitiveArray) -> f64 { + // convert to hashsets over Option + let s1 = a.into_iter().collect::>(); + let s2 = b.into_iter().collect::>(); + + // count the number of intersections + let s3_len = s1.intersection(&s2).count(); + // return similarity + s3_len as f64 / (s1.len() + s2.len() - s3_len) as f64 +} + +pub(super) fn naive_jaccard_sim(a: &ListChunked, b: &ListChunked) -> PolarsResult { + polars_ensure!( + a.inner_dtype() == b.inner_dtype(), + ComputeError: "inner data types don't match" + ); + polars_ensure!( + a.inner_dtype().is_integer(), + ComputeError: "inner data types must be integer" + ); + Ok(with_match_physical_integer_type!(a.inner_dtype(), |$T| { + polars::prelude::arity::binary_elementwise(a, b, |a, b| { + match (a, b) { + (Some(a), Some(b)) => { + let a = a.as_any().downcast_ref::>().unwrap(); + let b = b.as_any().downcast_ref::>().unwrap(); + Some(jacc_helper(a, b)) + }, + _ => None + } + }) + })) +} + +fn haversine_elementwise(start_lat: T, start_long: T, end_lat: T, end_long: T) -> T { + let r_in_km = T::from(6371.0).unwrap(); + let two = T::from(2.0).unwrap(); + let one = T::one(); + + let d_lat = (end_lat - start_lat).to_radians(); + let d_lon = (end_long - start_long).to_radians(); + let lat1 = (start_lat).to_radians(); + let lat2 = (end_lat).to_radians(); + + let a = ((d_lat / two).sin()) * ((d_lat / two).sin()) + + ((d_lon / two).sin()) * ((d_lon / two).sin()) * (lat1.cos()) * (lat2.cos()); + let c = two * ((a.sqrt()).atan2((one - a).sqrt())); + r_in_km * c +} + +pub(super) fn naive_haversine( + start_lat: &ChunkedArray, + start_long: &ChunkedArray, + end_lat: &ChunkedArray, + end_long: &ChunkedArray, +) -> PolarsResult> +where + T: PolarsFloatType, + T::Native: Float, +{ + let out: ChunkedArray = start_lat + .into_iter() + .zip(start_long.into_iter()) + .zip(end_lat.into_iter()) + .zip(end_long.into_iter()) + .map(|(((start_lat, start_long), end_lat), end_long)| { + let start_lat = start_lat?; + let start_long = start_long?; + let end_lat = end_lat?; + let end_long = end_long?; + Some(haversine_elementwise( + start_lat, start_long, end_lat, end_long, + )) + }) + .collect(); + + Ok(out.with_name(start_lat.name())) +} diff --git a/example/derive_expression/expression_lib/src/expressions.rs b/example/derive_expression/expression_lib/src/expressions.rs new file mode 100644 index 0000000..273acc5 --- /dev/null +++ b/example/derive_expression/expression_lib/src/expressions.rs @@ -0,0 +1,61 @@ +use polars::prelude::*; +use polars_plan::dsl::FieldsMapper; +use pyo3_polars::derive::polars_expr; +use std::fmt::Write; + +fn pig_latin_str(value: &str, output: &mut String) { + if let Some(first_char) = value.chars().next() { + write!(output, "{}{}ay", &value[1..], first_char).unwrap() + } +} + +#[polars_expr(output_type=Utf8)] +fn pig_latinnify(inputs: &[Series]) -> PolarsResult { + let ca = inputs[0].utf8()?; + let out: Utf8Chunked = ca.apply_to_buffer(pig_latin_str); + Ok(out.into_series()) +} + +#[polars_expr(output_type=Float64)] +fn jaccard_similarity(inputs: &[Series]) -> PolarsResult { + let a = inputs[0].list()?; + let b = inputs[1].list()?; + crate::distances::naive_jaccard_sim(a, b).map(|ca| ca.into_series()) +} + +#[polars_expr(output_type=Float64)] +fn hamming_distance(inputs: &[Series]) -> PolarsResult { + let a = inputs[0].utf8()?; + let b = inputs[1].utf8()?; + let out: UInt32Chunked = + arity::binary_elementwise_values(a, b, crate::distances::naive_hamming_dist); + Ok(out.into_series()) +} + +fn haversine_output(input_fields: &[Field]) -> PolarsResult { + FieldsMapper::new(input_fields).map_to_float_dtype() +} + +#[polars_expr(type_func=haversine_output)] +fn haversine(inputs: &[Series]) -> PolarsResult { + let out = match inputs[0].dtype() { + DataType::Float32 => { + let start_lat = inputs[0].f32().unwrap(); + let start_long = inputs[1].f32().unwrap(); + let end_lat = inputs[2].f32().unwrap(); + let end_long = inputs[3].f32().unwrap(); + crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)? + .into_series() + } + DataType::Float64 => { + let start_lat = inputs[0].f64().unwrap(); + let start_long = inputs[1].f64().unwrap(); + let end_lat = inputs[2].f64().unwrap(); + let end_long = inputs[3].f64().unwrap(); + crate::distances::naive_haversine(start_lat, start_long, end_lat, end_long)? + .into_series() + } + _ => unimplemented!(), + }; + Ok(out) +} diff --git a/example/derive_expression/expression_lib/src/lib.rs b/example/derive_expression/expression_lib/src/lib.rs new file mode 100644 index 0000000..3fd00a4 --- /dev/null +++ b/example/derive_expression/expression_lib/src/lib.rs @@ -0,0 +1,2 @@ +mod distances; +mod expressions; diff --git a/example/derive_expression/requirements.txt b/example/derive_expression/requirements.txt new file mode 100644 index 0000000..dbf962f --- /dev/null +++ b/example/derive_expression/requirements.txt @@ -0,0 +1 @@ +maturin diff --git a/example/derive_expression/run.py b/example/derive_expression/run.py new file mode 100644 index 0000000..392a56d --- /dev/null +++ b/example/derive_expression/run.py @@ -0,0 +1,19 @@ +import polars as pl +from expression_lib import Language, Distance + +df = pl.DataFrame({ + "names": ["Richard", "Alice", "Bob"], + "moons": ["full", "half", "red"], + "dist_a": [[12, 32, 1], [], [1, -2]], + "dist_b": [[-12, 1], [43], [876, -45, 9]] +}) + + +out = df.with_columns( + pig_latin = pl.col("names").language.pig_latinnify() +).with_columns( + hamming_dist = pl.col("names").dist.hamming_distance("pig_latin"), + jaccard_sim = pl.col("dist_a").dist.jaccard_similarity("dist_b") +) + +print(out) diff --git a/example/extend_polars/.github/workflows/CI.yml b/example/extend_polars/.github/workflows/CI.yml deleted file mode 100644 index 074743e..0000000 --- a/example/extend_polars/.github/workflows/CI.yml +++ /dev/null @@ -1,70 +0,0 @@ -name: CI - -on: - push: - branches: - - main - - master - pull_request: - workflow_dispatch: - -jobs: - linux: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - uses: PyO3/maturin-action@v1 - with: - manylinux: auto - command: build - args: --release --sdist -o dist --find-interpreter - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - windows: - runs-on: windows-latest - steps: - - uses: actions/checkout@v3 - - uses: PyO3/maturin-action@v1 - with: - command: build - args: --release -o dist --find-interpreter - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - macos: - runs-on: macos-latest - steps: - - uses: actions/checkout@v3 - - uses: PyO3/maturin-action@v1 - with: - command: build - args: --release -o dist --universal2 --find-interpreter - - name: Upload wheels - uses: actions/upload-artifact@v3 - with: - name: wheels - path: dist - - release: - name: Release - runs-on: ubuntu-latest - if: "startsWith(github.ref, 'refs/tags/')" - needs: [ macos, windows, linux ] - steps: - - uses: actions/download-artifact@v3 - with: - name: wheels - - name: Publish to PyPI - uses: PyO3/maturin-action@v1 - env: - MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} - with: - command: upload - args: --skip-existing * \ No newline at end of file diff --git a/example/Makefile b/example/extend_polars_python_dispatch/Makefile similarity index 69% rename from example/Makefile rename to example/extend_polars_python_dispatch/Makefile index 431360f..9a99eb7 100644 --- a/example/Makefile +++ b/example/extend_polars_python_dispatch/Makefile @@ -15,4 +15,11 @@ install-release: venv clean: -@rm -r venv - -@cd extend_polars && cargo clean \ No newline at end of file + -@cd extend_polars && cargo clean + + +run: install + source venv/bin/activate && python run.py + +run-release: install-release + source venv/bin/activate && python run.py diff --git a/example/extend_polars_python_dispatch/extend_polars/.gitignore b/example/extend_polars_python_dispatch/extend_polars/.gitignore new file mode 100644 index 0000000..af3ca5e --- /dev/null +++ b/example/extend_polars_python_dispatch/extend_polars/.gitignore @@ -0,0 +1,72 @@ +/target + +# Byte-compiled / optimized / DLL files +__pycache__/ +.pytest_cache/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +.venv/ +env/ +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +include/ +man/ +venv/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +.DS_Store + +# Sphinx documentation +docs/_build/ + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# Pyenv +.python-version \ No newline at end of file diff --git a/example/extend_polars/Cargo.toml b/example/extend_polars_python_dispatch/extend_polars/Cargo.toml similarity index 60% rename from example/extend_polars/Cargo.toml rename to example/extend_polars_python_dispatch/extend_polars/Cargo.toml index b69f42d..406299c 100644 --- a/example/extend_polars/Cargo.toml +++ b/example/extend_polars_python_dispatch/extend_polars/Cargo.toml @@ -10,8 +10,8 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.19", features = ["extension-module"] } -pyo3-polars = { version = "*", path = "../../pyo3-polars", features = ["lazy"] } -polars-core = { version = "0.32.0" } -polars-lazy = "*" -polars = { version = "0.32.0", features = ["fmt"] } +pyo3-polars = { version = "*", path = "../../../pyo3-polars", features = ["lazy"] } +polars-core = { workspace = true } +polars-lazy = {workspace = true} +polars = { workspace = true, features = ["fmt"] } rayon = "1.6" diff --git a/example/extend_polars/pyproject.toml b/example/extend_polars_python_dispatch/extend_polars/pyproject.toml similarity index 100% rename from example/extend_polars/pyproject.toml rename to example/extend_polars_python_dispatch/extend_polars/pyproject.toml diff --git a/example/extend_polars/src/lib.rs b/example/extend_polars_python_dispatch/extend_polars/src/lib.rs similarity index 100% rename from example/extend_polars/src/lib.rs rename to example/extend_polars_python_dispatch/extend_polars/src/lib.rs diff --git a/example/extend_polars/src/parallel_jaccard_mod.rs b/example/extend_polars_python_dispatch/extend_polars/src/parallel_jaccard_mod.rs similarity index 100% rename from example/extend_polars/src/parallel_jaccard_mod.rs rename to example/extend_polars_python_dispatch/extend_polars/src/parallel_jaccard_mod.rs diff --git a/example/requirements.txt b/example/extend_polars_python_dispatch/requirements.txt similarity index 100% rename from example/requirements.txt rename to example/extend_polars_python_dispatch/requirements.txt diff --git a/example/run.py b/example/extend_polars_python_dispatch/run.py similarity index 100% rename from example/run.py rename to example/extend_polars_python_dispatch/run.py diff --git a/pyo3-polars-derive/Cargo.toml b/pyo3-polars-derive/Cargo.toml new file mode 100644 index 0000000..6de85ab --- /dev/null +++ b/pyo3-polars-derive/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "pyo3-polars-derive" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lib] +proc-macro = true + +[[test]] +name = "tests" +path = "tests/run.rs" + +[dependencies] +polars-core = {workspace = true} +polars-ffi = {workspace = true} +polars-plan = {workspace = true} +syn = {version = "2", features= ["full", "extra-traits"]} +quote = "1.0" +proc-macro2 = "1.0" + +[dev-dependencies] +trybuild = { version = "1", features = ["diff"] } diff --git a/pyo3-polars-derive/src/attr.rs b/pyo3-polars-derive/src/attr.rs new file mode 100644 index 0000000..a5fecc3 --- /dev/null +++ b/pyo3-polars-derive/src/attr.rs @@ -0,0 +1,50 @@ +use crate::keywords; +use proc_macro2::Ident; +use std::fmt::Debug; +use syn::parse::{Parse, ParseStream}; +use syn::Token; + +#[derive(Clone, Debug)] +pub struct KeyWordAttribute { + pub kw: K, + pub value: V, +} + +impl Parse for KeyWordAttribute { + fn parse(input: ParseStream) -> syn::Result { + let kw = input.parse()?; + let _: Token![=] = input.parse()?; + let value = input.parse()?; + Ok(KeyWordAttribute { kw, value }) + } +} + +pub type OutputAttribute = KeyWordAttribute; +pub type OutputFuncAttribute = KeyWordAttribute; + +#[derive(Default, Debug)] +pub struct ExprsFunctionOptions { + pub output_dtype: Option, + pub output_type_fn: Option, +} + +impl Parse for ExprsFunctionOptions { + fn parse(input: ParseStream<'_>) -> syn::Result { + let mut options = ExprsFunctionOptions::default(); + + while !input.is_empty() { + let lookahead = input.lookahead1(); + + if lookahead.peek(keywords::output_type) { + let attr = input.parse::()?; + options.output_dtype = Some(attr.value) + } else if lookahead.peek(keywords::type_func) { + let attr = input.parse::()?; + options.output_type_fn = Some(attr.value) + } else { + panic!("didn't recognize attribute") + } + } + Ok(options) + } +} diff --git a/pyo3-polars-derive/src/keywords.rs b/pyo3-polars-derive/src/keywords.rs new file mode 100644 index 0000000..062baba --- /dev/null +++ b/pyo3-polars-derive/src/keywords.rs @@ -0,0 +1,2 @@ +syn::custom_keyword!(output_type); +syn::custom_keyword!(type_func); diff --git a/pyo3-polars-derive/src/lib.rs b/pyo3-polars-derive/src/lib.rs new file mode 100644 index 0000000..f0de3d5 --- /dev/null +++ b/pyo3-polars-derive/src/lib.rs @@ -0,0 +1,98 @@ +mod attr; +mod keywords; + +use proc_macro::TokenStream; +use quote::quote; +use syn::parse_macro_input; + +fn create_expression_function(ast: syn::ItemFn) -> proc_macro2::TokenStream { + let fn_name = &ast.sig.ident; + + quote!( + use pyo3_polars::export::*; + // create the outer public function + #[no_mangle] + pub unsafe extern "C" fn #fn_name (e: *mut polars_ffi::SeriesExport, len: usize) -> polars_ffi::SeriesExport { + let inputs = polars_ffi::import_series_buffer(e, len).unwrap(); + + // define the function + #ast + + // call the function + let output: polars_core::prelude::Series = #fn_name(&inputs).unwrap(); + let out = polars_ffi::export_series(&output); + out + } + ) +} + +fn get_field_name(fn_name: &syn::Ident) -> syn::Ident { + syn::Ident::new(&format!("__polars_field_{}", fn_name), fn_name.span()) +} + +fn get_inputs() -> proc_macro2::TokenStream { + quote!( + let inputs = std::slice::from_raw_parts(field, len); + let inputs = inputs.iter().map(|field| { + let field = polars_core::export::arrow::ffi::import_field_from_c(field).unwrap(); + let out = polars_core::prelude::Field::from(&field); + out + }).collect::>(); + ) +} + +fn create_field_function(fn_name: &syn::Ident) -> proc_macro2::TokenStream { + let map_field_name = get_field_name(fn_name); + let inputs = get_inputs(); + + quote! ( + #[no_mangle] + pub unsafe extern "C" fn #map_field_name(field: *mut polars_core::export::arrow::ffi::ArrowSchema, len: usize) -> polars_core::export::arrow::ffi::ArrowSchema { + #inputs; + let out = #fn_name(&inputs).unwrap(); + polars_core::export::arrow::ffi::export_field_to_c(&out.to_arrow()) + } + ) +} + +fn create_field_function_from_with_dtype( + fn_name: &syn::Ident, + dtype: syn::Ident, +) -> proc_macro2::TokenStream { + let map_field_name = get_field_name(fn_name); + let inputs = get_inputs(); + + quote! ( + #[no_mangle] + pub unsafe extern "C" fn #map_field_name(field: *mut polars_core::export::arrow::ffi::ArrowSchema, len: usize) -> polars_core::export::arrow::ffi::ArrowSchema { + #inputs + + let mapper = polars_plan::dsl::FieldsMapper::new(&inputs); + let dtype = polars_core::datatypes::DataType::#dtype; + let out = mapper.with_dtype(dtype).unwrap(); + polars_core::export::arrow::ffi::export_field_to_c(&out.to_arrow()) + } + ) +} + +#[proc_macro_attribute] +pub fn polars_expr(attr: TokenStream, input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as syn::ItemFn); + + let options = parse_macro_input!(attr as attr::ExprsFunctionOptions); + let expanded_field_fn = if let Some(fn_name) = options.output_type_fn { + create_field_function(&fn_name) + } else if let Some(dtype) = options.output_dtype { + create_field_function_from_with_dtype(&ast.sig.ident, dtype) + } else { + panic!("didn't understand polars_expr attribute") + }; + + let expanded_expr = create_expression_function(ast); + let expanded = quote!( + #expanded_field_fn + + #expanded_expr + ); + TokenStream::from(expanded) +} diff --git a/pyo3-polars-derive/tests/01.rs b/pyo3-polars-derive/tests/01.rs new file mode 100644 index 0000000..cb0082b --- /dev/null +++ b/pyo3-polars-derive/tests/01.rs @@ -0,0 +1,19 @@ +use polars_core::error::PolarsResult; +use polars_core::prelude::{Field, Series}; +use polars_plan::dsl::FieldsMapper; +use pyo3_polars_derive::polars_expr; + +fn horizontal_product_output(input_fields: &[Field]) -> PolarsResult { + FieldsMapper::new(input_fields).map_to_supertype() +} + +#[polars_expr(type_func=horizontal_product_output)] +fn horizontal_product(series: &[Series]) -> PolarsResult { + let mut acc = series[0].clone(); + for s in &series[1..] { + acc = &acc * s + } + Ok(acc) +} + +fn main() {} diff --git a/pyo3-polars-derive/tests/02.rs b/pyo3-polars-derive/tests/02.rs new file mode 100644 index 0000000..fda4347 --- /dev/null +++ b/pyo3-polars-derive/tests/02.rs @@ -0,0 +1,14 @@ +use polars_core::error::PolarsResult; +use polars_core::prelude::Series; +use pyo3_polars_derive::polars_expr; + +#[polars_expr(output_type=Int32)] +fn horizontal_product(series: &[Series]) -> PolarsResult { + let mut acc = series[0].clone(); + for s in &series[1..] { + acc = &acc * s + } + Ok(acc) +} + +fn main() {} diff --git a/pyo3-polars-derive/tests/run.rs b/pyo3-polars-derive/tests/run.rs new file mode 100644 index 0000000..dc9bca6 --- /dev/null +++ b/pyo3-polars-derive/tests/run.rs @@ -0,0 +1,6 @@ +#[test] +fn tests() { + let t = trybuild::TestCases::new(); + t.pass("tests/01.rs"); + t.pass("tests/02.rs"); +} diff --git a/pyo3-polars/Cargo.toml b/pyo3-polars/Cargo.toml index 37713d4..d7e1bc3 100644 --- a/pyo3-polars/Cargo.toml +++ b/pyo3-polars/Cargo.toml @@ -10,15 +10,18 @@ description = "PyO3 bindings to polars" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -polars = { version = "0.32.0", default_features = false } -polars-core = { version = "0.32.0", default_features = false } -polars-plan = { version = "0.32.0", default_features = false, optional = true } -polars-lazy = { version = "0.32.0", default_features = false, optional = true } +polars = { workspace = true, default-features = false } +polars-core = { workspace = true, default-features = false } +polars-ffi = { workspace = true, optional = true } +polars-plan = { workspace = true, optional = true } +polars-lazy = { workspace = true, optional = true } pyo3 = "0.19.0" thiserror = "1" -arrow2 = "0.17.4" ciborium = { version = "0.2.1", optional = true } +pyo3-polars-derive = { version = "*", path = "../pyo3-polars-derive", optional=true } + [features] -lazy = ["polars/serde-lazy", "polars-plan", "polars-lazy", "ciborium"] +lazy = ["polars/serde-lazy", "polars-plan", "polars-lazy/serde", "ciborium"] +derive = ["pyo3-polars-derive", "polars-plan", "polars-ffi"] diff --git a/pyo3-polars/src/derive.rs b/pyo3-polars/src/derive.rs new file mode 100644 index 0000000..12d3f6b --- /dev/null +++ b/pyo3-polars/src/derive.rs @@ -0,0 +1 @@ +pub use pyo3_polars_derive::polars_expr; diff --git a/pyo3-polars/src/error.rs b/pyo3-polars/src/error.rs index 16df9e2..7dea7ba 100644 --- a/pyo3-polars/src/error.rs +++ b/pyo3-polars/src/error.rs @@ -3,7 +3,7 @@ use std::fmt::{Debug, Formatter}; use polars::prelude::PolarsError; use polars_core::error::ArrowError; use pyo3::create_exception; -use pyo3::exceptions::{PyException, PyIOError, PyRuntimeError, PyValueError}; +use pyo3::exceptions::{PyException, PyIOError, PyIndexError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use thiserror::Error; @@ -29,6 +29,7 @@ impl std::convert::From for PyErr { PolarsError::ShapeMismatch(err) => ShapeError::new_err(err.to_string()), PolarsError::SchemaMismatch(err) => SchemaError::new_err(err.to_string()), PolarsError::Io(err) => PyIOError::new_err(err.to_string()), + PolarsError::OutOfBounds(err) => PyIndexError::new_err(err.to_string()), PolarsError::InvalidOperation(err) => PyValueError::new_err(err.to_string()), PolarsError::ArrowError(err) => ArrowErrorException::new_err(format!("{:?}", err)), PolarsError::Duplicate(err) => DuplicateError::new_err(err.to_string()), diff --git a/pyo3-polars/src/export.rs b/pyo3-polars/src/export.rs new file mode 100644 index 0000000..76551b3 --- /dev/null +++ b/pyo3-polars/src/export.rs @@ -0,0 +1,3 @@ +pub use polars_core; +pub use polars_ffi; +pub use polars_plan; diff --git a/pyo3-polars/src/ffi/to_py.rs b/pyo3-polars/src/ffi/to_py.rs index b01125d..9b4fa3a 100644 --- a/pyo3-polars/src/ffi/to_py.rs +++ b/pyo3-polars/src/ffi/to_py.rs @@ -1,4 +1,4 @@ -use arrow2::ffi; +use polars::export::arrow::ffi; use polars::prelude::{ArrayRef, ArrowField}; use pyo3::ffi::Py_uintptr_t; use pyo3::prelude::*; diff --git a/pyo3-polars/src/ffi/to_rust.rs b/pyo3-polars/src/ffi/to_rust.rs index 2a12b58..27f5037 100644 --- a/pyo3-polars/src/ffi/to_rust.rs +++ b/pyo3-polars/src/ffi/to_rust.rs @@ -1,5 +1,5 @@ use crate::error::PyPolarsErr; -use arrow2::ffi; +use polars::export::arrow::ffi; use polars::prelude::*; use pyo3::ffi::Py_uintptr_t; use pyo3::prelude::*; diff --git a/pyo3-polars/src/lib.rs b/pyo3-polars/src/lib.rs index 3edd9df..237b796 100644 --- a/pyo3-polars/src/lib.rs +++ b/pyo3-polars/src/lib.rs @@ -23,7 +23,7 @@ //! //! /// A Python module implemented in Rust. //! #[pymodule] -//! fn extend_polars(_py: Python, m: &PyModule) -> PyResult<()> { +//! fn expression_lib(_py: Python, m: &PyModule) -> PyResult<()> { //! m.add_function(wrap_pyfunction!(my_cool_function, m)?)?; //! Ok(()) //! } @@ -33,7 +33,7 @@ //! //! From `my_python_file.py`. //! ```python -//! from extend_polars import my_cool_function +//! from expression_lib import my_cool_function //! //! df = pl.DataFrame({ //! "foo": [1, 2, None], @@ -41,7 +41,11 @@ //! }) //! out_df = my_cool_function(df) //! ``` +#[cfg(feature = "derive")] +pub mod derive; pub mod error; +#[cfg(feature = "derive")] +pub mod export; mod ffi; use crate::error::PyPolarsErr;