diff --git a/.bumpversion.cfg b/.bumpversion.cfg deleted file mode 100644 index 9bc2be0..0000000 --- a/.bumpversion.cfg +++ /dev/null @@ -1,8 +0,0 @@ -[bumpversion] -current_version = 0.0.15 -commit = False -tag = True - -[bumpversion:file:pyproject.toml] -search = version = "{current_version}" -replace = version = "{new_version}" diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..2125666 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto \ No newline at end of file diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000..47ab2b1 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,73 @@ + +# Contributing Guide + +- Check the [GitHub Issues](https://github.com/manmartgarc/stochatreat/issues) for open issues that need attention. +- Follow the [How to submit a contribution](https://opensource.guide/how-to-contribute/#how-to-submit-a-contribution) Guide. + +- Make sure unit tests pass. Please read how to run unit tests [below](#tests). + +- If you are fixing a bug: + - If you are resolving an existing issue, reference the issue ID in a commit message `(e.g., fixed #XXXX)`. + - If the issue has not been reported, please add a detailed description of the bug in the Pull Request (PR). + - Please add a regression test case to check the bug is fixed. + +- If you are adding a new feature: + - Please open a suggestion issue first. + - Provide a convincing reason to add this feature and have it greenlighted before working on it. + - Add tests to cover the functionality. + +- Please follow [Style Guide for Python Code](https://www.python.org/dev/peps/pep-0008/). + +## Setting up development environment + +You can install the development environment, i.e. all the dependencies required to run all tests and checks that are run when you submit a PR, by following these steps: + +1. [Install](https://hatch.pypa.io/1.9/install/#installation) `hatch`. +2. Clone the repository: + ```bash + git clone https://github.com/manmartgarc/stochatreat + cd stochatreat + ``` +3. Confirm `hatch` picked up the project: + ```bash + hatch status + ``` + +## Tests + +To run tests in the default environment: + +```bash +hatch run default:test +``` + +To run tests in all environments: + +```bash +hatch run all:test +``` + +## Format + +When submitting a PR, the CI will run `make format` and also `make lint` to check the format of the code. You can run this locally by running: + +```bash +hatch fmt +``` + +## Release + +- Run `hatch` to update the version number file and create a new tag: + + ```bash + hatch version [major|minor|patch] + ``` + +- Commit the changes and push them to your fork. +- Tag the new version: + ```bash + git tag -a v0.0.0 -m "v0.0.0" + git push origin v0.0.0 + ``` +- Submit a PR. +- Once the PR is merged, run `hatch publish` to create a new release in PyPI. diff --git a/.github/workflows/test.yml b/.github/workflows/release.yml similarity index 50% rename from .github/workflows/test.yml rename to .github/workflows/release.yml index baa2d48..3c21b7f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/release.yml @@ -1,7 +1,7 @@ # This workflow will install Python dependencies, run tests and lint with a variety of Python versions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions -name: test +name: Release CI on: push: @@ -14,28 +14,30 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: pip - - name: Install dependencies + - name: Install hatch run: | - python -m pip install --upgrade pip - pip install -e .[dev] - - name: Lint with Ruff + python -m pip install hatch + - name: Check lint and format run: | - make lint - - name: Check style with Black - run: | - make style + hatch fmt --check - name: Type-checking with Mypy run: | - make test/mypy - - name: Test with pytest + hatch run types:check + - name: Run tests with pytest run: | - make test/pytest + hatch run +py=${{ matrix.python-version }} all:test-cov + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v4 + with: + fail_ci_if_error: true + files: coverage.json + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index 53d469b..09dca20 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .vscode *.ps1 +coverage.json # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/Makefile b/Makefile deleted file mode 100644 index 6116d6a..0000000 --- a/Makefile +++ /dev/null @@ -1,49 +0,0 @@ -.DEFAULT_GOAL := all - -all: clean lint style test dist - -clean: clean-build clean-pyc clean-test - -clean-build: - rm -fr build/ - rm -fr dist/ - rm -fr .eggs/ - find . -name '*.egg-info' -exec rm -fr {} + - find . -name '*.egg' -exec rm -f {} + - -clean-pyc: - find . -name '*.pyc' -exec rm -f {} + - find . -name '*.pyo' -exec rm -f {} + - find . -name '*~' -exec rm -f {} + - find . -name '__pycache__' -exec rm -fr {} + - -clean-test: - rm -fr .tox/ - rm -f .coverage - rm -fr htmlcov/ - rm -fr .pytest_cache - -dist: clean - python -m build - ls -lth dist/ - -lint: lint/ruff - -lint/ruff: - ruff src tests - -release: dist - twine upload dist/* - -style: style/black - -style/black: - black --check src tests - -test: test/mypy test/pytest - -test/pytest: - pytest - -test/mypy: - mypy \ No newline at end of file diff --git a/README.md b/README.md index ca7cb6b..9a5b7c5 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,10 @@ | | | |---|---| -|Build|[![Main Branch Tests](https://github.com/manmartgarc/stochatreat/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/manmartgarc/stochatreat/actions/workflows/test.yml) +|Build|[![Main Branch Tests](https://github.com/manmartgarc/stochatreat/actions/workflows/release.yml/badge.svg?branch=main)](https://github.com/manmartgarc/stochatreat/actions/workflows/test.yml) [![codecov](https://codecov.io/gh/manmartgarc/stochatreat/graph/badge.svg?token=llPoW2rWIN)](https://codecov.io/gh/manmartgarc/stochatreat) |PyPI| [![pypi](https://img.shields.io/pypi/v/stochatreat?logo=pypi)](https://pypi.org/project/stochatreat/) ![pypi-downloads](https://img.shields.io/pypi/dm/stochatreat?logo=pypi) |conda-forge| [![Conda](https://img.shields.io/conda/v/conda-forge/stochatreat?logo=conda-forge)](https://anaconda.org/conda-forge/stochatreat) ![conda-downloads](https://img.shields.io/conda/dn/conda-forge/stochatreat?logo=conda-forge) -|Meta| [![linting - Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [![code style - Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![types - Mypy](https://img.shields.io/badge/types-Mypy-blue.svg)](https://github.com/python/mypy) [![License - MIT](https://img.shields.io/badge/license-MIT-9400d3.svg)](https://spdx.org/licenses/) +|Meta| [![Hatch project](https://img.shields.io/badge/%F0%9F%A5%9A-Hatch-4051b5.svg)](https://github.com/pypa/hatch) [![linting - Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [![types - Mypy](https://img.shields.io/badge/types-Mypy-blue.svg)](https://github.com/python/mypy) [![License - MIT](https://img.shields.io/badge/license-MIT-9400d3.svg)](https://spdx.org/licenses/) --- @@ -119,6 +119,10 @@ nhood dummy 1 35 68 ``` +## Contributing + +If you'd like to contribute to the package, make sure you read the [contributing guide](https://github.com/manmartgarc/stochatreat/blob/main/.github/CONTRIBUTING.md). + ## References - `stochatreat` is totally inspired by [Alvaro Carril's](https://acarril.github.io/) fantastic STATA package: [`randtreat`](https://acarril.github.io/posts/randtreat), which was published in [The Stata Journal](https://www.stata-journal.com/article.html?article=st0490). diff --git a/pyproject.toml b/pyproject.toml index 1163828..5b5824f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,69 +1,77 @@ [build-system] -requires = ["setuptools"] -build-backend = "setuptools.build_meta" +requires = ["hatchling"] +build-backend = "hatchling.build" [project] name = "stochatreat" -version = "0.0.15" +dynamic = ["version"] description = 'Stratified random assignment using pandas' readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" keywords = [ "randomization", "block randomization", "stratified randomization", "stratified random assignment", - "strata" -] -authors = [ - { name = "Manuel Martinez", email = "manmartgarc@gmail.com" }, + "strata", ] +authors = [{ name = "Manuel Martinez", email = "manmartgarc@gmail.com" }] classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Operating System :: OS Independent", "License :: OSI Approved :: MIT License", ] -dependencies = ["pandas"] - -[project.optional-dependencies] -dev = [ - "black", - "build", - "bump2version", - "mypy", - "ruff", - "pytest", - "pytest-cov" -] +dependencies = ["pandas>=2.2"] [project.urls] Documentation = "https://github.com/manmartgarc/stochatreat/blob/main/README.md" Issues = "https://github.com/manmartgarc/stochatreat/issues" Source = "https://github.com/manmartgarc/stochatreat/" -[tools.setuptools.packages.find] -where = ["src"] -namespaces = false +[tool.hatch.version] +path = "src/stochatreat/__about__.py" + +[tool.hatch.envs.default] +installer = "uv" +dependencies = ["pytest", "pytest-cov", "pytest-xdist", "coverage[toml]>=6.5"] +[tool.hatch.envs.default.scripts] +test-cov = "pytest --cov=src/stochatreat tests/ --cov-report json -n auto" +test = "test-cov --no-cov" + +[tool.hatch.envs.all] +installer = "uv" + +[[tool.hatch.envs.all.matrix]] +python = ["3.9", "3.10", "3.11", "3.12"] -[tool.black] +[tool.hatch.envs.types] +dependencies = ["mypy>=1.0.0"] +[tool.hatch.envs.types.scripts] +check = "mypy --install-types --non-interactive src/stochatreat tests/" + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHEKCING", + "__version__ = ", +] + +[tool.ruff] line-length = 79 +[tool.ruff.lint] +ignore = ["TRY003"] + [tool.mypy] -python_version = 3.8 +check_untyped_defs = true +show_error_codes = true +pretty = true ignore_missing_imports = true -packages = ["src"] - -[tool.pytest.ini_options] -addopts = [ - "--cov=stochatreat", - "--cov-branch", - "--cov-report=term-missing", - "--durations=5" -] \ No newline at end of file diff --git a/src/stochatreat/__about__.py b/src/stochatreat/__about__.py index 311f216..39e62e7 100644 --- a/src/stochatreat/__about__.py +++ b/src/stochatreat/__about__.py @@ -1 +1,2 @@ -__version__ = "0.0.14" +# pragma: no cover +__version__ = "0.0.19" diff --git a/src/stochatreat/__init__.py b/src/stochatreat/__init__.py index d230bef..fa5d5bc 100644 --- a/src/stochatreat/__init__.py +++ b/src/stochatreat/__init__.py @@ -1,13 +1,3 @@ -# -*- coding: utf-8 -*- -""" -Created on Wed Jul 10 12:16:55 2019 +from stochatreat.stochatreat import stochatreat -=============================================================================== -@author: Manuel Martinez -@project: stochatreat -=============================================================================== -""" -from stochatreat.stochatreat import stochatreat # noqa: F401 - -__version__ = "0.0.14" -__author__ = "Manuel Martinez" +__all__ = ["stochatreat"] diff --git a/src/stochatreat/py.typed b/src/stochatreat/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/stochatreat/stochatreat.py b/src/stochatreat/stochatreat.py index 986bede..a913a5a 100644 --- a/src/stochatreat/stochatreat.py +++ b/src/stochatreat/stochatreat.py @@ -1,32 +1,34 @@ -# -*- coding: utf-8 -*- -""" -Created on Thursday, 8th November 2018 2:34:47 pm -=============================================================================== -@filename: stochatreat.py -@author: Manuel Martinez (manmartgarc@gmail.com) -@project: stochatreat -@purpose: Define a function that assign treatments over an arbitrary - number of strata. -=============================================================================== +"""Stratified random assignment of treatments to units. + +This module provides a function to assign treatments to units in a +stratified manner. The function is designed to work with pandas +dataframes and is able to handle multiple strata. There are also different +strategies to deal with misfits (units that are left over after the +stratified assignment procedure). """ -from typing import List, Optional -from fractions import Fraction + +from __future__ import annotations + +import math +from typing import Literal import numpy as np import pandas as pd from stochatreat.utils import get_lcm_prob_denominators +MIN_ROW_N = 2 + def stochatreat( data: pd.DataFrame, - stratum_cols: List[str], + stratum_cols: list[str], treats: int, - probs: Optional[List[float]] = None, + probs: list[float] | None = None, random_state: int = 42, - idx_col: Optional[str] = None, - size: Optional[int] = None, - misfit_strategy: str = "stratum", + idx_col: str | None = None, + size: int | None = None, + misfit_strategy: Literal["stratum", "global"] = "stratum", ) -> pd.DataFrame: """ Takes a dataframe and an arbitrary number of treatments over an @@ -72,7 +74,7 @@ def stochatreat( treats=2, # including control idx_col='myid', # unique id column random_state=42) # seed for rng - >>> data = data.merge(treats, how='left', on='myid') + >>> data = data.merge(treats, how="left", on="myid") Multiple strata: >>> treats = stochatreat(data=data, @@ -81,10 +83,9 @@ def stochatreat( probs=[1/3, 2/3], idx_col='myid', random_state=42) - >>> data = data.merge(treats, how='left', on='myid') + >>> data = data.merge(treats, how="left", on="myid") """ - # pylint: disable=invalid-name - R = np.random.RandomState(random_state) + rand = np.random.RandomState(random_state) # ========================================================================= # do checks @@ -99,46 +100,51 @@ def stochatreat( probs_np = np.array([frac] * len(treatment_ids)) elif probs is not None: probs_np = np.array(probs) - probs_sum = float(np.array([Fraction(f).limit_denominator() for f in probs]).sum()) - if probs_sum != 1: - raise ValueError("The probabilities must add up to 1") - - assertmsg = "length of treatments and probs must be the same" - assert len(treatment_ids) == len(probs_np), assertmsg + if not math.isclose(probs_np.sum(), 1, rel_tol=1e-9): + error_msg = "The probabilities must add up to 1" + raise ValueError(error_msg) + if len(probs_np) != len(treatment_ids): + error_msg = ( + "The number of probabilities must match the number of " + "treatments" + ) + raise ValueError(error_msg) # check if dataframe is empty if data.empty: - raise ValueError("Make sure that your dataframe is not empty.") + error_msg = "Make sure that your dataframe is not empty." + raise ValueError(error_msg) # check length of data - if len(data) < 2: - raise ValueError("Make sure your data has enough observations.") + if len(data) < MIN_ROW_N: + error_msg = "Your dataframe at least needs to have 2 rows." + raise ValueError(error_msg) # if idx_col parameter was not defined. if idx_col is None: data = data.rename_axis("index", axis="index").reset_index() idx_col = "index" elif not isinstance(idx_col, str): - raise TypeError("idx_col has to be a string.") + error_msg = "idx_col has to be a string." + raise TypeError(error_msg) # retrieve type to check and re-assign in the end idx_col_type = data[idx_col].dtype # check for unique identifiers if data[idx_col].duplicated(keep=False).sum() > 0: - raise ValueError("Values in idx_col are not unique.") + error_msg = "The values in idx_col are not unique." + raise ValueError(error_msg) # if size is larger than sample universe if size is not None and size > len(data): - raise ValueError("Size argument is larger than the sample universe.") + error_msg = "Size argument is larger than the sample universe." + raise ValueError(error_msg) # deal with multiple strata if isinstance(stratum_cols, str): stratum_cols = [stratum_cols] - if misfit_strategy not in ("stratum", "global"): - raise ValueError("the strategy must be one of 'stratum' or 'global'") - # sort data - useful to preserve correspondence between `idx_col` and # assignments data = data.sort_values(by=idx_col) @@ -147,7 +153,7 @@ def stochatreat( data["stratum_id"] = data.groupby(stratum_cols).ngroup() # keep only ids and concatenated strata - data = data[[idx_col] + ["stratum_id"]].copy() + data = data[[idx_col, "stratum_id"]].copy() # apply weights to each stratum if sampling is wanted if size is not None: @@ -161,13 +167,12 @@ def stochatreat( data = data.groupby("stratum_id").apply( lambda x: x.sample( n=reduced_sizes[x.name], random_state=random_state - ) + ), + include_groups=False, ) - + data["stratum_id"] = data.index.get_level_values(0) data = data.droplevel(level="stratum_id") - assert sum(reduced_sizes) == len(data) - # Treatment assignment proceeds in two stages within each stratum: # 1. In as far as units can be neatly divided in the proportions given by # prob they are so divided. @@ -192,21 +197,20 @@ def stochatreat( if misfit_strategy == "global": # separate the global misfits - misfit_data = ( - data.groupby("stratum_id") - .apply( - lambda x: x.sample( - n=(x.shape[0] % lcm_prob_denominators), - replace=False, - random_state=random_state, - ) - ) - .droplevel(level="stratum_id") + misfit_data = data.groupby("stratum_id").apply( + lambda x: x.sample( + n=(x.shape[0] % lcm_prob_denominators), + replace=False, + random_state=random_state, + ), + include_groups=False, ) + misfit_data["stratum_id"] = misfit_data.index.get_level_values(0) + misfit_data = misfit_data.droplevel(level="stratum_id") good_form_data = data.drop(index=misfit_data.index) # assign the misfits their own stratum and concatenate - misfit_data.loc[:, "stratum_id"] = np.Inf + misfit_data.loc[:, "stratum_id"] = -1 data = pd.concat([good_form_data, misfit_data]) # ========================================================================= @@ -228,28 +232,25 @@ def stochatreat( fake_rep = pd.DataFrame( fake.values.repeat(fake["fake"], axis=0), columns=fake.columns ) - - data.loc[:, "fake"] = False - fake_rep.loc[:, "fake"] = True + data.loc[:, "fake"] = 0 + fake_rep.loc[:, "fake"] = 1 data = pd.concat([data, fake_rep], sort=False).sort_values(by="stratum_id") # generate random permutations without loop by generating large number of # random values and sorting row (meaning one permutation) wise permutations = np.argsort( - R.rand(len(data) // lcm_prob_denominators, lcm_prob_denominators), + rand.rand(len(data) // lcm_prob_denominators, lcm_prob_denominators), axis=1, ) # lookup treatment name for permutations. This works because we flatten # row-major style, i.e. one row after another. data.loc[:, "treat"] = treat_mask[permutations].flatten(order="C") - data = data[~data["fake"]].drop(columns=["fake"]) + data = data[data["fake"] == 0].drop(columns=["fake"]) # re-assign type - as it might have changed with the addition of fake data data[idx_col] = data[idx_col].astype(idx_col_type) data["treat"] = data["treat"].astype(np.int64) - assert data["treat"].isnull().sum() == 0 - return data diff --git a/src/stochatreat/utils.py b/src/stochatreat/utils.py index b1ce817..09ff168 100644 --- a/src/stochatreat/utils.py +++ b/src/stochatreat/utils.py @@ -1,28 +1,13 @@ -import sys +from collections.abc import Iterable from fractions import Fraction -from typing import Iterable +from math import lcm -if sum(sys.version_info[:2]) < 12: - from functools import reduce - from math import gcd # type: ignore - def lcm(*args): - """ - Helper function to compute the Lowest Common Multiple of a list of - integers - """ - return reduce(lambda a, b: a * b // gcd(a, b), args) - -else: - from math import lcm # type: ignore - - -def get_lcm_prob_denominators(probs: Iterable[float]): +def get_lcm_prob_denominators(probs: Iterable[float]) -> int: """ Helper function to compute the LCM of the denominators of the probabilities """ - prob_denominators = [ + prob_denominators = ( Fraction(prob).limit_denominator().denominator for prob in probs - ] - lcm_prob_denominators = lcm(*prob_denominators) - return lcm_prob_denominators + ) + return lcm(*prob_denominators) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_stochatreat_assignment.py b/tests/test_assignment.py similarity index 94% rename from tests/test_stochatreat_assignment.py rename to tests/test_assignment.py index 80b4ecf..f8f2749 100644 --- a/tests/test_stochatreat_assignment.py +++ b/tests/test_assignment.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import pytest + from stochatreat.stochatreat import stochatreat from stochatreat.utils import get_lcm_prob_denominators @@ -11,18 +12,16 @@ @pytest.fixture(params=[10_000, 100_000]) def df(request): - N = request.param - df = pd.DataFrame( + n = request.param + return pd.DataFrame( data={ - "id": np.arange(N), - "dummy": [1] * N, - "stratum1": np.random.randint(1, 100, size=N), - "stratum2": np.random.randint(0, 2, size=N), + "id": np.arange(n), + "dummy": [1] * n, + "stratum1": np.random.randint(1, 100, size=n), + "stratum2": np.random.randint(0, 2, size=n), } ) - return df - # a set of treatment assignment probabilities to throw at many tests standard_probs = [ @@ -31,6 +30,7 @@ def df(request): [0.5, 0.5], [2 / 3, 1 / 3], [0.9, 0.1], + [1 / 2, 1 / 3, 1 / 6], ] # a set of stratum column combinations from the above df fixture to throw at @@ -46,19 +46,17 @@ def df(request): # no misfits @pytest.fixture def df_no_misfits(): - N = 1_000 + n = 1_000 stratum_size = 10 - df = pd.DataFrame( + return pd.DataFrame( data={ - "id": np.arange(N), + "id": np.arange(n), "stratum": np.repeat( - np.arange(N / stratum_size), repeats=stratum_size + np.arange(n / stratum_size), repeats=stratum_size ), } ) - return df - probs_no_misfits = [ [0.1, 0.9], @@ -147,11 +145,11 @@ def test_stochatreat_only_misfits(probs): of units is sufficiently large -- relies on the Law of Large Numbers, not deterministic """ - N = 10_000 + n = 10_000 df = pd.DataFrame( data={ - "id": np.arange(N), - "stratum": np.arange(N), + "id": np.arange(n), + "stratum": np.arange(n), } ) treats = stochatreat( @@ -190,12 +188,10 @@ def get_within_strata_counts(treats): .reset_index() ) - counts = pd.merge( + return pd.merge( treatment_counts, stratum_counts, on="stratum_id", how="left" ) - return counts - def compute_count_diff(treats, probs): """ @@ -324,9 +320,10 @@ def test_stochatreat_stratum_ids(df, misfit_strategy, stratum_cols): if misfit_strategy == "global": # depending on whether there are misfits - assert (n_unique_stratum_ids == n_unique_strata) or ( - n_unique_stratum_ids - 1 == n_unique_strata - ) + assert n_unique_strata in { + n_unique_stratum_ids, + n_unique_stratum_ids - 1, + } else: assert n_unique_stratum_ids == n_unique_strata diff --git a/tests/test_import.py b/tests/test_import.py new file mode 100644 index 0000000..cb8920a --- /dev/null +++ b/tests/test_import.py @@ -0,0 +1,4 @@ +def test_import(): + from stochatreat import stochatreat + + assert callable(stochatreat) diff --git a/tests/test_stochatreat_io.py b/tests/test_io.py similarity index 89% rename from tests/test_stochatreat_io.py rename to tests/test_io.py index 494e30b..6b6c84d 100644 --- a/tests/test_stochatreat_io.py +++ b/tests/test_io.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import pytest + from stochatreat.stochatreat import stochatreat @@ -9,7 +10,7 @@ def correct_params(): """ A set of valid parameters that can be passed to stochatreat() """ - params = { + return { "probs": [0.1, 0.9], "treat": 2, "data": pd.DataFrame( @@ -17,7 +18,6 @@ def correct_params(): ), "idx_col": "id", } - return params def test_input_invalid_probs(correct_params): @@ -25,7 +25,7 @@ def test_input_invalid_probs(correct_params): Tests that the function rejects probabilities that don't add up to one """ probs_not_sum_to_one = [0.1, 0.2] - with pytest.raises(Exception): + with pytest.raises(ValueError, match="The probabilities must add up to 1"): stochatreat( data=correct_params["data"], stratum_cols=["stratum"], @@ -41,7 +41,12 @@ def test_input_more_treats_than_probs(correct_params): different sizes """ treat_too_large = 3 - with pytest.raises(Exception): + with pytest.raises( + ValueError, + match=( + "The number of probabilities must match the number of treatments" + ), + ): stochatreat( data=correct_params["data"], stratum_cols=["stratum"], @@ -56,10 +61,12 @@ def test_input_empty_data(correct_params): Tests that the function raises an error when an empty dataframe is passed """ empty_data = pd.DataFrame() - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Make sure that your dataframe is not empty." + ): stochatreat( data=empty_data, - stratum_cols="stratum", + stratum_cols=["stratum"], treats=correct_params["treat"], idx_col=correct_params["idx_col"], probs=correct_params["probs"], @@ -72,12 +79,12 @@ def test_input_idx_col_str(correct_params): string or None """ idx_col_not_str = 0 - with pytest.raises(TypeError): + with pytest.raises(TypeError, match="idx_col has to be a string."): stochatreat( data=correct_params["data"], stratum_cols=["stratum"], treats=correct_params["treat"], - idx_col=idx_col_not_str, + idx_col=idx_col_not_str, # type: ignore probs=correct_params["probs"], ) @@ -87,7 +94,9 @@ def test_input_invalid_size(correct_params): Tests that the function rejects a sampling size larger than the data count """ size_bigger_than_sampling_universe_size = 101 - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="Size argument is larger than the sample universe." + ): stochatreat( data=correct_params["data"], stratum_cols=["stratum"], @@ -106,7 +115,9 @@ def test_input_idx_col_unique(correct_params): data_with_idx_col_with_duplicates = pd.DataFrame( data={"id": 1, "stratum": np.arange(100)} ) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="The values in idx_col are not unique." + ): stochatreat( data=data_with_idx_col_with_duplicates, stratum_cols=["stratum"], @@ -116,23 +127,6 @@ def test_input_idx_col_unique(correct_params): ) -def test_input_invalid_strategy(correct_params): - """ - Tests that the function raises an error if an invalid strategy string is - passed - """ - unknown_strategy = "unknown" - with pytest.raises(ValueError): - stochatreat( - data=correct_params["data"], - stratum_cols=["stratum"], - treats=correct_params["treat"], - idx_col=correct_params["idx_col"], - probs=correct_params["probs"], - misfit_strategy=unknown_strategy, - ) - - @pytest.fixture def treatments_dict(): """fixture of stochatreat() output to test output format""" @@ -152,15 +146,13 @@ def treatments_dict(): random_state=42, ) - treatments_dict = { + return { "data": data, "idx_col": idx_col, "size": size, "treatments": treatments, } - return treatments_dict - def test_output_type(treatments_dict): """ @@ -248,7 +240,9 @@ def treatments_dict_rand_index(): "stratum": [0] * 40 + [1] * 30 + [2] * 30, } ) - data = data.set_index(pd.Index(np.random.choice(300, 100, replace=False))) + data = data.set_index( + pd.Index(np.random.choice(300, 100, replace=False)).astype(np.int64) + ) idx_col = "id" treatments = stochatreat( @@ -259,7 +253,7 @@ def treatments_dict_rand_index(): random_state=42, ) - treatments_dict = { + return { "data": data, "stratum_cols": ["stratum"], "idx_col": idx_col, @@ -267,8 +261,6 @@ def treatments_dict_rand_index(): "n_treatments": treats, } - return treatments_dict - standard_probs = [ [0.1, 0.9], @@ -317,7 +309,7 @@ def test_output_index_and_idx_col_correspondence( treatments = stochatreat( data=data_with_rand_index, - stratum_cols="stratum", + stratum_cols=["stratum"], probs=probs, treats=2, idx_col=idx_col,