diff --git a/pymc/data.py b/pymc/data.py index a0d6893cb1..e136868b99 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -218,9 +218,9 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: def determine_coords( model, value: pd.DataFrame | pd.Series | xr.DataArray, - dims: Sequence[str | None] | None = None, + dims: Sequence[str] | None = None, coords: dict[str, Sequence | np.ndarray] | None = None, -) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str | None]]: +) -> tuple[dict[str, Sequence | np.ndarray], Sequence[str] | Sequence[None]]: """Determines coordinate values from data or the model (via ``dims``).""" if coords is None: coords = {} @@ -265,9 +265,10 @@ def determine_coords( if dims is None: # TODO: Also determine dim names from the index - dims = [None] * np.ndim(value) - - return coords, dims + new_dims: Sequence[str] | Sequence[None] = [None] * np.ndim(value) + else: + new_dims = dims + return coords, new_dims def ConstantData( @@ -363,7 +364,7 @@ def Data( The name for this variable. value : array_like or pandas.Series, pandas.Dataframe A value to associate with this variable. - dims : str or tuple of str, optional + dims : str, tuple of str or tuple of None, optional Dimension names of the random variables (as opposed to the shapes of these random variables). Use this when ``value`` is a pandas Series or DataFrame. The ``dims`` will then be the name of the Series / DataFrame's columns. See ArviZ @@ -448,13 +449,16 @@ def Data( expected=x.ndim, ) + new_dims: Sequence[str] | Sequence[None] | None if infer_dims_and_coords: - coords, dims = determine_coords(model, value, dims) + coords, new_dims = determine_coords(model, value, dims) + else: + new_dims = dims - if dims: + if new_dims: xshape = x.shape # Register new dimension lengths - for d, dname in enumerate(dims): + for d, dname in enumerate(new_dims): if dname not in model.dim_lengths: model.add_coord( name=dname, @@ -464,6 +468,6 @@ def Data( length=xshape[d], ) - model.register_data_var(x, dims=dims) + model.register_data_var(x, dims=new_dims) return x