diff --git a/crates/polars-plan/src/dsl/python_udf.rs b/crates/polars-plan/src/dsl/python_udf.rs index a813dbf64e87..c90b3d50d0af 100644 --- a/crates/polars-plan/src/dsl/python_udf.rs +++ b/crates/polars-plan/src/dsl/python_udf.rs @@ -1,6 +1,7 @@ use std::io::Cursor; use std::sync::Arc; +use once_cell::sync::Lazy; use polars_core::datatypes::{DataType, Field}; use polars_core::error::*; use polars_core::frame::column::Column; @@ -25,7 +26,9 @@ pub static mut CALL_COLUMNS_UDF_PYTHON: Option< pub static mut CALL_DF_UDF_PYTHON: Option< fn(s: DataFrame, lambda: &PyObject) -> PolarsResult, > = None; +#[cfg(feature = "serde")] pub(super) const MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes(); +static PYTHON_VERSION_MINOR: Lazy = Lazy::new(get_python_minor_version); #[derive(Clone, Debug)] pub struct PythonFunction(pub PyObject); @@ -60,7 +63,7 @@ impl Serialize for PythonFunction { Python::with_gil(|py| { let pickle = PyModule::import_bound(py, "cloudpickle") .or_else(|_| PyModule::import_bound(py, "pickle")) - .expect("Unable to import 'cloudpickle' or 'pickle'") + .expect("unable to import 'cloudpickle' or 'pickle'") .getattr("dumps") .unwrap(); @@ -86,9 +89,8 @@ impl<'a> Deserialize<'a> for PythonFunction { let bytes = Vec::::deserialize(deserializer)?; Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "cloudpickle") - .or_else(|_| PyModule::import_bound(py, "pickle")) - .expect("Unable to import 'pickle'") + let pickle = PyModule::import_bound(py, "pickle") + .expect("unable to import 'pickle'") .getattr("loads") .unwrap(); let arg = (PyBytes::new_bound(py, &bytes),); @@ -125,19 +127,36 @@ impl PythonUdfExpression { #[cfg(feature = "serde")] pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult> { + // Handle byte mark debug_assert!(buf.starts_with(MAGIC_BYTE_MARK)); - // skip header let buf = &buf[MAGIC_BYTE_MARK.len()..]; + + // Handle pickle metadata + let use_cloudpickle = buf[0]; + if use_cloudpickle != 0 { + let ser_py_version = buf[1]; + let cur_py_version = *PYTHON_VERSION_MINOR; + polars_ensure!( + ser_py_version == cur_py_version, + InvalidOperation: + "current Python version (3.{}) does not match the Python version used to serialize the UDF (3.{})", + cur_py_version, + ser_py_version + ); + } + let buf = &buf[2..]; + + // Load UDF metadata let mut reader = Cursor::new(buf); let (output_type, is_elementwise, returns_scalar): (Option, bool, bool) = ciborium::de::from_reader(&mut reader).map_err(map_err)?; let remainder = &buf[reader.position() as usize..]; + // Load UDF Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "cloudpickle") - .or_else(|_| PyModule::import_bound(py, "pickle")) - .expect("Unable to import 'pickle'") + let pickle = PyModule::import_bound(py, "pickle") + .expect("unable to import 'pickle'") .getattr("loads") .unwrap(); let arg = (PyBytes::new_bound(py, remainder),); @@ -189,26 +208,45 @@ impl ColumnsUdf for PythonUdfExpression { #[cfg(feature = "serde")] fn try_serialize(&self, buf: &mut Vec) -> PolarsResult<()> { + // Write byte marks buf.extend_from_slice(MAGIC_BYTE_MARK); - ciborium::ser::into_writer( - &( - self.output_type.clone(), - self.is_elementwise, - self.returns_scalar, - ), - &mut *buf, - ) - .unwrap(); Python::with_gil(|py| { - let pickle = PyModule::import_bound(py, "cloudpickle") - .or_else(|_| PyModule::import_bound(py, "pickle")) - .expect("Unable to import 'pickle'") + // Try pickle to serialize the UDF, otherwise fall back to cloudpickle. + let pickle = PyModule::import_bound(py, "pickle") + .expect("unable to import 'pickle'") .getattr("dumps") .unwrap(); - let dumped = pickle - .call1((self.python_function.clone(),)) - .map_err(from_pyerr)?; + let pickle_result = pickle.call1((self.python_function.clone(),)); + let (dumped, use_cloudpickle, py_version) = match pickle_result { + Ok(dumped) => (dumped, false, 0), + Err(_) => { + let cloudpickle = PyModule::import_bound(py, "cloudpickle") + .map_err(from_pyerr)? + .getattr("dumps") + .unwrap(); + let dumped = cloudpickle + .call1((self.python_function.clone(),)) + .map_err(from_pyerr)?; + (dumped, true, *PYTHON_VERSION_MINOR) + }, + }; + + // Write pickle metadata + buf.extend_from_slice(&[use_cloudpickle as u8, py_version]); + + // Write UDF metadata + ciborium::ser::into_writer( + &( + self.output_type.clone(), + self.is_elementwise, + self.returns_scalar, + ), + &mut *buf, + ) + .unwrap(); + + // Write UDF let dumped = dumped.extract::().unwrap(); buf.extend_from_slice(&dumped); Ok(()) @@ -298,3 +336,17 @@ impl Expr { } } } + +/// Get the minor Python version from the `sys` module. +fn get_python_minor_version() -> u8 { + Python::with_gil(|py| { + PyModule::import_bound(py, "sys") + .unwrap() + .getattr("version_info") + .unwrap() + .getattr("minor") + .unwrap() + .extract() + .unwrap() + }) +} diff --git a/py-polars/tests/unit/lazyframe/test_serde.py b/py-polars/tests/unit/lazyframe/test_serde.py index a82e389b4583..8ddcbfafd6f6 100644 --- a/py-polars/tests/unit/lazyframe/test_serde.py +++ b/py-polars/tests/unit/lazyframe/test_serde.py @@ -116,3 +116,33 @@ def test_lf_serde_scan(tmp_path: Path) -> None: result = pl.LazyFrame.deserialize(io.BytesIO(ser)) assert_frame_equal(result, lf) assert_frame_equal(result.collect(), df) + + +@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") +def test_lf_serde_version_specific_lambda(monkeypatch: pytest.MonkeyPatch) -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}).select( + pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64) + ) + ser = lf.serialize() + + result = pl.LazyFrame.deserialize(io.BytesIO(ser)) + expected = pl.LazyFrame({"a": [2, 3, 4]}) + assert_frame_equal(result, expected) + + +def custom_function(x: pl.Series) -> pl.Series: + return x + 1 + + +@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning") +def test_lf_serde_version_specific_named_function( + monkeypatch: pytest.MonkeyPatch, +) -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}).select( + pl.col("a").map_batches(custom_function, return_dtype=pl.Int64) + ) + ser = lf.serialize() + + result = pl.LazyFrame.deserialize(io.BytesIO(ser)) + expected = pl.LazyFrame({"a": [2, 3, 4]}) + assert_frame_equal(result, expected)