Skip to content

Commit

Permalink
Merge pull request #31 from manmartgarc/fix-pandas
Browse files Browse the repository at this point in the history
fix: resolve failing tests with pandas>=2.2
  • Loading branch information
manmartgarc authored Apr 28, 2024
2 parents 77817d5 + 94aeed1 commit c023860
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 49 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
Expand Down
40 changes: 13 additions & 27 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,19 @@ Source = "https://github.com/manmartgarc/stochatreat/"
path = "src/stochatreat/__about__.py"

[tool.hatch.envs.default]
dependencies = [
"pytest",
"coverage[toml]>=6.5",
]
installer = "uv"
dependencies = ["pytest", "pytest-cov", "pytest-xdist", "coverage[toml]>=6.5"]
[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
test-cov = "coverage run -m pytest {args:tests}"
cov-report = [
"- coverage combine",
"coverage report",
]
cov = [
"test-cov",
"cov-report"
]
test = "pytest -n auto {args:tests}"
test-cov = "pytest -n auto --cov {args:tests}"
cov-report = ["- coverage combine", "coverage report"]
cov = ["test-cov", "cov-report"]

[tool.hatch.envs.all]
installer = "uv"

[[tool.hatch.envs.all.matrix]]
python = ["3.8", "3.9", "3.10", "3.11", "3.12"]
python = ["3.9", "3.10", "3.11", "3.12"]

[tool.hatch.envs.types]
dependencies = ["mypy>=1.0.0"]
Expand All @@ -68,32 +63,23 @@ check = "mypy --install-types --non-interactive {args:src/stochatreat tests}"
source_pkgs = ["stochatreat", "tests"]
branch = true
parallel = true
omit = [
"src/stochatreat/__about__.py",
]
omit = ["src/stochatreat/__about__.py"]

[tool.coverage.paths]
stochatreat = ["src/stochatreat", "*/stochatreat/src/stochatreat"]
tests = ["tests", "*/stochatreat/tests"]

[tool.coverage.report]
exclude_lines = [
"no cov",
"if __name__ == .__main__.:",
"if TYPE_CHEKCING"
]
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHEKCING"]

[tool.ruff]
line-length = 79

[tool.ruff.lint]
ignore = [
"TRY003"
]
ignore = ["TRY003"]

[tool.mypy]
check_untyped_defs = true
show_error_codes = true
pretty = true
ignore_missing_imports = true

3 changes: 3 additions & 0 deletions src/stochatreat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from stochatreat.stochatreat import stochatreat

__all__ = ["stochatreat"]
36 changes: 16 additions & 20 deletions src/stochatreat/stochatreat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
"""
"""
from __future__ import annotations

from typing import Literal
Expand Down Expand Up @@ -160,9 +157,10 @@ def stochatreat(
data = data.groupby("stratum_id").apply(
lambda x: x.sample(
n=reduced_sizes[x.name], random_state=random_state
)
),
include_groups=False,
)

data["stratum_id"] = data.index.get_level_values(0)
data = data.droplevel(level="stratum_id")

# Treatment assignment proceeds in two stages within each stratum:
Expand All @@ -189,21 +187,20 @@ def stochatreat(

if misfit_strategy == "global":
# separate the global misfits
misfit_data = (
data.groupby("stratum_id")
.apply(
lambda x: x.sample(
n=(x.shape[0] % lcm_prob_denominators),
replace=False,
random_state=random_state,
)
)
.droplevel(level="stratum_id")
misfit_data = data.groupby("stratum_id").apply(
lambda x: x.sample(
n=(x.shape[0] % lcm_prob_denominators),
replace=False,
random_state=random_state,
),
include_groups=False,
)
misfit_data["stratum_id"] = misfit_data.index.get_level_values(0)
misfit_data = misfit_data.droplevel(level="stratum_id")
good_form_data = data.drop(index=misfit_data.index)

# assign the misfits their own stratum and concatenate
misfit_data.loc[:, "stratum_id"] = np.Inf
misfit_data.loc[:, "stratum_id"] = -1
data = pd.concat([good_form_data, misfit_data])

# =========================================================================
Expand All @@ -225,9 +222,8 @@ def stochatreat(
fake_rep = pd.DataFrame(
fake.values.repeat(fake["fake"], axis=0), columns=fake.columns
)

data.loc[:, "fake"] = False
fake_rep.loc[:, "fake"] = True
data.loc[:, "fake"] = 0
fake_rep.loc[:, "fake"] = 1

data = pd.concat([data, fake_rep], sort=False).sort_values(by="stratum_id")

Expand All @@ -240,7 +236,7 @@ def stochatreat(
# lookup treatment name for permutations. This works because we flatten
# row-major style, i.e. one row after another.
data.loc[:, "treat"] = treat_mask[permutations].flatten(order="C")
data = data[~data["fake"]].drop(columns=["fake"])
data = data[data["fake"] == 0].drop(columns=["fake"])

# re-assign type - as it might have changed with the addition of fake data
data[idx_col] = data[idx_col].astype(idx_col_type)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test_import():
from stochatreat import stochatreat

assert callable(stochatreat)
4 changes: 3 additions & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ def treatments_dict_rand_index():
"stratum": [0] * 40 + [1] * 30 + [2] * 30,
}
)
data = data.set_index(pd.Index(np.random.choice(300, 100, replace=False)))
data = data.set_index(
pd.Index(np.random.choice(300, 100, replace=False)).astype(np.int64)
)
idx_col = "id"

treatments = stochatreat(
Expand Down

0 comments on commit c023860

Please sign in to comment.