Skip to content

Commit

Permalink
add test case for dist with inner graph
Browse files Browse the repository at this point in the history
  • Loading branch information
aerubanov committed Sep 11, 2023
1 parent bb5fd48 commit 29f485b
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,26 @@ def test_custom_dist_default_moment(self, dist_params, size, expected, dist_fn):
CustomDist("x", *dist_params, dist=dist_fn, size=size)
assert_moment_is_expected(model, expected)

def test_custom_dist_custom_moment_inner_graph(self):
def scan_step(mu):
x = pm.Normal.dist(mu, 1)
x_update = collect_default_updates([x])
return x, x_update

def dist(mu, size):
# size = size.reshape(mu.shape)
ys, _ = pytensor.scan(
fn=scan_step,
sequences=[mu],
outputs_info=[None],
name="ys",
)
return pt.sum(ys)

with Model() as model:
CustomDist("x", pt.ones(2), dist=dist)
assert_moment_is_expected(model, 2)

def test_logcdf_inference(self):
def custom_dist(mu, sigma, size):
return pt.exp(pm.Normal.dist(mu, sigma, size=size))
Expand Down

0 comments on commit 29f485b

Please sign in to comment.