diff --git a/causalpy/tests/test_integration_pymc_examples.py b/causalpy/tests/test_integration_pymc_examples.py index 013e722e..6f36764f 100644 --- a/causalpy/tests/test_integration_pymc_examples.py +++ b/causalpy/tests/test_integration_pymc_examples.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import arviz as az import numpy as np import pandas as pd import pytest @@ -597,6 +598,56 @@ def test_iv_reg(): assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"] +@pytest.mark.integration +def test_inverse_prop(): + """Test the InversePropensityWeighting class.""" + df = cp.load_data("nhefs") + sample_kwargs = { + "tune": 100, + "draws": 500, + "chains": 2, + "cores": 2, + "random_seed": 100, + } + result = cp.InversePropensityWeighting( + df, + formula="trt ~ 1 + age + race", + outcome_variable="outcome", + weighting_scheme="robust", + model=cp.pymc_models.PropensityScore(sample_kwargs=sample_kwargs), + ) + assert isinstance(result.idata, az.InferenceData) + ps = result.idata.posterior["p"].mean(dim=("chain", "draw")) + w1, w2, _, _ = result.make_doubly_robust_adjustment(ps) + assert isinstance(w1, pd.Series) + assert isinstance(w2, pd.Series) + w1, w2, n1, nw = result.make_raw_adjustments(ps) + assert isinstance(w1, pd.Series) + assert isinstance(w2, pd.Series) + w1, w2, n1, n2 = result.make_robust_adjustments(ps) + assert isinstance(w1, pd.Series) + assert isinstance(w2, pd.Series) + w1, w2, n1, n2 = result.make_overlap_adjustments(ps) + assert isinstance(w1, pd.Series) + assert isinstance(w2, pd.Series) + ate_list = result.get_ate(0, result.idata) + assert isinstance(ate_list, list) + ate_list = result.get_ate(0, result.idata, method="raw") + assert isinstance(ate_list, list) + ate_list = result.get_ate(0, result.idata, method="robust") + assert isinstance(ate_list, list) + ate_list = result.get_ate(0, result.idata, method="overlap") + assert isinstance(ate_list, list) + fig, axs = result.plot_ate(prop_draws=1, ate_draws=10) + assert isinstance(fig, plt.Figure) + assert isinstance(axs, list) + assert all(isinstance(ax, plt.Axes) for ax in axs) + fig, axs = result.plot_balance_ecdf("age") + assert isinstance(fig, plt.Figure) + assert isinstance(axs, list) + assert all(isinstance(ax, plt.Axes) for ax in axs) + + # DEPRECATION WARNING TESTS ============================================================ diff --git a/causalpy/tests/test_misc.py b/causalpy/tests/test_misc.py index 01c34449..9a1e3b64 100644 --- a/causalpy/tests/test_misc.py +++ b/causalpy/tests/test_misc.py @@ -15,10 +15,6 @@ Miscellaneous unit tests """ -import arviz as az -import pandas as pd -from matplotlib import pyplot as plt - import causalpy as cp sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2} @@ -41,52 +37,3 @@ def test_regression_kink_gradient_change(): assert cp.RegressionKink._eval_gradient_change(0, 0, -2, 1) == -2.0 assert cp.RegressionKink._eval_gradient_change(-1, -1, -2, 1) == -1.0 assert cp.RegressionKink._eval_gradient_change(1, 0, -2, 1) == -1.0 - - -def test_inverse_prop(): - """Test the InversePropensityWeighting class.""" - df = cp.load_data("nhefs") - sample_kwargs = { - "tune": 100, - "draws": 500, - "chains": 2, - "cores": 2, - "random_seed": 100, - } - result = cp.InversePropensityWeighting( - df, - formula="trt ~ 1 + age + race", - outcome_variable="outcome", - weighting_scheme="robust", - model=cp.pymc_models.PropensityScore(sample_kwargs=sample_kwargs), - ) - assert isinstance(result.idata, az.InferenceData) - ps = result.idata.posterior["p"].mean(dim=("chain", "draw")) - w1, w2, _, _ = result.make_doubly_robust_adjustment(ps) - assert isinstance(w1, pd.Series) - assert isinstance(w2, pd.Series) - w1, w2, n1, nw = result.make_raw_adjustments(ps) - assert isinstance(w1, pd.Series) - assert isinstance(w2, pd.Series) - w1, w2, n1, n2 = result.make_robust_adjustments(ps) - assert isinstance(w1, pd.Series) - assert isinstance(w2, pd.Series) - w1, w2, n1, n2 = result.make_overlap_adjustments(ps) - assert isinstance(w1, pd.Series) - assert isinstance(w2, pd.Series) - ate_list = result.get_ate(0, result.idata) - assert isinstance(ate_list, list) - ate_list = result.get_ate(0, result.idata, method="raw") - assert isinstance(ate_list, list) - ate_list = result.get_ate(0, result.idata, method="robust") - assert isinstance(ate_list, list) - ate_list = result.get_ate(0, result.idata, method="overlap") - assert isinstance(ate_list, list) - fig, axs = result.plot_ate(prop_draws=1, ate_draws=10) - assert isinstance(fig, plt.Figure) - assert isinstance(axs, list) - assert all(isinstance(ax, plt.Axes) for ax in axs) - fig, axs = result.plot_balance_ecdf("age") - assert isinstance(fig, plt.Figure) - assert isinstance(axs, list) - assert all(isinstance(ax, plt.Axes) for ax in axs)