From d4ce8563be135ea12cc1d38486792462a1ffdc06 Mon Sep 17 00:00:00 2001 From: Daniel Shields Date: Fri, 30 Aug 2024 11:29:52 -0500 Subject: [PATCH] Remove unnecessary _is_node optimization and clean up type hints. --- src/uberjob/_plan.py | 11 +++-------- src/uberjob/_registry.py | 6 +++--- src/uberjob/_rendering.py | 11 +++++------ src/uberjob/_testing/test_mounted_file_store.py | 4 ++-- src/uberjob/_util/validation.py | 6 +++--- src/uberjob/stores/_mounted_store.py | 4 ++-- 6 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/uberjob/_plan.py b/src/uberjob/_plan.py index 9280ea5..dca52ac 100644 --- a/src/uberjob/_plan.py +++ b/src/uberjob/_plan.py @@ -40,11 +40,6 @@ } -def _is_node(value) -> bool: - """Efficiently determines whether the given value is a :class:`~uberjob.graph.Node`.""" - return type(value) in (Call, Literal) - - class Plan: """Represents a symbolic call graph.""" @@ -90,7 +85,7 @@ def lit(self, value) -> Literal: :param value: The literal value. :return: The symbolic literal value. """ - if _is_node(value): + if isinstance(value, Node): raise TypeError(f"The value is already a {Node.__name__}.") literal = Literal(value, scope=self._scope) self.graph.add_node(literal) @@ -118,12 +113,12 @@ def recurse(root): if gather_fn is not None: items = root.items() if root_type is dict else root children = [recurse(item) for item in items] - if any(_is_node(child) for child in children): + if any(isinstance(child, Node) for child in children): return self._call(stack_frame, gather_fn, *children) return root value = recurse(value) - return value if _is_node(value) else self.lit(value) + return value if isinstance(value, Node) else self.lit(value) def gather(self, value) -> Node: """ diff --git a/src/uberjob/_registry.py b/src/uberjob/_registry.py index c024443..6908282 100644 --- a/src/uberjob/_registry.py +++ b/src/uberjob/_registry.py @@ -14,7 +14,7 @@ # limitations under the License. # import copy -import typing +from collections.abc import Iterable, KeysView from uberjob._builtins import source from uberjob._plan import Node, Plan @@ -99,7 +99,7 @@ def get(self, node: Node) -> ValueStore | None: v = self.mapping.get(node) return v.value_store if v else None - def keys(self) -> typing.KeysView[Node]: + def keys(self) -> KeysView[Node]: """ Get all registered :class:`~uberjob.graph.Node` instances. @@ -123,7 +123,7 @@ def items(self) -> list[tuple[Node, ValueStore]]: """ return [(k, v.value_store) for k, v in self.mapping.items()] - def __iter__(self) -> typing.Iterable[Node]: + def __iter__(self) -> Iterable[Node]: """ Get all registered :class:`~uberjob.graph.Node` instances. diff --git a/src/uberjob/_rendering.py b/src/uberjob/_rendering.py index d7a9a5d..46745f5 100644 --- a/src/uberjob/_rendering.py +++ b/src/uberjob/_rendering.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import typing -from collections import OrderedDict +from collections.abc import Callable from uberjob._plan import Plan from uberjob._registry import Registry @@ -35,7 +34,7 @@ GRAY = (0.4, 0.4, 0.4) -def default_style(registry: Registry = None): +def default_style(registry: Registry | None = None): import nxv if registry is None: @@ -132,8 +131,8 @@ class Scope: def render( plan: Plan | Graph | tuple[Plan, Node | None], *, - registry: Registry = None, - predicate: typing.Callable[[Node, dict], bool] = None, + registry: Registry | None = None, + predicate: None | Callable[[Node, dict], bool] = None, level: int | None = None, format: str | None = None ) -> bytes | None: @@ -165,7 +164,7 @@ def render( ) if level is not None: - scope_groups = OrderedDict() + scope_groups = {} for u in graph.nodes(): scope = u.scope if scope: diff --git a/src/uberjob/_testing/test_mounted_file_store.py b/src/uberjob/_testing/test_mounted_file_store.py index 36f3cbf..46f6dfb 100644 --- a/src/uberjob/_testing/test_mounted_file_store.py +++ b/src/uberjob/_testing/test_mounted_file_store.py @@ -14,7 +14,7 @@ # limitations under the License. # import datetime as dt -import typing +from collections.abc import Callable from uberjob._testing.test_store import TestStore from uberjob._util import repr_helper @@ -23,7 +23,7 @@ class TestMountedFileStore(MountedStore): - def __init__(self, create_file_store: typing.Callable[[str], FileStore]): + def __init__(self, create_file_store: Callable[[str], FileStore]): super().__init__(create_file_store) self.remote_store = TestStore() diff --git a/src/uberjob/_util/validation.py b/src/uberjob/_util/validation.py index 1848b11..baf40e1 100644 --- a/src/uberjob/_util/validation.py +++ b/src/uberjob/_util/validation.py @@ -14,7 +14,7 @@ # limitations under the License. # import inspect -import typing +from collections.abc import Callable from functools import lru_cache from uberjob._util import fully_qualified_name @@ -57,14 +57,14 @@ def assert_is_instance( @lru_cache(4096) -def try_get_signature(fn: typing.Callable): +def try_get_signature(fn: Callable): try: return inspect.signature(fn) except ValueError: return None -def assert_can_bind(fn: typing.Callable, *args, **kwargs): +def assert_can_bind(fn: Callable, *args, **kwargs): sig = try_get_signature(fn) if sig is None: return diff --git a/src/uberjob/stores/_mounted_store.py b/src/uberjob/stores/_mounted_store.py index 51f99b2..6b5e6ca 100644 --- a/src/uberjob/stores/_mounted_store.py +++ b/src/uberjob/stores/_mounted_store.py @@ -15,8 +15,8 @@ # import os import tempfile -import typing from abc import ABC, abstractmethod +from collections.abc import Callable from contextlib import contextmanager from uberjob._util import repr_helper @@ -38,7 +38,7 @@ class MountedStore(ValueStore, ABC): __slots__ = ("create_store",) - def __init__(self, create_store: typing.Callable[[str], ValueStore]): + def __init__(self, create_store: Callable[[str], ValueStore]): self.create_store = create_store """Creates an instance of the underlying :class:`~uberjob.ValueStore` for the given path."""