From f84281e536a38a74af2ea8ee2204bf7a2c44e4e8 Mon Sep 17 00:00:00 2001 From: Andy Date: Fri, 6 Oct 2023 14:31:18 +0200 Subject: [PATCH] fix: type_func --- example/derive_expression/run.py | 8 +++++--- pyo3-polars-derive/src/lib.rs | 9 ++++++--- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/example/derive_expression/run.py b/example/derive_expression/run.py index 392a56d..45c24bd 100644 --- a/example/derive_expression/run.py +++ b/example/derive_expression/run.py @@ -5,15 +5,17 @@ "names": ["Richard", "Alice", "Bob"], "moons": ["full", "half", "red"], "dist_a": [[12, 32, 1], [], [1, -2]], - "dist_b": [[-12, 1], [43], [876, -45, 9]] + "dist_b": [[-12, 1], [43], [876, -45, 9]], + "floats": [5.6, -1245.8, 242.224] }) out = df.with_columns( - pig_latin = pl.col("names").language.pig_latinnify() + pig_latin = pl.col("names").language.pig_latinnify(), ).with_columns( hamming_dist = pl.col("names").dist.hamming_distance("pig_latin"), - jaccard_sim = pl.col("dist_a").dist.jaccard_similarity("dist_b") + jaccard_sim = pl.col("dist_a").dist.jaccard_similarity("dist_b"), + haversine = pl.col("floats").dist.haversine("floats", "floats", "floats", "floats"), ) print(out) diff --git a/pyo3-polars-derive/src/lib.rs b/pyo3-polars-derive/src/lib.rs index f0de3d5..d4b1acc 100644 --- a/pyo3-polars-derive/src/lib.rs +++ b/pyo3-polars-derive/src/lib.rs @@ -41,7 +41,10 @@ fn get_inputs() -> proc_macro2::TokenStream { ) } -fn create_field_function(fn_name: &syn::Ident) -> proc_macro2::TokenStream { +fn create_field_function( + fn_name: &syn::Ident, + dtype_fn_name: &syn::Ident +) -> proc_macro2::TokenStream { let map_field_name = get_field_name(fn_name); let inputs = get_inputs(); @@ -49,7 +52,7 @@ fn create_field_function(fn_name: &syn::Ident) -> proc_macro2::TokenStream { #[no_mangle] pub unsafe extern "C" fn #map_field_name(field: *mut polars_core::export::arrow::ffi::ArrowSchema, len: usize) -> polars_core::export::arrow::ffi::ArrowSchema { #inputs; - let out = #fn_name(&inputs).unwrap(); + let out = #dtype_fn_name(&inputs).unwrap(); polars_core::export::arrow::ffi::export_field_to_c(&out.to_arrow()) } ) @@ -81,7 +84,7 @@ pub fn polars_expr(attr: TokenStream, input: TokenStream) -> TokenStream { let options = parse_macro_input!(attr as attr::ExprsFunctionOptions); let expanded_field_fn = if let Some(fn_name) = options.output_type_fn { - create_field_function(&fn_name) + create_field_function(&ast.sig.ident, &fn_name) } else if let Some(dtype) = options.output_dtype { create_field_function_from_with_dtype(&ast.sig.ident, dtype) } else {