From 94aeed1378aa4fef5bf9e483513637b57711a49a Mon Sep 17 00:00:00 2001 From: Manuel Martinez Date: Sat, 27 Apr 2024 20:05:17 -0700 Subject: [PATCH] fix: resolve failing tests with pandas>=2.2 Fixes failing tests that were due to a change in [pandas>=2.2](https://pandas.pydata.org/docs/reference/api/pandas.core.groupby.DataFrameGroupBy.apply.html) related to using `.apply()` on grouped-by objects. Fixes a bunch of warnings related to dubious type mixing within columns. I also finally fixes the import issue due to using the same name everywhere and adds a unit test for verify that the import is not broken again. It also deprecates 3.8 since it's almost EOL. --- .github/workflows/release.yml | 2 +- pyproject.toml | 40 +++++++++++----------------------- src/stochatreat/__init__.py | 3 +++ src/stochatreat/stochatreat.py | 36 ++++++++++++++---------------- tests/test_import.py | 4 ++++ tests/test_io.py | 4 +++- 6 files changed, 40 insertions(+), 49 deletions(-) create mode 100644 tests/test_import.py diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f0f3446..2711864 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 diff --git a/pyproject.toml b/pyproject.toml index 1c350d7..e05bc6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,24 +40,19 @@ Source = "https://github.com/manmartgarc/stochatreat/" path = "src/stochatreat/__about__.py" [tool.hatch.envs.default] -dependencies = [ - "pytest", - "coverage[toml]>=6.5", -] +installer = "uv" +dependencies = ["pytest", "pytest-cov", "pytest-xdist", "coverage[toml]>=6.5"] [tool.hatch.envs.default.scripts] -test = "pytest {args:tests}" -test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report" -] +test = "pytest -n auto {args:tests}" +test-cov = "pytest -n auto --cov {args:tests}" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] + +[tool.hatch.envs.all] +installer = "uv" [[tool.hatch.envs.all.matrix]] -python = ["3.8", "3.9", "3.10", "3.11", "3.12"] +python = ["3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.types] dependencies = ["mypy>=1.0.0"] @@ -68,32 +63,23 @@ check = "mypy --install-types --non-interactive {args:src/stochatreat tests}" source_pkgs = ["stochatreat", "tests"] branch = true parallel = true -omit = [ - "src/stochatreat/__about__.py", -] +omit = ["src/stochatreat/__about__.py"] [tool.coverage.paths] stochatreat = ["src/stochatreat", "*/stochatreat/src/stochatreat"] tests = ["tests", "*/stochatreat/tests"] [tool.coverage.report] -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHEKCING" -] +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHEKCING"] [tool.ruff] line-length = 79 [tool.ruff.lint] -ignore = [ - "TRY003" -] +ignore = ["TRY003"] [tool.mypy] check_untyped_defs = true show_error_codes = true pretty = true ignore_missing_imports = true - diff --git a/src/stochatreat/__init__.py b/src/stochatreat/__init__.py index e69de29..fa5d5bc 100644 --- a/src/stochatreat/__init__.py +++ b/src/stochatreat/__init__.py @@ -0,0 +1,3 @@ +from stochatreat.stochatreat import stochatreat + +__all__ = ["stochatreat"] diff --git a/src/stochatreat/stochatreat.py b/src/stochatreat/stochatreat.py index d1066c3..58949dc 100644 --- a/src/stochatreat/stochatreat.py +++ b/src/stochatreat/stochatreat.py @@ -1,6 +1,3 @@ -""" - -""" from __future__ import annotations from typing import Literal @@ -160,9 +157,10 @@ def stochatreat( data = data.groupby("stratum_id").apply( lambda x: x.sample( n=reduced_sizes[x.name], random_state=random_state - ) + ), + include_groups=False, ) - + data["stratum_id"] = data.index.get_level_values(0) data = data.droplevel(level="stratum_id") # Treatment assignment proceeds in two stages within each stratum: @@ -189,21 +187,20 @@ def stochatreat( if misfit_strategy == "global": # separate the global misfits - misfit_data = ( - data.groupby("stratum_id") - .apply( - lambda x: x.sample( - n=(x.shape[0] % lcm_prob_denominators), - replace=False, - random_state=random_state, - ) - ) - .droplevel(level="stratum_id") + misfit_data = data.groupby("stratum_id").apply( + lambda x: x.sample( + n=(x.shape[0] % lcm_prob_denominators), + replace=False, + random_state=random_state, + ), + include_groups=False, ) + misfit_data["stratum_id"] = misfit_data.index.get_level_values(0) + misfit_data = misfit_data.droplevel(level="stratum_id") good_form_data = data.drop(index=misfit_data.index) # assign the misfits their own stratum and concatenate - misfit_data.loc[:, "stratum_id"] = np.Inf + misfit_data.loc[:, "stratum_id"] = -1 data = pd.concat([good_form_data, misfit_data]) # ========================================================================= @@ -225,9 +222,8 @@ def stochatreat( fake_rep = pd.DataFrame( fake.values.repeat(fake["fake"], axis=0), columns=fake.columns ) - - data.loc[:, "fake"] = False - fake_rep.loc[:, "fake"] = True + data.loc[:, "fake"] = 0 + fake_rep.loc[:, "fake"] = 1 data = pd.concat([data, fake_rep], sort=False).sort_values(by="stratum_id") @@ -240,7 +236,7 @@ def stochatreat( # lookup treatment name for permutations. This works because we flatten # row-major style, i.e. one row after another. data.loc[:, "treat"] = treat_mask[permutations].flatten(order="C") - data = data[~data["fake"]].drop(columns=["fake"]) + data = data[data["fake"] == 0].drop(columns=["fake"]) # re-assign type - as it might have changed with the addition of fake data data[idx_col] = data[idx_col].astype(idx_col_type) diff --git a/tests/test_import.py b/tests/test_import.py new file mode 100644 index 0000000..cb8920a --- /dev/null +++ b/tests/test_import.py @@ -0,0 +1,4 @@ +def test_import(): + from stochatreat import stochatreat + + assert callable(stochatreat) diff --git a/tests/test_io.py b/tests/test_io.py index 176d49d..6b6c84d 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -240,7 +240,9 @@ def treatments_dict_rand_index(): "stratum": [0] * 40 + [1] * 30 + [2] * 30, } ) - data = data.set_index(pd.Index(np.random.choice(300, 100, replace=False))) + data = data.set_index( + pd.Index(np.random.choice(300, 100, replace=False)).astype(np.int64) + ) idx_col = "id" treatments = stochatreat(