Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 committed Jul 17, 2024
1 parent c436953 commit f869040
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _copy_to_script_wrapper(fn):
import pyro.params.param_store
from pyro.ops.provenance import detach_provenance
from pyro.poutine.messenger import Messenger
from pyro.poutine.runtime import _PYRO_PARAM_STORE
from pyro.poutine.runtime import _PYRO_PARAM_STORE, InferDict

_MODULE_LOCAL_PARAMS: bool = False

Expand Down Expand Up @@ -234,6 +234,10 @@ def __get__(
return value


class _PyroSampleInferDict(InferDict):
_original_pyrosample_dist: pyro.distributions.Distribution


class PyroSamplePlateScope(Messenger):
"""
Handler for executing PyroSample statements in a more intuitive plate context.
Expand All @@ -243,7 +247,7 @@ def __init__(self, allowed_plates: Iterable[str] = ()):
self._inner_allowed_plates = frozenset(allowed_plates)

def __enter__(self):
self._plates: frozenset[str] = (
self._plates = (
frozenset(p.name for p in pyro.poutine.runtime.get_plates())
| self._inner_allowed_plates
)
Expand All @@ -255,7 +259,7 @@ def _is_local_plate(self, m: Messenger) -> bool:
and m.name not in self._plates
)

def _pyro_sample(self, msg):
def _pyro_sample(self, msg) -> None:
if not msg["infer"].get("_original_pyrosample_dist", None):
return
msg["stop"] = True
Expand Down Expand Up @@ -658,7 +662,9 @@ def __getattr__(self, name: str) -> Any:
else pyro.sample(
fullname,
prior,
infer={"_original_pyrosample_dist": prior},
infer=_PyroSampleInferDict(
_original_pyrosample_dist=prior
),
)
)
context.set(fullname, value)
Expand Down

0 comments on commit f869040

Please sign in to comment.