diff --git a/lightweight_mmm/plot.py b/lightweight_mmm/plot.py index cd1c74f..8d9a101 100644 --- a/lightweight_mmm/plot.py +++ b/lightweight_mmm/plot.py @@ -805,6 +805,7 @@ def plot_media_channel_posteriors( channel_names = np.arange(np.shape(media_channel_posteriors)[1]) fig, axes = plt.subplots( nrows=n_media_channels, ncols=n_geos, figsize=fig_size) + media_channel_posteriors = np.asarray(media_channel_posteriors) for channel_i, channel_axis in enumerate(axes): if isinstance(channel_axis, np.ndarray): for geo_i, geo_axis in enumerate(channel_axis):