diff --git a/fglib/tests/test_inference.py b/fglib/tests/test_inference.py index 7d38f97..f817603 100644 --- a/fglib/tests/test_inference.py +++ b/fglib/tests/test_inference.py @@ -134,5 +134,55 @@ def test_msa(self): pass +class TestExample(unittest.TestCase): + + def test_readme(self): + # Create factor graph + fg = graphs.FactorGraph() + + # Create variable nodes + x1 = nodes.VNode("x1", rv.Discrete) # with 2 states (Bernoulli) + x2 = nodes.VNode("x2", rv.Discrete) # with 3 states + x3 = nodes.VNode("x3", rv.Discrete) + x4 = nodes.VNode("x4", rv.Discrete) + + # Create factor nodes (with joint distributions) + dist_fa = [[0.3, 0.2, 0.1], + [0.3, 0.0, 0.1]] + fa = nodes.FNode("fa", rv.Discrete(dist_fa, x1, x2)) + + dist_fb = [[0.3, 0.2], + [0.3, 0.0], + [0.1, 0.1]] + fb = nodes.FNode("fb", rv.Discrete(dist_fb, x2, x3)) + + dist_fc = [[0.3, 0.2], + [0.3, 0.0], + [0.1, 0.1]] + fc = nodes.FNode("fc", rv.Discrete(dist_fc, x2, x4)) + + # Add nodes to factor graph + fg.set_nodes([x1, x2, x3, x4]) + fg.set_nodes([fa, fb, fc]) + + # Add edges to factor graph + fg.set_edge(x1, fa) + fg.set_edge(fa, x2) + fg.set_edge(x2, fb) + fg.set_edge(fb, x3) + fg.set_edge(x2, fc) + fg.set_edge(fc, x4) + + # Perform sum-product algorithm on factor graph + # and request belief of variable node x4 + belief = inference.sum_product(fg, x4) + + # Print belief of variables + # print("Belief of variable node x4:") + # print(belief) + + npt.assert_almost_equal(belief.pmf, np.array([0.63, 0.36]), decimal=2) + + if __name__ == "__main__": unittest.main()