Skip to content

Commit

Permalink
move IPW integration test to better test file
Browse files Browse the repository at this point in the history
  • Loading branch information
drbenvincent committed Aug 7, 2024
1 parent cc62438 commit 644cf6b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 53 deletions.
51 changes: 51 additions & 0 deletions causalpy/tests/test_integration_pymc_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import arviz as az
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -597,6 +598,56 @@ def test_iv_reg():
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]


@pytest.mark.integration
def test_inverse_prop():
"""Test the InversePropensityWeighting class."""
df = cp.load_data("nhefs")
sample_kwargs = {
"tune": 100,
"draws": 500,
"chains": 2,
"cores": 2,
"random_seed": 100,
}
result = cp.InversePropensityWeighting(
df,
formula="trt ~ 1 + age + race",
outcome_variable="outcome",
weighting_scheme="robust",
model=cp.pymc_models.PropensityScore(sample_kwargs=sample_kwargs),
)
assert isinstance(result.idata, az.InferenceData)
ps = result.idata.posterior["p"].mean(dim=("chain", "draw"))
w1, w2, _, _ = result.make_doubly_robust_adjustment(ps)
assert isinstance(w1, pd.Series)
assert isinstance(w2, pd.Series)
w1, w2, n1, nw = result.make_raw_adjustments(ps)
assert isinstance(w1, pd.Series)
assert isinstance(w2, pd.Series)
w1, w2, n1, n2 = result.make_robust_adjustments(ps)
assert isinstance(w1, pd.Series)
assert isinstance(w2, pd.Series)
w1, w2, n1, n2 = result.make_overlap_adjustments(ps)
assert isinstance(w1, pd.Series)
assert isinstance(w2, pd.Series)
ate_list = result.get_ate(0, result.idata)
assert isinstance(ate_list, list)
ate_list = result.get_ate(0, result.idata, method="raw")
assert isinstance(ate_list, list)
ate_list = result.get_ate(0, result.idata, method="robust")
assert isinstance(ate_list, list)
ate_list = result.get_ate(0, result.idata, method="overlap")
assert isinstance(ate_list, list)
fig, axs = result.plot_ate(prop_draws=1, ate_draws=10)
assert isinstance(fig, plt.Figure)
assert isinstance(axs, list)
assert all(isinstance(ax, plt.Axes) for ax in axs)
fig, axs = result.plot_balance_ecdf("age")
assert isinstance(fig, plt.Figure)
assert isinstance(axs, list)
assert all(isinstance(ax, plt.Axes) for ax in axs)


# DEPRECATION WARNING TESTS ============================================================


Expand Down
53 changes: 0 additions & 53 deletions causalpy/tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
Miscellaneous unit tests
"""

import arviz as az
import pandas as pd
from matplotlib import pyplot as plt

import causalpy as cp

sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}
Expand All @@ -41,52 +37,3 @@ def test_regression_kink_gradient_change():
assert cp.RegressionKink._eval_gradient_change(0, 0, -2, 1) == -2.0
assert cp.RegressionKink._eval_gradient_change(-1, -1, -2, 1) == -1.0
assert cp.RegressionKink._eval_gradient_change(1, 0, -2, 1) == -1.0


def test_inverse_prop():
"""Test the InversePropensityWeighting class."""
df = cp.load_data("nhefs")
sample_kwargs = {
"tune": 100,
"draws": 500,
"chains": 2,
"cores": 2,
"random_seed": 100,
}
result = cp.InversePropensityWeighting(
df,
formula="trt ~ 1 + age + race",
outcome_variable="outcome",
weighting_scheme="robust",
model=cp.pymc_models.PropensityScore(sample_kwargs=sample_kwargs),
)
assert isinstance(result.idata, az.InferenceData)
ps = result.idata.posterior["p"].mean(dim=("chain", "draw"))
w1, w2, _, _ = result.make_doubly_robust_adjustment(ps)
assert isinstance(w1, pd.Series)
assert isinstance(w2, pd.Series)
w1, w2, n1, nw = result.make_raw_adjustments(ps)
assert isinstance(w1, pd.Series)
assert isinstance(w2, pd.Series)
w1, w2, n1, n2 = result.make_robust_adjustments(ps)
assert isinstance(w1, pd.Series)
assert isinstance(w2, pd.Series)
w1, w2, n1, n2 = result.make_overlap_adjustments(ps)
assert isinstance(w1, pd.Series)
assert isinstance(w2, pd.Series)
ate_list = result.get_ate(0, result.idata)
assert isinstance(ate_list, list)
ate_list = result.get_ate(0, result.idata, method="raw")
assert isinstance(ate_list, list)
ate_list = result.get_ate(0, result.idata, method="robust")
assert isinstance(ate_list, list)
ate_list = result.get_ate(0, result.idata, method="overlap")
assert isinstance(ate_list, list)
fig, axs = result.plot_ate(prop_draws=1, ate_draws=10)
assert isinstance(fig, plt.Figure)
assert isinstance(axs, list)
assert all(isinstance(ax, plt.Axes) for ax in axs)
fig, axs = result.plot_balance_ecdf("age")
assert isinstance(fig, plt.Figure)
assert isinstance(axs, list)
assert all(isinstance(ax, plt.Axes) for ax in axs)

0 comments on commit 644cf6b

Please sign in to comment.