Skip to content

Commit

Permalink
Fix some type hinting to help with migrating Distribution (#7484)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasaarholt authored Oct 9, 2024
1 parent 1457626 commit 938aff4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
13 changes: 8 additions & 5 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@
)
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.variable import TensorConstant
from pytensor.tensor.variable import TensorConstant, TensorVariable

from pymc.logprob.abstract import _logprob_helper
from pymc.logprob.basic import icdf
from pymc.logprob.basic import TensorLike, icdf
from pymc.pytensorf import normalize_rng_param

try:
Expand Down Expand Up @@ -148,7 +148,7 @@ class BoundedContinuous(Continuous):
"""Base class for bounded continuous distributions."""

# Indices of the arguments that define the lower and upper bounds of the distribution
bound_args_indices: list[int] | None = None
bound_args_indices: tuple[int | None, int | None] | None = None


@_default_transform.register(PositiveContinuous)
Expand Down Expand Up @@ -210,7 +210,9 @@ def assert_negative_support(var, label, distname, value=-1e-6):
return Assert(msg)(var, pt.all(pt.ge(var, 0.0)))


def get_tau_sigma(tau=None, sigma=None):
def get_tau_sigma(
tau: TensorLike | None = None, sigma: TensorLike | None = None
) -> tuple[TensorVariable, TensorVariable]:
r"""
Find precision and standard deviation.
Expand Down Expand Up @@ -239,13 +241,14 @@ def get_tau_sigma(tau=None, sigma=None):
sigma = pt.as_tensor_variable(1.0)
tau = pt.as_tensor_variable(1.0)
elif tau is None:
assert sigma is not None # Just for type checker
sigma = pt.as_tensor_variable(sigma)
# Keep tau negative, if sigma was negative, so that it will
# fail when used
tau = (sigma**-2.0) * pt.sign(sigma)
else:
tau = pt.as_tensor_variable(tau)
# Keep tau negative, if sigma was negative, so that it will
# Keep sigma negative, if tau was negative, so that it will
# fail when used
sigma = pt.abs(tau) ** -0.5 * pt.sign(tau)

Expand Down
10 changes: 7 additions & 3 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from abc import ABCMeta
from collections.abc import Callable, Sequence
from functools import singledispatch
from typing import TypeAlias
from typing import Any, TypeAlias

import numpy as np

Expand Down Expand Up @@ -423,8 +423,12 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) ->
class Distribution(metaclass=DistributionMeta):
"""Statistical distribution."""

rv_op: [RandomVariable, SymbolicRandomVariable] = None
rv_type: MetaType = None
# rv_op and _type are set to None via the DistributionMeta.__new__
# if not specified as class attributes in subclasses of Distribution.
# rv_op can either be a class (see the Normal class) or a method
# (see the Censored class), both callable to return a TensorVariable.
rv_op: Any = None
rv_type: MetaType | None = None

def __new__(
cls,
Expand Down

0 comments on commit 938aff4

Please sign in to comment.