diff --git a/example/derive_expression/expression_lib/Cargo.toml b/example/derive_expression/expression_lib/Cargo.toml index 8c01ac8..23dc2e7 100644 --- a/example/derive_expression/expression_lib/Cargo.toml +++ b/example/derive_expression/expression_lib/Cargo.toml @@ -14,6 +14,3 @@ pyo3 = { version = "0.21", features = ["abi3-py38"] } pyo3-polars = { version = "*", path = "../../../pyo3-polars", features = ["derive"] } rayon = "1.7.0" serde = { version = "1", features = ["derive"] } - -[target.'cfg(target_os = "linux")'.dependencies] -jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] } diff --git a/example/derive_expression/expression_lib/src/lib.rs b/example/derive_expression/expression_lib/src/lib.rs index d5c6766..0e201d3 100644 --- a/example/derive_expression/expression_lib/src/lib.rs +++ b/example/derive_expression/expression_lib/src/lib.rs @@ -1,9 +1,7 @@ +use pyo3_polars::PolarsAllocator; + mod distances; mod expressions; -#[cfg(target_os = "linux")] -use jemallocator::Jemalloc; - #[global_allocator] -#[cfg(target_os = "linux")] -static ALLOC: Jemalloc = Jemalloc; +static ALLOC: PolarsAllocator = PolarsAllocator::new(); diff --git a/example/extend_polars_python_dispatch/extend_polars/src/lib.rs b/example/extend_polars_python_dispatch/extend_polars/src/lib.rs index fac8510..306cfcb 100644 --- a/example/extend_polars_python_dispatch/extend_polars/src/lib.rs +++ b/example/extend_polars_python_dispatch/extend_polars/src/lib.rs @@ -5,7 +5,10 @@ use polars_lazy::frame::IntoLazy; use polars_lazy::prelude::LazyFrame; use pyo3::prelude::*; use pyo3_polars::error::PyPolarsErr; -use pyo3_polars::{PyDataFrame, PyLazyFrame}; +use pyo3_polars::{PolarsAllocator, PyDataFrame, PyLazyFrame}; + +#[global_allocator] +static ALLOC: PolarsAllocator = PolarsAllocator::new(); #[pyfunction] fn parallel_jaccard(pydf: PyDataFrame, col_a: &str, col_b: &str) -> PyResult { diff --git a/pyo3-polars/Cargo.toml b/pyo3-polars/Cargo.toml index 6a9476a..be7754e 100644 --- a/pyo3-polars/Cargo.toml +++ b/pyo3-polars/Cargo.toml @@ -11,6 +11,8 @@ description = "Expression plugins and PyO3 types for polars" [dependencies] ciborium = { version = "0.2.1", optional = true } +libc = "0.2" # pyo3 depends on libc already, so this does not introduce an extra dependence. +once_cell = "1" polars = { workspace = true, default-features = false } polars-core = { workspace = true, default-features = false } polars-ffi = { workspace = true, optional = true } diff --git a/pyo3-polars/src/alloc.rs b/pyo3-polars/src/alloc.rs new file mode 100644 index 0000000..b2b9898 --- /dev/null +++ b/pyo3-polars/src/alloc.rs @@ -0,0 +1,115 @@ +use std::alloc::{GlobalAlloc, Layout, System}; +use std::ffi::c_char; + +use once_cell::race::OnceRef; +use pyo3::ffi::{PyCapsule_Import, Py_IsInitialized}; +use pyo3::Python; + +unsafe extern "C" fn fallback_alloc(size: usize, align: usize) -> *mut u8 { + System.alloc(Layout::from_size_align_unchecked(size, align)) +} + +unsafe extern "C" fn fallback_dealloc(ptr: *mut u8, size: usize, align: usize) { + System.dealloc(ptr, Layout::from_size_align_unchecked(size, align)) +} + +unsafe extern "C" fn fallback_alloc_zeroed(size: usize, align: usize) -> *mut u8 { + System.alloc_zeroed(Layout::from_size_align_unchecked(size, align)) +} + +unsafe extern "C" fn fallback_realloc( + ptr: *mut u8, + size: usize, + align: usize, + new_size: usize, +) -> *mut u8 { + System.realloc( + ptr, + Layout::from_size_align_unchecked(size, align), + new_size, + ) +} + +#[repr(C)] +struct AllocatorCapsule { + alloc: unsafe extern "C" fn(usize, usize) -> *mut u8, + dealloc: unsafe extern "C" fn(*mut u8, usize, usize), + alloc_zeroed: unsafe extern "C" fn(usize, usize) -> *mut u8, + realloc: unsafe extern "C" fn(*mut u8, usize, usize, usize) -> *mut u8, +} + +static FALLBACK_ALLOCATOR_CAPSULE: AllocatorCapsule = AllocatorCapsule { + alloc: fallback_alloc, + alloc_zeroed: fallback_alloc_zeroed, + dealloc: fallback_dealloc, + realloc: fallback_realloc, +}; + +static ALLOCATOR_CAPSULE_NAME: &[u8] = b"polars.polars._allocator\0"; + +/// A memory allocator that relays allocations to the allocator used by Polars. +/// +/// You can use it as the global memory allocator: +/// +/// ```rust +/// use pyo3_polars::PolarsAllocator; +/// +/// #[global_allocator] +/// static ALLOC: PolarsAllocator = PolarsAllocator::new(); +/// ``` +/// +/// If the allocator capsule (`polars.polars._allocator`) is not available, +/// this allocator fallbacks to [`std::alloc::System`]. +pub struct PolarsAllocator(OnceRef<'static, AllocatorCapsule>); + +impl PolarsAllocator { + fn get_allocator(&self) -> &'static AllocatorCapsule { + // Do not allocate in this function, + // otherwise it will cause infinite recursion. + self.0.get_or_init(|| { + let r = (unsafe { Py_IsInitialized() } != 0) + .then(|| { + Python::with_gil(|_| unsafe { + (PyCapsule_Import(ALLOCATOR_CAPSULE_NAME.as_ptr() as *const c_char, 0) + as *const AllocatorCapsule) + .as_ref() + }) + }) + .flatten(); + #[cfg(debug_assertions)] + if r.is_none() { + // Do not use eprintln; it may alloc. + let msg = b"failed to get allocator capsule\n"; + unsafe { libc::write(2, msg.as_ptr() as *const libc::c_void, msg.len()) }; + } + r.unwrap_or(&FALLBACK_ALLOCATOR_CAPSULE) + }) + } + + /// Create a `PolarsAllocator`. + pub const fn new() -> Self { + PolarsAllocator(OnceRef::new()) + } +} + +unsafe impl GlobalAlloc for PolarsAllocator { + #[inline] + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + (self.get_allocator().alloc)(layout.size(), layout.align()) + } + + #[inline] + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + (self.get_allocator().dealloc)(ptr, layout.size(), layout.align()); + } + + #[inline] + unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { + (self.get_allocator().alloc_zeroed)(layout.size(), layout.align()) + } + + #[inline] + unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 { + (self.get_allocator().realloc)(ptr, layout.size(), layout.align(), new_size) + } +} diff --git a/pyo3-polars/src/lib.rs b/pyo3-polars/src/lib.rs index 206fd90..4c9f6c1 100644 --- a/pyo3-polars/src/lib.rs +++ b/pyo3-polars/src/lib.rs @@ -41,6 +41,7 @@ //! }) //! out_df = my_cool_function(df) //! ``` +mod alloc; #[cfg(feature = "derive")] pub mod derive; pub mod error; @@ -48,6 +49,7 @@ pub mod error; pub mod export; mod ffi; +pub use crate::alloc::PolarsAllocator; use crate::error::PyPolarsErr; use crate::ffi::to_py::to_py_array; use polars::export::arrow;