From 306bb30bdd2b371792a521efaa7fde124e363a6f Mon Sep 17 00:00:00 2001 From: Thomas Aarholt Date: Fri, 4 Oct 2024 11:12:45 +0200 Subject: [PATCH] Use function-scoped new_dims to handle type hint varying throughout function We don't want to allow the user to pass a `dims=[None, None]` to our function, but current behaviour set `dims=[None] * N` at the end of `determine_coords`. To handle this, I created a `new_dims` with a larger type scope which matches the return type of `dims` in `determine_coords`. Then I did the same within def Data to support this new type hint. --- pymc/data.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index a0d6893cb1..e136868b99 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -218,9 +218,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: def determine_coords( model, value: pd.DataFrame | pd.Series | xr.DataArray, - dims: Sequence[str | None] | None = None, + dims: Sequence[str] | None = None, coords: dict[str, Sequence | np.ndarray] | None = None, -) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None]]: +) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]: """Determines coordinate values from data or the model (via ``dims``).""" if coords is None: coords = {} @@ -265,9 +265,10 @@ def determine_coords( if dims is None: # TODO: Also determine dim names from the index - dims = [None] * np.ndim(value) - - return coords, dims + new_dims: Sequence[str] | Sequence[None] = [None] * np.ndim(value) + else: + new_dims = dims + return coords, new_dims def ConstantData( @@ -363,7 +364,7 @@ def Data( The name for this variable. value : array_like or pandas.Series, pandas.Dataframe A value to associate with this variable. - dims : str or tuple of str, optional + dims : str, tuple of str or tuple of None, optional Dimension names of the random variables (as opposed to the shapes of these random variables). Use this when ``value`` is a pandas Series or DataFrame. The ``dims`` will then be the name of the Series / DataFrame's columns. See ArviZ @@ -448,13 +449,16 @@ def Data( expected=x.ndim, ) + new_dims: Sequence[str] | Sequence[None] | None if infer_dims_and_coords: - coords, dims = determine_coords(model, value, dims) + coords, new_dims = determine_coords(model, value, dims) + else: + new_dims = dims - if dims: + if new_dims: xshape = x.shape # Register new dimension lengths - for d, dname in enumerate(dims): + for d, dname in enumerate(new_dims): if dname not in model.dim_lengths: model.add_coord( name=dname, @@ -464,6 +468,6 @@ def Data( length=xshape[d], ) - model.register_data_var(x, dims=dims) + model.register_data_var(x, dims=new_dims) return x