Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type hints #7532

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies:
- watermark
- polyagamma
- sphinx-remove-toctrees
- mypy=1.5.1
- mypy=1.11.2
- types-cachetools
- pip:
- git+https://github.com/pymc-devs/pymc-sphinx-theme
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- mypy=1.5.1
- mypy=1.11.2
- types-cachetools
- pip:
- numdifftools>=0.9.40
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- mypy=1.5.1
- mypy=1.11.2
- types-cachetools
- pip:
- numdifftools>=0.9.40
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies:
- sphinx>=1.5
- watermark
- sphinx-remove-toctrees
- mypy=1.5.1
- mypy=1.11.2
- types-cachetools
- pip:
- git+https://github.com/pymc-devs/pymc-sphinx-theme
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- mypy=1.5.1
- mypy=1.11.2
- types-cachetools
- pip:
- numdifftools>=0.9.40
Expand Down
5 changes: 2 additions & 3 deletions docs/source/contributing/implementing_distribution.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ The following snippet illustrates how to create a new `RandomVariable`:

from pytensor.tensor.var import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from typing import List, Tuple

# Create your own `RandomVariable`...
class BlahRV(RandomVariable):
Expand All @@ -53,7 +52,7 @@ class BlahRV(RandomVariable):
dtype: str = "floatX"

# A pretty text and LaTeX representation for the RV
_print_name: Tuple[str, str] = ("blah", "\\operatorname{blah}")
_print_name: tuple[str, str] = ("blah", "\\operatorname{blah}")

# If you want to add a custom signature and default values for the
# parameters, do it like this. Otherwise this can be left out.
Expand All @@ -70,7 +69,7 @@ class BlahRV(RandomVariable):
rng: np.random.RandomState,
loc: np.ndarray,
scale: np.ndarray,
size: Tuple[int, ...],
size: tuple[int, ...],
) -> np.ndarray:
return scipy.stats.blah.rvs(loc, scale, random_state=rng, size=size)

Expand Down
58 changes: 29 additions & 29 deletions pymc/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
import re
import subprocess
import sys
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable
import functools


def get_keywords() -> Dict[str, str]:
def get_keywords() -> dict[str, str]:
"""Get the keywords needed to look up the version information."""
# these strings will be replaced by git during git-archive.
# setup.py/versioneer.py will grep for the variable names, so they must
Expand Down Expand Up @@ -75,8 +75,8 @@ class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""


LONG_VERSION_PY: Dict[str, str] = {}
HANDLERS: Dict[str, Dict[str, Callable]] = {}
LONG_VERSION_PY: dict[str, str] = {}
HANDLERS: dict[str, dict[str, Callable]] = {}


def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator
Expand All @@ -91,18 +91,18 @@ def decorate(f: Callable) -> Callable:


def run_command(
commands: List[str],
args: List[str],
cwd: Optional[str] = None,
commands: list[str],
args: list[str],
cwd: str | None = None,
verbose: bool = False,
hide_stderr: bool = False,
env: Optional[Dict[str, str]] = None,
) -> Tuple[Optional[str], Optional[int]]:
env: dict[str, str] | None = None,
) -> tuple[str | None, int | None]:
"""Call the given command(s)."""
assert isinstance(commands, list)
process = None

