diff --git a/docs/source/notebooks/mmm/mmm_multidimensional_example.ipynb b/docs/source/notebooks/mmm/mmm_multidimensional_example.ipynb index 8f02136d..e710c01f 100644 --- a/docs/source/notebooks/mmm/mmm_multidimensional_example.ipynb +++ b/docs/source/notebooks/mmm/mmm_multidimensional_example.ipynb @@ -44,7 +44,28 @@ "name": "stdout", "output_type": "stream", "text": [ - "Object `MultiDimensionalMMM.VanillaMultiDimensionalMMM` not found.\n" + "\u001b[0;31mInit signature:\u001b[0m\n", + "\u001b[0mVanillaMultiDimensionalMMM\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mdate_column\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mchannel_columns\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtarget_column\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0madstock\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mpymc_marketing\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmmm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomponents\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madstock\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAdstockTransformation\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0msaturation\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mpymc_marketing\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmmm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomponents\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msaturation\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSaturationTransformation\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtime_varying_intercept\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtime_varying_media\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mdims\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtuple\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmodel_config\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdict\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0msampler_config\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdict\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mvalidate_data\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mcontrol_columns\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0myearly_seasonality\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0madstock_first\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m Docstring example.\n", + "\u001b[0;31mInit docstring:\u001b[0m Define the constructor method.\n", + "\u001b[0;31mFile:\u001b[0m ~/Documents/GitHub/pymc-marketing/pymc_marketing/mmm/MultiDimensionalMMM.py\n", + "\u001b[0;31mType:\u001b[0m type\n", + "\u001b[0;31mSubclasses:\u001b[0m " ] } ], @@ -74,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -90,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -125,57 +146,57 @@ "
PandasIndex(Index(['Argentina', 'Chile', 'Colombia', 'Venezuela'], dtype='object', name='country'))
<xarray.Dataset> Size: 512B\n", + "<xarray.Dataset> Size: 544B\n", "Dimensions: (date: 10, country: 4)\n", "Coordinates:\n", " * date (date) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10\n", - " * country (country) <U7 112B 'Canada' 'France' 'Germany' 'USA'\n", + " * country (country) <U9 144B 'Argentina' 'Chile' 'Colombia' 'Venezuela'\n", "Data variables:\n", - " y (date, country) float64 320B 598.0 883.0 880.0 ... 904.0 1.094e+03\n", + " y (date, country) float64 320B 794.0 608.0 703.0 ... 609.0 536.0\n", "Attributes:\n", - " created_at: 2024-09-04T17:09:09.410542+00:00\n", + " created_at: 2024-09-05T14:26:26.205763+00:00\n", " arviz_version: 0.19.0\n", " inference_library: pymc\n", - " inference_library_version: 5.15.1
PandasIndex(Index(['Argentina', 'Chile', 'Colombia', 'Venezuela'], dtype='object', name='country'))
<xarray.Dataset> Size: 680B\n", + "<xarray.Dataset> Size: 712B\n", "Dimensions: (date: 10, country: 4, channel: 2)\n", "Coordinates:\n", " * date (date) datetime64[ns] 80B 2023-01-01 2023-01-02 ... 2023-01-10\n", - " * country (country) <U7 112B 'Canada' 'France' 'Germany' 'USA'\n", + " * country (country) <U9 144B 'Argentina' 'Chile' 'Colombia' 'Venezuela'\n", " * channel (channel) <U1 8B 'a' 'b'\n", "Data variables:\n", - " channel_data (date, country, channel) int32 320B 295 143 226 ... 404 283\n", - " target (date, country) int32 160B 598 883 880 1056 ... 748 904 1094\n", + " channel_data (date, country, channel) int32 320B 202 246 100 ... 125 263\n", + " target (date, country) int32 160B 794 608 703 849 ... 1134 609 536\n", "Attributes:\n", - " created_at: 2024-09-04T17:09:09.411481+00:00\n", + " created_at: 2024-09-05T14:26:26.206696+00:00\n", " arviz_version: 0.19.0\n", " inference_library: pymc\n", - " inference_library_version: 5.15.1
PandasIndex(Index(['Argentina', 'Chile', 'Colombia', 'Venezuela'], dtype='object', name='country'))
PandasIndex(Index(['a', 'b'], dtype='object', name='channel'))