diff --git a/lightweight_mmm/plot_test.py b/lightweight_mmm/plot_test.py index 9c53326..426ef26 100644 --- a/lightweight_mmm/plot_test.py +++ b/lightweight_mmm/plot_test.py @@ -162,7 +162,8 @@ def test_plot_response_curves_produces_y_axis_starting_at_zero( calls_list = self.mock_sns_lineplot.call_args_list for _, call_kwargs in calls_list[:3]: - self.assertEqual(call_kwargs["y"].min().item(), 0) + self.assertLessEqual(call_kwargs["y"].min().item(), 0.1) + self.assertGreaterEqual(call_kwargs["y"].min().item(), -0.1) @parameterized.named_parameters([ dict(