Skip to content

Commit

Permalink
Maintenance fixes (#3398)
Browse files Browse the repository at this point in the history
* TorchFix issues and attrgetter

* parameter name fix
  • Loading branch information
ordabayevy authored Sep 19, 2024
1 parent e914e19 commit 88ae262
Show file tree
Hide file tree
Showing 14 changed files with 55 additions and 83 deletions.
2 changes: 1 addition & 1 deletion pyro/distributions/omt_mvn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def backward(ctx, grad_output):
diff_L_ab = 0.5 * sum_leftmost(g_ja * epsilon_jb + g_R_inv * z_ja, -2)

Sigma_inv = torch.mm(R_inv, R_inv.t())
V, D, _ = torch.svd(Sigma_inv + jitter)
V, D, _ = torch.linalg.svd(Sigma_inv + jitter)
D_outer = D.unsqueeze(-1) + D.unsqueeze(0)

expand_tuple = tuple([-1] * (z.dim() - 1) + [dim, dim])
Expand Down
2 changes: 1 addition & 1 deletion pyro/distributions/transforms/householder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, u_unnormed=None):
# Construct normalized vectors for Householder transform
def u(self):
u_unnormed = self.u_unnormed() if callable(self.u_unnormed) else self.u_unnormed
norm = torch.norm(u_unnormed, p=2, dim=-1, keepdim=True)
norm = torch.linalg.norm(u_unnormed, ord=2, dim=-1, keepdim=True)
return torch.div(u_unnormed, norm)

def _call(self, x):
Expand Down
4 changes: 2 additions & 2 deletions pyro/distributions/transforms/sylvester.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def Q(self, x):
u = self.u()
partial_Q = torch.eye(
self.input_dim, dtype=x.dtype, layout=x.layout, device=x.device
) - 2.0 * torch.ger(u[0], u[0])
) - 2.0 * torch.outer(u[0], u[0])

for idx in range(1, self.u_unnormed.size(-2)):
partial_Q = torch.matmul(
partial_Q, torch.eye(self.input_dim) - 2.0 * torch.ger(u[idx], u[idx])
partial_Q, torch.eye(self.input_dim) - 2.0 * torch.outer(u[idx], u[idx])
)

return partial_Q
Expand Down
17 changes: 9 additions & 8 deletions pyro/infer/autoguide/effect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from operator import attrgetter
from typing import Callable, Optional, Tuple, Union

import torch
Expand All @@ -14,7 +15,7 @@
from pyro.poutine.runtime import get_plates

from .initialization import init_to_feasible, init_to_mean
from .utils import deep_getattr, deep_setattr, helpful_support_errors
from .utils import deep_setattr, helpful_support_errors


