Skip to content

Commit

Permalink
move hdi function below general function
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Sep 11, 2024
1 parent e9c69c0 commit f710f25
Showing 1 changed file with 63 additions and 63 deletions.
126 changes: 63 additions & 63 deletions pymc_marketing/mmm/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,69 +243,6 @@ def _plot_hdi_selection(
SelToString = Callable[[Selection], str]


def plot_hdi(
curve: xr.DataArray,
non_grid_names: set[str],
hdi_kwargs: dict | None = None,
subplot_kwargs: dict[str, Any] | None = None,
plot_kwargs: dict[str, Any] | None = None,
axes: npt.NDArray[Axes] | None = None,
same_axes: bool = False,
colors: Iterable[str] | None = None,
legend: bool = False,
sel_to_string: SelToString | None = None,
) -> tuple[plt.Figure, npt.NDArray[Axes]]:
"""Plot hdi of the curve across coords.
Parameters
----------
curve : xr.DataArray
Curve to plot
non_grid_names : set[str]
The names to exclude from the grid. chain and draw are
excluded automatically
n : int, optional
Number of samples to plot
rng : np.random.Generator, optional
Random number generator
axes : npt.NDArray[plt.Axes], optional
Axes to plot on
subplot_kwargs : dict, optional
Additional kwargs to while creating the fig and axes
plot_kwargs : dict, optional
Kwargs for the plot function
Returns
-------
tuple[plt.Figure, npt.NDArray[plt.Axes]]
Figure and the axes
"""
get_plot_data = _create_get_hdi_plot_data(hdi_kwargs or {})
make_selection = _make_hdi_selection
plot_selection = _plot_hdi_selection

plot_kwargs = plot_kwargs or {}
plot_kwargs = {**{"alpha": 0.25}, **plot_kwargs}

return _plot_across_coord(
curve=curve,
non_grid_names=non_grid_names,
get_plot_data=get_plot_data,
make_selection=make_selection,
plot_selection=plot_selection,
subplot_kwargs=subplot_kwargs,
same_axes=same_axes,
axes=axes,
colors=colors,
legend=legend,
plot_kwargs=plot_kwargs,
patch=True,
line=False,
sel_to_string=sel_to_string,
)


def random_samples(
rng: np.random.Generator,
n: int,
Expand Down Expand Up @@ -467,6 +404,69 @@ def create_title(sel):
return fig, return_axes


def plot_hdi(
curve: xr.DataArray,
non_grid_names: set[str],
hdi_kwargs: dict | None = None,
subplot_kwargs: dict[str, Any] | None = None,
plot_kwargs: dict[str, Any] | None = None,
axes: npt.NDArray[Axes] | None = None,
same_axes: bool = False,
colors: Iterable[str] | None = None,
legend: bool = False,
sel_to_string: SelToString | None = None,
) -> tuple[plt.Figure, npt.NDArray[Axes]]:
"""Plot hdi of the curve across coords.
Parameters
----------
curve : xr.DataArray
Curve to plot
non_grid_names : set[str]
The names to exclude from the grid. chain and draw are
excluded automatically
n : int, optional
Number of samples to plot
rng : np.random.Generator, optional
Random number generator
axes : npt.NDArray[plt.Axes], optional
Axes to plot on
subplot_kwargs : dict, optional
Additional kwargs to while creating the fig and axes
plot_kwargs : dict, optional
Kwargs for the plot function
Returns
-------
tuple[plt.Figure, npt.NDArray[plt.Axes]]
Figure and the axes
"""
get_plot_data = _create_get_hdi_plot_data(hdi_kwargs or {})
make_selection = _make_hdi_selection
plot_selection = _plot_hdi_selection

plot_kwargs = plot_kwargs or {}
plot_kwargs = {**{"alpha": 0.25}, **plot_kwargs}

return _plot_across_coord(
curve=curve,
non_grid_names=non_grid_names,
get_plot_data=get_plot_data,
make_selection=make_selection,
plot_selection=plot_selection,
subplot_kwargs=subplot_kwargs,
same_axes=same_axes,
axes=axes,
colors=colors,
legend=legend,
plot_kwargs=plot_kwargs,
patch=True,
line=False,
sel_to_string=sel_to_string,
)


def plot_samples(
curve: xr.DataArray,
non_grid_names: set[str],
Expand Down

0 comments on commit f710f25

Please sign in to comment.