Skip to content

Commit

Permalink
feat: Check Python version when deserializing UDFs (#19175)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Oct 10, 2024
1 parent 10bd047 commit 9162f67
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 23 deletions.
98 changes: 75 additions & 23 deletions crates/polars-plan/src/dsl/python_udf.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::io::Cursor;
use std::sync::Arc;

use once_cell::sync::Lazy;
use polars_core::datatypes::{DataType, Field};
use polars_core::error::*;
use polars_core::frame::column::Column;
Expand All @@ -25,7 +26,9 @@ pub static mut CALL_COLUMNS_UDF_PYTHON: Option<
pub static mut CALL_DF_UDF_PYTHON: Option<
fn(s: DataFrame, lambda: &PyObject) -> PolarsResult<DataFrame>,
> = None;
#[cfg(feature = "serde")]
pub(super) const MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes();
static PYTHON_VERSION_MINOR: Lazy<u8> = Lazy::new(get_python_minor_version);

#[derive(Clone, Debug)]
pub struct PythonFunction(pub PyObject);
Expand Down Expand Up @@ -60,7 +63,7 @@ impl Serialize for PythonFunction {
Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "cloudpickle")
.or_else(|_| PyModule::import_bound(py, "pickle"))
.expect("Unable to import 'cloudpickle' or 'pickle'")
.expect("unable to import 'cloudpickle' or 'pickle'")
.getattr("dumps")
.unwrap();

Expand All @@ -86,9 +89,8 @@ impl<'a> Deserialize<'a> for PythonFunction {
let bytes = Vec::<u8>::deserialize(deserializer)?;

Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "cloudpickle")
.or_else(|_| PyModule::import_bound(py, "pickle"))
.expect("Unable to import 'pickle'")
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("loads")
.unwrap();
let arg = (PyBytes::new_bound(py, &bytes),);
Expand Down Expand Up @@ -125,19 +127,36 @@ impl PythonUdfExpression {

#[cfg(feature = "serde")]
pub(crate) fn try_deserialize(buf: &[u8]) -> PolarsResult<Arc<dyn ColumnsUdf>> {
// Handle byte mark
debug_assert!(buf.starts_with(MAGIC_BYTE_MARK));
// skip header
let buf = &buf[MAGIC_BYTE_MARK.len()..];

// Handle pickle metadata
let use_cloudpickle = buf[0];
if use_cloudpickle != 0 {
let ser_py_version = buf[1];
let cur_py_version = *PYTHON_VERSION_MINOR;
polars_ensure!(
ser_py_version == cur_py_version,
InvalidOperation:
"current Python version (3.{}) does not match the Python version used to serialize the UDF (3.{})",
cur_py_version,
ser_py_version
);
}
let buf = &buf[2..];

// Load UDF metadata
let mut reader = Cursor::new(buf);
let (output_type, is_elementwise, returns_scalar): (Option<DataType>, bool, bool) =
ciborium::de::from_reader(&mut reader).map_err(map_err)?;

let remainder = &buf[reader.position() as usize..];

// Load UDF
Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "cloudpickle")
.or_else(|_| PyModule::import_bound(py, "pickle"))
.expect("Unable to import 'pickle'")
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("loads")
.unwrap();
let arg = (PyBytes::new_bound(py, remainder),);
Expand Down Expand Up @@ -189,26 +208,45 @@ impl ColumnsUdf for PythonUdfExpression {

#[cfg(feature = "serde")]
fn try_serialize(&self, buf: &mut Vec<u8>) -> PolarsResult<()> {
// Write byte marks
buf.extend_from_slice(MAGIC_BYTE_MARK);
ciborium::ser::into_writer(
&(
self.output_type.clone(),
self.is_elementwise,
self.returns_scalar,
),
&mut *buf,
)
.unwrap();

Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "cloudpickle")
.or_else(|_| PyModule::import_bound(py, "pickle"))
.expect("Unable to import 'pickle'")
// Try pickle to serialize the UDF, otherwise fall back to cloudpickle.
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("dumps")
.unwrap();
let dumped = pickle
.call1((self.python_function.clone(),))
.map_err(from_pyerr)?;
let pickle_result = pickle.call1((self.python_function.clone(),));
let (dumped, use_cloudpickle, py_version) = match pickle_result {
Ok(dumped) => (dumped, false, 0),
Err(_) => {
let cloudpickle = PyModule::import_bound(py, "cloudpickle")
.map_err(from_pyerr)?
.getattr("dumps")
.unwrap();
let dumped = cloudpickle
.call1((self.python_function.clone(),))
.map_err(from_pyerr)?;
(dumped, true, *PYTHON_VERSION_MINOR)
},
};

// Write pickle metadata
buf.extend_from_slice(&[use_cloudpickle as u8, py_version]);

// Write UDF metadata
ciborium::ser::into_writer(
&(
self.output_type.clone(),
self.is_elementwise,
self.returns_scalar,
),
&mut *buf,
)
.unwrap();

// Write UDF
let dumped = dumped.extract::<PyBackedBytes>().unwrap();
buf.extend_from_slice(&dumped);
Ok(())
Expand Down Expand Up @@ -298,3 +336,17 @@ impl Expr {
}
}
}

/// Get the minor Python version from the `sys` module.
fn get_python_minor_version() -> u8 {
Python::with_gil(|py| {
PyModule::import_bound(py, "sys")
.unwrap()
.getattr("version_info")
.unwrap()
.getattr("minor")
.unwrap()
.extract()
.unwrap()
})
}
30 changes: 30 additions & 0 deletions py-polars/tests/unit/lazyframe/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,33 @@ def test_lf_serde_scan(tmp_path: Path) -> None:
result = pl.LazyFrame.deserialize(io.BytesIO(ser))
assert_frame_equal(result, lf)
assert_frame_equal(result.collect(), df)


@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
def test_lf_serde_version_specific_lambda(monkeypatch: pytest.MonkeyPatch) -> None:
lf = pl.LazyFrame({"a": [1, 2, 3]}).select(
pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64)
)
ser = lf.serialize()

result = pl.LazyFrame.deserialize(io.BytesIO(ser))
expected = pl.LazyFrame({"a": [2, 3, 4]})
assert_frame_equal(result, expected)


def custom_function(x: pl.Series) -> pl.Series:
return x + 1


@pytest.mark.filterwarnings("ignore::polars.exceptions.PolarsInefficientMapWarning")
def test_lf_serde_version_specific_named_function(
monkeypatch: pytest.MonkeyPatch,
) -> None:
lf = pl.LazyFrame({"a": [1, 2, 3]}).select(
pl.col("a").map_batches(custom_function, return_dtype=pl.Int64)
)
ser = lf.serialize()

result = pl.LazyFrame.deserialize(io.BytesIO(ser))
expected = pl.LazyFrame({"a": [2, 3, 4]})
assert_frame_equal(result, expected)

0 comments on commit 9162f67

Please sign in to comment.