diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 4107a498957b9a..cae508e773d8bd 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -241,11 +241,20 @@ core: allowed_deserialization_classes: description: | What classes can be imported during deserialization. This is a multi line value. - The individual items will be parsed as regexp. Python built-in classes (like dict) - are always allowed. Bare "." will be replaced so you can set airflow.* . + The individual items will be parsed as a pattern to a glob function. + Python built-in classes (like dict) are always allowed. version_added: 2.5.0 type: string - default: 'airflow\..*' + default: 'airflow.*' + example: ~ + allowed_deserialization_classes_regexp: + description: | + What classes can be imported during deserialization. This is a multi line value. + The individual items will be parsed as regexp patterns. + This is a secondary option to ``allowed_deserialization_classes``. + version_added: 2.8.1 + type: string + default: '' example: ~ killed_task_cleanup_time: description: | diff --git a/airflow/config_templates/unit_tests.cfg b/airflow/config_templates/unit_tests.cfg index 69c2d65bba0f42..42055b9d9c7d01 100644 --- a/airflow/config_templates/unit_tests.cfg +++ b/airflow/config_templates/unit_tests.cfg @@ -58,7 +58,7 @@ unit_test_mode = True # We want to use a shorter timeout for task cleanup killed_task_cleanup_time = 5 # We only allow our own classes to be deserialized in tests -allowed_deserialization_classes = airflow\..* tests\..* +allowed_deserialization_classes = airflow.* tests.* [database] diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py index 23d67e6162fcd1..a214acc9a6677d 100644 --- a/airflow/serialization/serde.py +++ b/airflow/serialization/serde.py @@ -22,6 +22,7 @@ import functools import logging import sys +from fnmatch import fnmatch from importlib import import_module from typing import TYPE_CHECKING, Any, Pattern, TypeVar, Union, cast @@ -241,7 +242,6 @@ def deserialize(o: T | None, full=True, type_hint: Any = None) -> object: # only return string representation if not full: return _stringify(classname, version, value) - if not _match(classname) and classname not in _extra_allowed: raise ImportError( f"{classname} was not found in allow list for deserialization imports. " @@ -288,7 +288,22 @@ def _convert(old: dict) -> dict: def _match(classname: str) -> bool: - return any(p.match(classname) is not None for p in _get_patterns()) + """Checks if the given classname matches a path pattern either using glob format or regexp format.""" + return _match_glob(classname) or _match_regexp(classname) + + +@functools.lru_cache(maxsize=None) +def _match_glob(classname: str): + """Checks if the given classname matches a pattern from allowed_deserialization_classes using glob syntax.""" + patterns = _get_patterns() + return any(fnmatch(classname, p.pattern) for p in patterns) + + +@functools.lru_cache(maxsize=None) +def _match_regexp(classname: str): + """Checks if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp.""" + patterns = _get_regexp_patterns() + return any(p.match(classname) is not None for p in patterns) def _stringify(classname: str, version: int, value: T | None) -> str: @@ -359,8 +374,12 @@ def _register(): @functools.lru_cache(maxsize=None) def _get_patterns() -> list[Pattern]: - patterns = conf.get("core", "allowed_deserialization_classes").split() - return [re2.compile(re2.sub(r"(\w)\.", r"\1\..", p)) for p in patterns] + return [re2.compile(p) for p in conf.get("core", "allowed_deserialization_classes").split()] + + +@functools.lru_cache(maxsize=None) +def _get_regexp_patterns() -> list[Pattern]: + return [re2.compile(p) for p in conf.get("core", "allowed_deserialization_classes_regexp").split()] _register() diff --git a/newsfragments/36147.significant.rst b/newsfragments/36147.significant.rst new file mode 100644 index 00000000000000..105e1b74d454aa --- /dev/null +++ b/newsfragments/36147.significant.rst @@ -0,0 +1,11 @@ +The ``allowed_deserialization_classes`` flag now follows a glob pattern. + +For example if one wants to add the class ``airflow.tests.custom_class`` to the +``allowed_deserialization_classes`` list, it can be done by writing the full class +name (``airflow.tests.custom_class``) or a pattern such as the ones used in glob +search (e.g., ``airflow.*``, ``airflow.tests.*``). + +If you currently use a custom regexp path make sure to rewrite it as a glob pattern. + +Alternatively, if you still wish to match it as a regexp pattern, add it under the new +list ``allowed_deserialization_classes_regexp`` instead. diff --git a/tests/serialization/test_serde.py b/tests/serialization/test_serde.py index dc3d3faf1ebacd..bd39ac71663946 100644 --- a/tests/serialization/test_serde.py +++ b/tests/serialization/test_serde.py @@ -33,7 +33,10 @@ SCHEMA_ID, VERSION, _get_patterns, + _get_regexp_patterns, _match, + _match_glob, + _match_regexp, deserialize, serialize, ) @@ -44,10 +47,16 @@ @pytest.fixture() def recalculate_patterns(): _get_patterns.cache_clear() + _get_regexp_patterns.cache_clear() + _match_glob.cache_clear() + _match_regexp.cache_clear() try: yield finally: _get_patterns.cache_clear() + _get_regexp_patterns.cache_clear() + _match_glob.cache_clear() + _match_regexp.cache_clear() class Z: @@ -218,7 +227,7 @@ def test_serder_dataclass(self): @conf_vars( { - ("core", "allowed_deserialization_classes"): "airflow[.].*", + ("core", "allowed_deserialization_classes"): "airflow.*", } ) @pytest.mark.usefixtures("recalculate_patterns") @@ -232,13 +241,54 @@ def test_allow_list_for_imports(self): @conf_vars( { - ("core", "allowed_deserialization_classes"): "tests.*", + ("core", "allowed_deserialization_classes"): "tests.airflow.*", } ) @pytest.mark.usefixtures("recalculate_patterns") - def test_allow_list_replace(self): + def test_allow_list_match(self): assert _match("tests.airflow.deep") - assert _match("testsfault") is False + assert _match("tests.wrongpath") is False + + @conf_vars( + { + ("core", "allowed_deserialization_classes"): "tests.airflow.deep", + } + ) + @pytest.mark.usefixtures("recalculate_patterns") + def test_allow_list_match_class(self): + """Test the match function when passing a full classname as + allowed_deserialization_classes + """ + assert _match("tests.airflow.deep") + assert _match("tests.airflow.FALSE") is False + + @conf_vars( + { + ("core", "allowed_deserialization_classes"): "", + ("core", "allowed_deserialization_classes_regexp"): "tests\.airflow\..", + } + ) + @pytest.mark.usefixtures("recalculate_patterns") + def test_allow_list_match_regexp(self): + """Test the match function when passing a path as + allowed_deserialization_classes_regexp with no glob pattern defined + """ + assert _match("tests.airflow.deep") + assert _match("tests.wrongpath") is False + + @conf_vars( + { + ("core", "allowed_deserialization_classes"): "", + ("core", "allowed_deserialization_classes_regexp"): "tests\.airflow\.deep", + } + ) + @pytest.mark.usefixtures("recalculate_patterns") + def test_allow_list_match_class_regexp(self): + """Test the match function when passing a full classname as + allowed_deserialization_classes_regexp with no glob pattern defined + """ + assert _match("tests.airflow.deep") + assert _match("tests.airflow.FALSE") is False def test_incompatible_version(self): data = dict(