Skip to content

Commit

Permalink
Make KroneckerNormal a SymbolicRV with a valid signature
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 21, 2024
1 parent 94f6264 commit 0d77a07
Showing 1 changed file with 23 additions and 31 deletions.
54 changes: 23 additions & 31 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,35 +1880,30 @@ def logp(value, mu, rowchol, colchol):
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet


class KroneckerNormalRV(RandomVariable):
name = "kroneckernormal"
class KroneckerNormalRV(SymbolicRandomVariable):
ndim_supp = 1
ndims_params = [1, 0, 2]
dtype = "floatX"
_print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}")

def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)

def rng_fn(self, rng, mu, sigma, *covs, size=None):
size = size if size else covs[-1]
covs = covs[:-1] if covs[-1] == size else covs

cov = reduce(scipy.linalg.kron, covs)

if sigma:
cov = cov + sigma**2 * np.eye(cov.shape[0])
@classmethod
def rv_op(cls, mu, sigma, *covs, size=None, rng=None):
mu = pt.as_tensor(mu)
sigma = pt.as_tensor(sigma)
covs = [pt.as_tensor(cov) for cov in covs]
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

x = multivariate_normal.rng_fn(rng=rng, mean=mu, cov=cov, size=size)
return x
cov = reduce(pt.linalg.kron, covs)
cov = cov + sigma**2 * pt.eye(cov.shape[-2])
next_rng, draws = multivariate_normal(mean=mu, cov=cov, size=size, rng=rng).owner.outputs

covs_sig = ",".join(f"(a{i},b{i})" for i in range(len(covs)))
signature = f"[rng],[size],(m),(),{covs_sig}->[rng],(m)"

kroneckernormal = KroneckerNormalRV()
return KroneckerNormalRV(
inputs=[rng, size, mu, sigma, *covs],
outputs=[next_rng, draws],
signature=signature,
)(rng, size, mu, sigma, *covs)


class KroneckerNormal(Continuous):
Expand Down Expand Up @@ -1999,7 +1994,8 @@ class KroneckerNormal(Continuous):
.. [1] Saatchi, Y. (2011). "Scalable inference for structured Gaussian process models"
"""

rv_op = kroneckernormal
rv_type = KroneckerNormalRV
rv_op = KroneckerNormalRV.rv_op

@classmethod
def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs):
Expand All @@ -2024,14 +2020,10 @@ def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs)

return super().dist([mu, sigma, *covs], **kwargs)

def support_point(rv, size, mu, covs, chols, evds):
mean = mu
if not rv_size_is_none(size):
support_point_size = pt.concatenate([size, mu.shape])
mean = pt.full(support_point_size, mu)
return mean
def support_point(rv, rng, size, mu, sigma, *covs):
return pt.full_like(rv, mu)

def logp(value, mu, sigma, *covs):
def logp(value, rng, size, mu, sigma, *covs):
"""
Calculate log-probability of Multivariate Normal distribution
with Kronecker-structured covariance at specified value.
Expand Down

0 comments on commit 0d77a07

Please sign in to comment.