From f86904091daa0fff0fd445d969605e0026cfe98a Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 17 Jul 2024 15:08:24 -0400 Subject: [PATCH] lint --- pyro/nn/module.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pyro/nn/module.py b/pyro/nn/module.py index 67810c95e2..2c691cc75c 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -56,7 +56,7 @@ def _copy_to_script_wrapper(fn): import pyro.params.param_store from pyro.ops.provenance import detach_provenance from pyro.poutine.messenger import Messenger -from pyro.poutine.runtime import _PYRO_PARAM_STORE +from pyro.poutine.runtime import _PYRO_PARAM_STORE, InferDict _MODULE_LOCAL_PARAMS: bool = False @@ -234,6 +234,10 @@ def __get__( return value +class _PyroSampleInferDict(InferDict): + _original_pyrosample_dist: pyro.distributions.Distribution + + class PyroSamplePlateScope(Messenger): """ Handler for executing PyroSample statements in a more intuitive plate context. @@ -243,7 +247,7 @@ def __init__(self, allowed_plates: Iterable[str] = ()): self._inner_allowed_plates = frozenset(allowed_plates) def __enter__(self): - self._plates: frozenset[str] = ( + self._plates = ( frozenset(p.name for p in pyro.poutine.runtime.get_plates()) | self._inner_allowed_plates ) @@ -255,7 +259,7 @@ def _is_local_plate(self, m: Messenger) -> bool: and m.name not in self._plates ) - def _pyro_sample(self, msg): + def _pyro_sample(self, msg) -> None: if not msg["infer"].get("_original_pyrosample_dist", None): return msg["stop"] = True @@ -658,7 +662,9 @@ def __getattr__(self, name: str) -> Any: else pyro.sample( fullname, prior, - infer={"_original_pyrosample_dist": prior}, + infer=_PyroSampleInferDict( + _original_pyrosample_dist=prior + ), ) ) context.set(fullname, value)