Skip to content

Commit

Permalink
implement conditional probs (#39)
Browse files Browse the repository at this point in the history
* implement conditional probs

* remove accidental

* test other index names

* remove the floating url
  • Loading branch information
wd60622 authored Jan 14, 2024
1 parent 4bdc926 commit 8dbe4a6
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 3 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ Or install directly from GitHub for the latest functionality.

## Features

https://wd60622.github.io/latent-calendar
- Integrated automatically into `pandas` with [`cal` attribute on DataFrames and Series](https://wd60622.github.io/latent-calendar/modules/extensions)
- Compatible with [`scikit-learn` pipelines and transformers](https://wd60622.github.io/latent-calendar/examples/model/sklearn-compat)
- [Transform and visualize data on a weekly calendar](https://wd60622.github.io/latent-calendar/examples/cal-attribute)
Expand Down
14 changes: 14 additions & 0 deletions latent_calendar/const.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Constants used to create the full vocabulary of the dataset."""
import calendar
from itertools import product
from typing import Dict, List, Union
Expand Down Expand Up @@ -41,6 +42,19 @@ def dicretized_hours(minutes: int) -> List[float]:
def create_full_vocab(
days_in_week: int, minutes: int, as_multiindex: bool = True
) -> Union[pd.MultiIndex, List[str]]:
"""Create the full vocabulary of the dataset.
Args:
days_in_week: Number of days in the week.
minutes: Number of minutes to discretize the hours by.
as_multiindex: Whether to return a multiindex or a list of strings.
Returns:
The full vocabulary of the dataset.
Either a MultiIndex or a list of strings.
"""

if not as_multiindex:
return [
format_dow_hour(day_of_week, hour)
Expand Down
47 changes: 47 additions & 0 deletions latent_calendar/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,29 @@ def timestamp_features(

return transformer.fit_transform(self._obj.rename(name).to_frame())

def conditional_probabilities(
self,
*,
level: Union[int, str] = 0,
) -> pd.Series:
"""Calculate conditional probabilities for each the row over the level.
Args:
level: level of the column MultiIndex.
Default 0 or day_of_week
Returns:
Series with conditional probabilities
"""

if not isinstance(self._obj.index, pd.MultiIndex):
raise ValueError(
"Series is expected to have a MultiIndex with the last column as the vocab."
)

return self._obj.div(self._obj.groupby(level=level).sum(), level=level)

def plot(
self,
*,
Expand Down Expand Up @@ -270,6 +293,30 @@ def normalize(self, kind: str) -> pd.DataFrame:

raise ValueError(f"kind must be one of ['max', 'probs'], got {kind}")

def conditional_probabilities(
self,
*,
level: Union[int, str] = 0,
) -> pd.DataFrame:
"""Calculate conditional probabilities for each row over the level.
Args:
level: level of the columns MultiIndex.
Default 0 or day_of_week
Returns:
DataFrame with conditional probabilities
"""
if not isinstance(self._obj.columns, pd.MultiIndex):
raise ValueError(
"DataFrame is expected to have a MultiIndex with the last column as the vocab."
)

return self._obj.div(
self._obj.groupby(level=level, axis=1).sum(), level=level, axis=1
)

def timestamp_features(
self,
column: str,
Expand Down
4 changes: 2 additions & 2 deletions latent_calendar/plot/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def plot_model_predictions(
"""Plot the model predictions compared to the test data.
Args:
X_to_predict: Training data for the model
X_test: Testing data for the model
X_to_predict: Data for the model
X_holdout: Holdout data for the model
model: LatentCalendar model instance
divergent: Option to change the data displayed
axes: list of 3 axes to plot this data
Expand Down
58 changes: 58 additions & 0 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@ def test_series_extensions(ser) -> None:
assert isinstance(ax, plt.Axes)


@pytest.fixture
def ser_row(ser) -> pd.Series:
return pd.Series(1, index=FULL_VOCAB)


@pytest.mark.parametrize(
"level, axis",
[
("day_of_week", 1),
(0, 1),
("hour", 0),
(1, 0),
],
)
def test_series_conditional_probabilities(ser_row, level, axis) -> None:
result = ser_row.cal.conditional_probabilities(level=level).unstack().sum(axis=axis)
# All the probabilities should sum to 1

assert (result.round() == 1).all()


@pytest.fixture
def df() -> pd.DataFrame:
"""Generate some fake data."""
Expand Down Expand Up @@ -212,3 +233,40 @@ def test_wide_dataframe_extensions(df_wide: pd.DataFrame) -> None:
pd.testing.assert_frame_equal(
df_wide.cal.sum_next_hours(hours=next_hours), df_answer
)


@pytest.fixture
def df_wide_subset() -> pd.DataFrame:
columns = pd.MultiIndex.from_tuples(
[
(0, 0),
(0, 1),
(0, 2),
(1, 0),
(1, 1),
(1, 2),
],
names=["day_of_week", "hour"],
)

data = np.ones((3, 6))
return pd.DataFrame(data, columns=columns).sort_index(axis=1)


@pytest.mark.parametrize(
"level, answer",
[
("day_of_week", 1 / 3),
(0, 1 / 3),
("hour", 1 / 2),
(1, 1 / 2),
],
)
def test_dataframe_conditional_probabilities(
df_wide_subset: pd.DataFrame, level, answer
) -> None:
result = df_wide_subset.cal.conditional_probabilities(level=level)
expected = pd.DataFrame(
answer, index=df_wide_subset.index, columns=df_wide_subset.columns
)
pd.testing.assert_frame_equal(result, expected)

0 comments on commit 8dbe4a6

Please sign in to comment.