Skip to content

Commit

Permalink
test dataset properties
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Apr 16, 2024
1 parent d4180d3 commit 1d5a08b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 10 deletions.
20 changes: 11 additions & 9 deletions latent_calendar/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,8 @@ def _create_local_file_name(name: str) -> Path:
def _load_data(
name: str, read_kwargs=None, save_kwargs=None, local_save: bool = False
) -> pd.DataFrame:
if read_kwargs is None:
read_kwargs = {}

if save_kwargs is None:
save_kwargs = {}
read_kwargs = read_kwargs or {}
save_kwargs = save_kwargs or {}

file = _create_local_file_name(name)

Expand All @@ -67,11 +64,12 @@ def _load_data(
return df


def load_online_transactions(local_save: bool = False) -> pd.DataFrame:
def load_online_transactions(local_save: bool = False, **read_kwargs) -> pd.DataFrame:
"""Kaggle Data for an non-store online retailer in UK. More information [here](https://www.kaggle.com/datasets/mashlyn/online-retail-ii-uci).
Args:
local_save: Whether to save the data locally if it doesn't exists.
read_kwargs: kwargs to pass to pd.read_csv
Returns:
Online transactions data from a non-store online retailer in UK.
Expand All @@ -80,12 +78,13 @@ def load_online_transactions(local_save: bool = False) -> pd.DataFrame:
name = "online_retail_II"
read_kwargs = {
"parse_dates": ["InvoiceDate"],
**read_kwargs,
}

return _load_data(name, read_kwargs=read_kwargs, local_save=local_save)


def load_chicago_bikes(local_save: bool = False) -> pd.DataFrame:
def load_chicago_bikes(local_save: bool = False, **read_kwargs) -> pd.DataFrame:
"""Bikesharing trip level data from Chicago's Divvy system.
Read more about the data source [here](https://data.cityofchicago.org/Transportation/Divvy-Trips/fg6s-gzvg).
Expand All @@ -94,6 +93,7 @@ def load_chicago_bikes(local_save: bool = False) -> pd.DataFrame:
Args:
local_save: Whether to save the data locally if it doesn't exists.
read_kwargs: kwargs to pass to pd.read_csv
Returns:
Trips data from Chicago's Divvy system.
Expand All @@ -103,23 +103,25 @@ def load_chicago_bikes(local_save: bool = False) -> pd.DataFrame:
read_kwargs = {
"parse_dates": ["started_at", "ended_at"],
"index_col": ["ride_id"],
**read_kwargs,
}

return _load_data(name, read_kwargs=read_kwargs, local_save=local_save)


def load_ufo_sightings(local_save: bool = False) -> pd.DataFrame:
def load_ufo_sightings(local_save: bool = False, **read_kwargs) -> pd.DataFrame:
"""UFO sightings over time around the world. More info [here](https://www.kaggle.com/datasets/camnugent/ufo-sightings-around-the-world).
Args:
local_save: Whether to save the data locally if it doesn't exists.
read_kwargs: kwargs to pass to pd.read_csv
Returns:
Sighting level data for UFOs.
"""
name = "ufo_sighting_data"
read_kwargs = {"low_memory": False}
read_kwargs = {"low_memory": False, **read_kwargs}
save_kwargs = {"index": False}

df = _load_data(
Expand Down
31 changes: 30 additions & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
HERE,
)

DATASETS_DIR = Path(__file__).parents[1] / "datasets"


@pytest.mark.skipif(
os.environ.get("CI") != "INTEGRATION", reason="CI is not set to INTEGRATION"
Expand All @@ -24,11 +26,38 @@
(load_ufo_sightings, "ufo_sighting_data.csv"),
],
)
def test_load_func(load_func, local_file):
def test_load_func_local_save(load_func, local_file):
file: Path = HERE / local_file
file.unlink(missing_ok=True)

df = load_func(local_save=True)
df_second_time = load_func()

pd.testing.assert_frame_equal(df, df_second_time)


@pytest.mark.parametrize(
"load_func, local_file",
[
(load_chicago_bikes, "chicago-bikes.csv"),
(load_online_transactions, "online_retail_II.csv"),
(load_ufo_sightings, "ufo_sighting_data.csv"),
],
)
def test_load_func_subset(load_func, local_file: str) -> None:
actual_file: Path = DATASETS_DIR / local_file
file: Path = HERE / local_file

if file.exists():
file.unlink()

file.symlink_to(actual_file)

read_kwargs = {"nrows": 5}
df = load_func(**read_kwargs)

assert isinstance(df, pd.DataFrame)
assert len(df) == 5
assert df.dtypes.eq("datetime64[ns]").sum() >= 1

file.unlink()

0 comments on commit 1d5a08b

Please sign in to comment.