class AutoMessengerMeta(type(GuideMessenger), type(PyroModule)):
Expand Down Expand Up @@ -175,8 +176,8 @@ def get_posterior(

def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
loc = attrgetter(name)(self.locs)
scale = attrgetter(name)(self.scales)
return loc, scale
except AttributeError:
pass
Expand Down Expand Up @@ -287,10 +288,10 @@ def get_posterior(

def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
loc = attrgetter(name)(self.locs)
scale = attrgetter(name)(self.scales)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
weight = deep_getattr(self.weights, name)
weight = attrgetter(name)(self.weights)
return loc, scale, weight
else:
return loc, scale
Expand Down Expand Up @@ -427,8 +428,8 @@ def get_posterior(

def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
loc = attrgetter(name)(self.locs)
scale = attrgetter(name)(self.scales)
return loc, scale
except AttributeError:
pass
Expand Down
17 changes: 9 additions & 8 deletions pyro/infer/autoguide/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from operator import attrgetter
from types import SimpleNamespace
from typing import Callable, Dict, Optional, Set, Tuple, Union

Expand All @@ -23,7 +24,7 @@

from .guides import AutoGuide
from .initialization import InitMessenger, init_to_feasible
from .utils import deep_getattr, deep_setattr, helpful_support_errors
from .utils import deep_setattr, helpful_support_errors


# Helper to dispatch to concrete subclasses of AutoGaussian, e.g.
Expand Down Expand Up @@ -287,8 +288,8 @@ def _transform_values(
for name, site in self._factors.items():
if site["is_observed"]:
continue
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
loc = attrgetter(name)(self.locs)
scale = attrgetter(name)(self.scales)
unconstrained = aux_values[name] * scale + loc

# Transform to constrained space.
Expand Down Expand Up @@ -335,7 +336,7 @@ def _setup_prototype(self, *args, **kwargs):
# Create sparse -> dense precision scatter indices.
self._dense_scatter = {}
for d, site in self._factors.items():
prec_sqrt_shape = deep_getattr(self.prec_sqrts, d).shape
prec_sqrt_shape = attrgetter(d)(self.prec_sqrts).shape
info_vec_shape = prec_sqrt_shape[:-1]
precision_shape = prec_sqrt_shape[:-1] + prec_sqrt_shape[-2:-1]
index1 = torch.zeros(info_vec_shape, dtype=torch.long)
Expand Down Expand Up @@ -425,8 +426,8 @@ def _dense_get_mvn(self):
flat_info_vec = torch.zeros(self._dense_size)
flat_precision = torch.zeros(self._dense_size**2)
for d, (index1, index2) in self._dense_scatter.items():
white_vec = deep_getattr(self.white_vecs, d)
prec_sqrt = deep_getattr(self.prec_sqrts, d)
white_vec = attrgetter(d)(self.white_vecs)
prec_sqrt = attrgetter(d)(self.prec_sqrts)
info_vec = (prec_sqrt @ white_vec[..., None])[..., 0]
precision = prec_sqrt @ prec_sqrt.transpose(-1, -2)
flat_info_vec.scatter_add_(0, index1, info_vec.reshape(-1))
Expand Down Expand Up @@ -505,8 +506,8 @@ def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch.Tensor]:
batch_shape = torch.Size(
p.size for p in sorted(self._plates[d], key=lambda p: p.dim)
)
white_vec = deep_getattr(self.white_vecs, d)
prec_sqrt = deep_getattr(self.prec_sqrts, d)
white_vec = attrgetter(d)(self.white_vecs)
prec_sqrt = attrgetter(d)(self.prec_sqrts)
factors[d] = funsor.gaussian.Gaussian(
white_vec=white_vec.reshape(batch_shape + white_vec.shape[-1:]),
prec_sqrt=prec_sqrt.reshape(batch_shape + prec_sqrt.shape[-2:]),
Expand Down
7 changes: 4 additions & 3 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def model():
import warnings
import weakref
from contextlib import ExitStack
from operator import attrgetter

import torch
from torch import nn
Expand All @@ -38,7 +39,7 @@ def model():
from pyro.poutine.util import site_is_subsample

from .initialization import InitMessenger, init_to_feasible, init_to_median
from .utils import _product, deep_getattr, deep_setattr, helpful_support_errors
from .utils import _product, deep_setattr, helpful_support_errors


def prototype_hide_fn(msg):
Expand Down Expand Up @@ -491,8 +492,8 @@ def _setup_prototype(self, *args, **kwargs):
)

def _get_loc_and_scale(self, name):
site_loc = deep_getattr(self.locs, name)
site_scale = deep_getattr(self.scales, name)
site_loc = attrgetter(name)(self.locs)
site_scale = attrgetter(name)(self.scales)
return site_loc, site_scale

def forward(self, *args, **kwargs):
Expand Down
19 changes: 10 additions & 9 deletions pyro/infer/autoguide/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from operator import attrgetter
from types import SimpleNamespace
from typing import Callable, Dict, Optional, Union

Expand All @@ -19,7 +20,7 @@

from .guides import AutoGuide
from .initialization import InitMessenger, init_to_feasible
from .utils import deep_getattr, deep_setattr, helpful_support_errors
from .utils import deep_setattr, helpful_support_errors


def _config_auxiliary(msg):
Expand Down Expand Up @@ -274,11 +275,11 @@ def get_deltas(self, save_params=None):

# Sample zero-mean blockwise independent Delta/Normal/MVN.
log_density = 0.0
loc = deep_getattr(self.locs, name)
loc = attrgetter(name)(self.locs)
zero = torch.zeros_like(loc)
conditional = self.conditionals[name]
if callable(conditional):
aux_value = deep_getattr(self.conds, name)()
aux_value = attrgetter(name)(self.conds)()
elif conditional == "delta":
aux_value = zero
elif conditional == "normal":
Expand All @@ -287,7 +288,7 @@ def get_deltas(self, save_params=None):
dist.Normal(zero, 1).to_event(1),
infer={"is_auxiliary": True},
)
scale = deep_getattr(self.scales, name)
scale = attrgetter(name)(self.scales)
aux_value = aux_value * scale
if compute_density:
log_density = (-scale.log()).expand_as(aux_value)
Expand All @@ -299,8 +300,8 @@ def get_deltas(self, save_params=None):
dist.Normal(zero, 1).to_event(1),
infer={"is_auxiliary": True},
)
scale = deep_getattr(self.scales, name)
scale_tril = deep_getattr(self.scale_trils, name)
scale = attrgetter(name)(self.scales)
scale_tril = attrgetter(name)(self.scale_trils)
aux_value = aux_value @ scale_tril.T * scale
if compute_density:
log_density = (
Expand All @@ -318,9 +319,9 @@ def get_deltas(self, save_params=None):
# Note: these shear transforms have no effect on the Jacobian
# determinant, and can therefore be excluded from the log_density
# computation below, even for nonlinear dep().
deps = deep_getattr(self.deps, name)
deps = attrgetter(name)(self.deps)
for upstream in self.dependencies.get(name, {}):
dep = deep_getattr(deps, upstream)
dep = attrgetter(upstream)(deps)
aux_value = aux_value + dep(aux_values[upstream])
aux_values[name] = aux_value

Expand Down Expand Up @@ -368,7 +369,7 @@ def forward(self, *args, **kwargs):
def median(self, *args, **kwargs):
result = {}
for name, site in self._sorted_sites:
loc = deep_getattr(self.locs, name).detach()
loc = attrgetter(name)(self.locs).detach()
shape = self._batch_shapes[name] + self._unconstrained_event_shapes[name]
loc = loc.reshape(shape)
result[name] = biject_to(site["fn"].support)(loc)
Expand Down
6 changes: 0 additions & 6 deletions pyro/infer/autoguide/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ def _product(shape):
return result


def deep_getattr(obj, key):
for part in key.split("."):
obj = getattr(obj, part)
return obj


def deep_setattr(obj, key, val):
"""
Set an attribute `key` on the object. If any of the prefix attributes do
Expand Down
4 changes: 2 additions & 2 deletions pyro/ops/welford.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def update(self, sample):
if self.diagonal:
self._m2 += delta_pre * delta_post
else:
self._m2 += torch.ger(delta_post, delta_pre)
self._m2 += torch.outer(delta_post, delta_pre)

def get_covariance(self, regularize=True):
if self.n_samples < 2:
Expand Down Expand Up @@ -72,7 +72,7 @@ def update(self, sample):
self._mean = self._mean + delta_pre / self.n_samples
delta_post = sample - self._mean
if self.head_size > 0:
self._m2_top = self._m2_top + torch.ger(
self._m2_top = self._m2_top + torch.outer(
delta_post[: self.head_size], delta_pre
)
else:
Expand Down
5 changes: 3 additions & 2 deletions pyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import OrderedDict
from contextlib import ExitStack, contextmanager
from inspect import isclass
from operator import attrgetter
from typing import Callable, Iterator, Optional, Sequence, Union

import torch
Expand All @@ -28,7 +29,7 @@
effectful,
)
from pyro.poutine.subsample_messenger import SubsampleMessenger
from pyro.util import deep_getattr, set_rng_seed # noqa: F401
from pyro.util import set_rng_seed # noqa: F401


def get_param_store() -> ParamStoreDict:
Expand Down Expand Up @@ -493,7 +494,7 @@ def module(
mod_name = _name
if _name in target_state_dict.keys():
if not is_param:
deep_getattr(nn_module, mod_name)._parameters[param_name] = (
attrgetter(mod_name)(nn_module)._parameters[param_name] = (
target_state_dict[_name]
)
else:
Expand Down
9 changes: 0 additions & 9 deletions pyro/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import functools
import math
import numbers
import random
Expand Down Expand Up @@ -704,14 +703,6 @@ def ignore_experimental_warning():
yield


def deep_getattr(obj: object, name: str) -> Any:
"""
Python getattr() for arbitrarily deep attributes
Throws an AttributeError if bad attribute
"""
return functools.reduce(getattr, name.split("."), obj)


class timed:
def __enter__(self, timer=timeit.default_timer):
self.start = timer()
Expand Down
16 changes: 12 additions & 4 deletions tests/infer/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ def test_importance_guide(self):
self.model, guide=self.guide, num_samples=5000
).run()
marginal = EmpiricalMarginal(posterior)
assert_equal(0, torch.norm(marginal.mean - self.loc_mean).item(), prec=0.01)
assert_equal(
0, torch.norm(marginal.variance.sqrt() - self.loc_stddev).item(), prec=0.1
0, torch.linalg.norm(marginal.mean - self.loc_mean).item(), prec=0.01
)
assert_equal(
0,
torch.linalg.norm(marginal.variance.sqrt() - self.loc_stddev).item(),
prec=0.1,
)

@pytest.mark.init(rng_seed=0)
Expand All @@ -89,7 +93,11 @@ def test_importance_prior(self):
self.model, guide=None, num_samples=10000
).run()
marginal = EmpiricalMarginal(posterior)
assert_equal(0, torch.norm(marginal.mean - self.loc_mean).item(), prec=0.01)
assert_equal(
0, torch.norm(marginal.variance.sqrt() - self.loc_stddev).item(), prec=0.1
0, torch.linalg.norm(marginal.mean - self.loc_mean).item(), prec=0.01
)
assert_equal(
0,
torch.linalg.norm(marginal.variance.sqrt() - self.loc_stddev).item(),
prec=0.1,
)
28 changes: 1 addition & 27 deletions tests/ops/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from pyro.ops.linalg import rinverse
from tests.common import assert_close, assert_equal
from tests.common import assert_equal


@pytest.mark.parametrize(
Expand Down Expand Up @@ -35,29 +35,3 @@ def test_sym_rinverse(A, use_sym):
batched_A = A.unsqueeze(0).unsqueeze(0).expand(5, 4, d, d)
expected_A = torch.inverse(A).unsqueeze(0).unsqueeze(0).expand(5, 4, d, d)
assert_equal(rinverse(batched_A, sym=use_sym), expected_A, prec=1e-8)


# Tests migration from torch.triangular_solve -> torch.linalg.solve_triangular
@pytest.mark.filterwarnings("ignore:torch.triangular_solve is deprecated")
@pytest.mark.parametrize("upper", [False, True], ids=["lower", "upper"])
def test_triangular_solve(upper):
b = torch.randn(5, 6)
A = torch.randn(5, 5)
expected = torch.triangular_solve(b, A, upper=upper).solution
actual = torch.linalg.solve_triangular(A, b, upper=upper)
assert_close(actual, expected)
A = A.triu() if upper else A.tril()
assert_close(A @ actual, b)


# Tests migration from torch.triangular_solve -> torch.linalg.solve_triangular
@pytest.mark.filterwarnings("ignore:torch.triangular_solve is deprecated")
@pytest.mark.parametrize("upper", [False, True], ids=["lower", "upper"])
def test_triangular_solve_transpose(upper):
b = torch.randn(5, 6)
A = torch.randn(5, 5)
expected = torch.triangular_solve(b, A, upper=upper, transpose=True).solution
actual = torch.linalg.solve_triangular(A.T, b, upper=not upper)
assert_close(actual, expected)
A = A.triu() if upper else A.tril()
assert_close(A.T @ actual, b)
Loading

0 comments on commit 88ae262

Please sign in to comment.