popen_kwargs: Dict[str, Any] = {}
popen_kwargs: dict[str, Any] = {}
if sys.platform == "win32":
# This hides the console window if pythonw.exe is used
startupinfo = subprocess.STARTUPINFO()
Expand Down Expand Up @@ -142,7 +142,7 @@ def versions_from_parentdir(
parentdir_prefix: str,
root: str,
verbose: bool,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Try to determine the version from the parent directory name.

Source tarballs conventionally unpack into a directory that includes both
Expand All @@ -167,13 +167,13 @@ def versions_from_parentdir(


@register_vcs_handler("git", "get_keywords")
def git_get_keywords(versionfile_abs: str) -> Dict[str, str]:
def git_get_keywords(versionfile_abs: str) -> dict[str, str]:
"""Extract version information from the given file."""
# the code embedded in _version.py can just fetch the value of these
# keywords. When used from setup.py, we don't want to import _version.py,
# so we do it with a regexp instead. This function is not used from
# _version.py.
keywords: Dict[str, str] = {}
keywords: dict[str, str] = {}
try:
with open(versionfile_abs, "r") as fobj:
for line in fobj:
Expand All @@ -196,10 +196,10 @@ def git_get_keywords(versionfile_abs: str) -> Dict[str, str]:

@register_vcs_handler("git", "keywords")
def git_versions_from_keywords(
keywords: Dict[str, str],
keywords: dict[str, str],
tag_prefix: str,
verbose: bool,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Get version information from git keywords."""
if "refnames" not in keywords:
raise NotThisMethod("Short version file found")
Expand Down Expand Up @@ -268,7 +268,7 @@ def git_pieces_from_vcs(
root: str,
verbose: bool,
runner: Callable = run_command
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Get version from 'git describe' in the root of the source tree.

This only gets called if the git-archive 'subst' keywords were *not*
Expand Down Expand Up @@ -308,7 +308,7 @@ def git_pieces_from_vcs(
raise NotThisMethod("'git rev-parse' failed")
full_out = full_out.strip()

pieces: Dict[str, Any] = {}
pieces: dict[str, Any] = {}
pieces["long"] = full_out
pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None
Expand Down Expand Up @@ -400,14 +400,14 @@ def git_pieces_from_vcs(
return pieces


def plus_or_dot(pieces: Dict[str, Any]) -> str:
def plus_or_dot(pieces: dict[str, Any]) -> str:
"""Return a + if we don't already have one, else return a ."""
if "+" in pieces.get("closest-tag", ""):
return "."
return "+"


def render_pep440(pieces: Dict[str, Any]) -> str:
def render_pep440(pieces: dict[str, Any]) -> str:
"""Build up version string, with post-release "local version identifier".

Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
Expand All @@ -432,7 +432,7 @@ def render_pep440(pieces: Dict[str, Any]) -> str:
return rendered


def render_pep440_branch(pieces: Dict[str, Any]) -> str:
def render_pep440_branch(pieces: dict[str, Any]) -> str:
"""TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .

The ".dev0" means not master branch. Note that .dev0 sorts backwards
Expand Down Expand Up @@ -462,7 +462,7 @@ def render_pep440_branch(pieces: Dict[str, Any]) -> str:
return rendered


def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]:
def pep440_split_post(ver: str) -> tuple[str, int | None]:
"""Split pep440 version string at the post-release segment.

Returns the release segments before the post-release and the
Expand All @@ -472,7 +472,7 @@ def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]:
return vc[0], int(vc[1] or 0) if len(vc) == 2 else None


def render_pep440_pre(pieces: Dict[str, Any]) -> str:
def render_pep440_pre(pieces: dict[str, Any]) -> str:
"""TAG[.postN.devDISTANCE] -- No -dirty.

Exceptions:
Expand All @@ -496,7 +496,7 @@ def render_pep440_pre(pieces: Dict[str, Any]) -> str:
return rendered


def render_pep440_post(pieces: Dict[str, Any]) -> str:
def render_pep440_post(pieces: dict[str, Any]) -> str:
"""TAG[.postDISTANCE[.dev0]+gHEX] .

The ".dev0" means dirty. Note that .dev0 sorts backwards
Expand All @@ -523,7 +523,7 @@ def render_pep440_post(pieces: Dict[str, Any]) -> str:
return rendered


def render_pep440_post_branch(pieces: Dict[str, Any]) -> str:
def render_pep440_post_branch(pieces: dict[str, Any]) -> str:
"""TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .

The ".dev0" means not master branch.
Expand Down Expand Up @@ -552,7 +552,7 @@ def render_pep440_post_branch(pieces: Dict[str, Any]) -> str:
return rendered


def render_pep440_old(pieces: Dict[str, Any]) -> str:
def render_pep440_old(pieces: dict[str, Any]) -> str:
"""TAG[.postDISTANCE[.dev0]] .

The ".dev0" means dirty.
Expand All @@ -574,7 +574,7 @@ def render_pep440_old(pieces: Dict[str, Any]) -> str:
return rendered


def render_git_describe(pieces: Dict[str, Any]) -> str:
def render_git_describe(pieces: dict[str, Any]) -> str:
"""TAG[-DISTANCE-gHEX][-dirty].

Like 'git describe --tags --dirty --always'.
Expand All @@ -594,7 +594,7 @@ def render_git_describe(pieces: Dict[str, Any]) -> str:
return rendered


def render_git_describe_long(pieces: Dict[str, Any]) -> str:
def render_git_describe_long(pieces: dict[str, Any]) -> str:
"""TAG-DISTANCE-gHEX[-dirty].

Like 'git describe --tags --dirty --always -long'.
Expand All @@ -614,7 +614,7 @@ def render_git_describe_long(pieces: Dict[str, Any]) -> str:
return rendered


def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]:
def render(pieces: dict[str, Any], style: str) -> dict[str, Any]:
"""Render the given version pieces into the requested style."""
if pieces["error"]:
return {"version": "unknown",
Expand Down Expand Up @@ -650,7 +650,7 @@ def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]:
"date": pieces.get("date")}


def get_versions() -> Dict[str, Any]:
def get_versions() -> dict[str, Any]:
"""Get version information or return default if unable to do so."""
# I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
# __file__, we can work backwards from there to the root. Some
Expand Down
4 changes: 2 additions & 2 deletions pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@

from collections.abc import Mapping, Sequence
from copy import copy
from typing import Optional, TypeAlias, Union
from typing import TypeAlias

import numpy as np

Expand All @@ -85,7 +85,7 @@
RunType: TypeAlias = Run
HAS_MCB = True
except ImportError:
TraceOrBackend = BaseTrace # type: ignore[misc]
TraceOrBackend = BaseTrace # type: ignore[assignment, misc]

Check warning on line 88 in pymc/backends/__init__.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/__init__.py#L88

Added line #L88 was not covered by tests
RunType = type(None) # type: ignore[assignment, misc]


Expand Down
8 changes: 4 additions & 4 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class _DefaultTrace:

Attributes
----------
trace_dict : Dict[str, np.ndarray]
trace_dict : dict[str, np.ndarray]
A dictionary constituting a trace. Should be extracted
after a procedure has filled the `_DefaultTrace` using the
`insert()` method
Expand Down Expand Up @@ -548,7 +548,7 @@ def predictions_to_inference_data(

Parameters
----------
predictions: Dict[str, np.ndarray]
predictions: dict[str, np.ndarray]
The predictions are the return value of :func:`~pymc.sample_posterior_predictive`,
a dictionary of strings (variable names) to numpy ndarrays (draws).
Requires the arrays to follow the convention ``chain, draw, *shape``.
Expand All @@ -559,9 +559,9 @@ def predictions_to_inference_data(
variables must be *removed* from this trace.
model: Model
The pymc model. It can be omitted if within a model context.
coords: Dict[str, array-like[Any]]
coords: dict[str, array-like[Any]]
Coordinates for the variables. Map from coordinate names to coordinate values.
dims: Dict[str, array-like[str]]
dims: dict[str, array-like[str]]
Map from variable name to ordered set of coordinate names.
idata_orig: InferenceData, optional
If supplied, then modify this inference data in place, adding ``predictions`` and
Expand Down
4 changes: 2 additions & 2 deletions pymc/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,11 @@ class MultiTrace:
----------
nchains: int
Number of chains in the `MultiTrace`.
chains: `List[int]`
chains: list[int]
List of chain indices
report: str
Report on the sampling process.
varnames: `List[str]`
varnames: list[str]
List of variable names in the trace(s)
"""

Expand Down
Loading
Loading