Skip to content

Commit

Permalink
Use function-scoped new_dims to handle type hint varying throughout f…
Browse files Browse the repository at this point in the history
…unction

We don't want to allow the user to pass a `dims=[None, None]` to our function, but current behaviour
set `dims=[None] * N` at the end of `determine_coords`.

To handle this, I created a `new_dims` with a larger type scope which matches
the return type of `dims` in `determine_coords`.

Then I did the same within def Data to support this new type hint.
  • Loading branch information
thomasaarholt committed Oct 4, 2024
1 parent 840ffea commit 306bb30
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size:
def determine_coords(
model,
value: pd.DataFrame | pd.Series | xr.DataArray,
dims: Sequence[str | None] | None = None,
dims: Sequence[str] | None = None,
coords: dict[str, Sequence | np.ndarray] | None = None,
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None]]:
) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]:
"""Determines coordinate values from data or the model (via ``dims``)."""
if coords is None:
coords = {}
Expand Down Expand Up @@ -265,9 +265,10 @@ def determine_coords(

if dims is None:
# TODO: Also determine dim names from the index
dims = [None] * np.ndim(value)

return coords, dims
new_dims: Sequence[str] | Sequence[None] = [None] * np.ndim(value)
else:
new_dims = dims
return coords, new_dims


def ConstantData(
Expand Down Expand Up @@ -363,7 +364,7 @@ def Data(
The name for this variable.
value : array_like or pandas.Series, pandas.Dataframe
A value to associate with this variable.
dims : str or tuple of str, optional
dims : str, tuple of str or tuple of None, optional
Dimension names of the random variables (as opposed to the shapes of these
random variables). Use this when ``value`` is a pandas Series or DataFrame. The
``dims`` will then be the name of the Series / DataFrame's columns. See ArviZ
Expand Down Expand Up @@ -448,13 +449,16 @@ def Data(
expected=x.ndim,
)

new_dims: Sequence[str] | Sequence[None] | None
if infer_dims_and_coords:
coords, dims = determine_coords(model, value, dims)
coords, new_dims = determine_coords(model, value, dims)
else:
new_dims = dims

if dims:
if new_dims:
xshape = x.shape
# Register new dimension lengths
for d, dname in enumerate(dims):
for d, dname in enumerate(new_dims):
if dname not in model.dim_lengths:
model.add_coord(
name=dname,
Expand All @@ -464,6 +468,6 @@ def Data(
length=xshape[d],
)

model.register_data_var(x, dims=dims)
model.register_data_var(x, dims=new_dims)

return x

0 comments on commit 306bb30

Please sign in to comment.