diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 10f0ee54228b..4306220f0cea 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -383,6 +383,7 @@ def __getattr__(name: str) -> Any: + print("GETTING NAME", name) # Deprecate re-export of exceptions at top-level if name in dir(exceptions): from polars._utils.deprecation import issue_deprecation_warning @@ -414,3 +415,38 @@ def __getattr__(name: str) -> Any: msg = f"module {__name__!r} has no attribute {name!r}" raise AttributeError(msg) + + +# fork() breaks Polars thread pool. Instead of silently hanging when users do +# this, e.g. by using multiprocessing's footgun default setting on Linux, warn +# them instead: +def __install_postfork_hook() -> None: + def fail(*args: Any, **kwargs: Any) -> None: + message = """\ +Using fork() will cause Polars will result in deadlocks in the child process. +In addition, using fork() with Python in general is a recipe for mysterious +deadlocks and crashes. + +The most likely reason you are seeing this error is because you are using the +multiprocessing crate on Linux, which uses fork() by default. This will be +fixed in Python 3.14. Until then, you want to use the "spawn" context instead. + +See https://docs.pola.rs/user-guide/misc/multiprocessing/ for details. +""" + raise RuntimeError(message) + + def post_hook_child() -> None: + # Switch most public Polars API to fail when called. This won't catch + # _all_ edge cases, but does make it more likely users get told they + # tried to do something broken. + for name in __all__: + if callable(globals()[name]): + globals()[name] = fail + + import os + + if hasattr(os, "register_at_fork"): + os.register_at_fork(after_in_child=post_hook_child) + + +__install_postfork_hook() diff --git a/py-polars/tests/unit/test_polars_import.py b/py-polars/tests/unit/test_polars_import.py index fa1779de3478..1dae9742a135 100644 --- a/py-polars/tests/unit/test_polars_import.py +++ b/py-polars/tests/unit/test_polars_import.py @@ -1,6 +1,8 @@ from __future__ import annotations import compileall +import multiprocessing +import os import subprocess import sys from pathlib import Path @@ -97,3 +99,22 @@ def test_polars_import() -> None: import_time_ms = polars_import_time // 1_000 msg = f"Possible import speed regression; took {import_time_ms}ms\n{df_import}" raise AssertionError(msg) + + +def run_in_child() -> pl.Series: + return pl.Series([1, 2, 3]) + + +@pytest.mark.skipif(not hasattr(os, "fork"), reason="Requires fork()") +def test_fork_safety() -> None: + # Using fork()-based multiprocessing shouldn't work: + with ( + multiprocessing.get_context("fork").Pool(1) as pool, + pytest.raises(RuntimeError, match=r"Using fork\(\) will cause Polars"), + ): + pool.apply(run_in_child) + + # Using forkserver and spawn context should not error out: + for context in ["spawn", "forkserver"]: + with multiprocessing.get_context(context).Pool(1) as pool: + pool.apply(run_in_child)