diff --git a/docs/source/notebooks/mmm/mmm_multidimensional_example.ipynb b/docs/source/notebooks/mmm/mmm_multidimensional_example.ipynb index 8f02136d..e710c01f 100644 --- a/docs/source/notebooks/mmm/mmm_multidimensional_example.ipynb +++ b/docs/source/notebooks/mmm/mmm_multidimensional_example.ipynb @@ -44,7 +44,28 @@ "name": "stdout", "output_type": "stream", "text": [ - "Object `MultiDimensionalMMM.VanillaMultiDimensionalMMM` not found.\n" + "\u001b[0;31mInit signature:\u001b[0m\n", + "\u001b[0mVanillaMultiDimensionalMMM\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mdate_column\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mchannel_columns\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtarget_column\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0madstock\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mpymc_marketing\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmmm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomponents\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madstock\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAdstockTransformation\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0msaturation\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mpymc_marketing\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmmm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomponents\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msaturation\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSaturationTransformation\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtime_varying_intercept\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtime_varying_media\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mdims\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtuple\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmodel_config\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdict\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0msampler_config\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdict\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mvalidate_data\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mcontrol_columns\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0myearly_seasonality\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m|\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0madstock_first\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m Docstring example.\n", + "\u001b[0;31mInit docstring:\u001b[0m Define the constructor method.\n", + "\u001b[0;31mFile:\u001b[0m ~/Documents/GitHub/pymc-marketing/pymc_marketing/mmm/MultiDimensionalMMM.py\n", + "\u001b[0;31mType:\u001b[0m type\n", + "\u001b[0;31mSubclasses:\u001b[0m " ] } ], @@ -74,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -90,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -125,57 +146,57 @@ " \n", " 0\n", " 2023-01-01\n", - " USA\n", - " 321\n", - " 373\n", - " 1056\n", + " Venezuela\n", + " 293\n", + " 238\n", + " 849\n", " \n", " \n", " 1\n", " 2023-01-02\n", - " USA\n", - " 256\n", + " Venezuela\n", " 374\n", - " 774\n", + " 241\n", + " 880\n", " \n", " \n", " 2\n", " 2023-01-03\n", - " USA\n", - " 243\n", - " 385\n", - " 868\n", + " Venezuela\n", + " 479\n", + " 380\n", + " 1348\n", " \n", " \n", " 3\n", " 2023-01-04\n", - " USA\n", - " 450\n", - " 234\n", - " 830\n", + " Venezuela\n", + " 346\n", + " 315\n", + " 1084\n", " \n", " \n", " 4\n", " 2023-01-05\n", - " USA\n", - " 354\n", - " 339\n", - " 950\n", + " Venezuela\n", + " 411\n", + " 393\n", + " 1220\n", " \n", " \n", "\n", "" ], "text/plain": [ - " date country a b target\n", - "0 2023-01-01 USA 321 373 1056\n", - "1 2023-01-02 USA 256 374 774\n", - "2 2023-01-03 USA 243 385 868\n", - "3 2023-01-04 USA 450 234 830\n", - "4 2023-01-05 USA 354 339 950" + " date country a b target\n", + "0 2023-01-01 Venezuela 293 238 849\n", + "1 2023-01-02 Venezuela 374 241 880\n", + "2 2023-01-03 Venezuela 479 380 1348\n", + "3 2023-01-04 Venezuela 346 315 1084\n", + "4 2023-01-05 Venezuela 411 393 1220" ] }, - "execution_count": 56, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -202,7 +223,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -226,7 +247,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "03f86f4cf36c4bcfbf7c919427aff42c", + "model_id": "91da39d084814abc8345b437c66d1c9d", "version_major": 2, "version_minor": 0 }, @@ -278,8 +299,8 @@ "
\n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-09-05T14:26:26.203327+00:00
    arviz_version :
    0.19.0
    inference_library :
    pymc
    inference_library_version :
    5.15.1
    sampling_time :
    3.0827558040618896
    tuning_steps :
    1000

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • country
    PandasIndex
    PandasIndex(Index(['Argentina', 'Chile', 'Colombia', 'Venezuela'], dtype='object', name='country'))
  • created_at :
    2024-09-05T14:26:26.205763+00:00
    arviz_version :
    0.19.0
    inference_library :
    pymc
    inference_library_version :
    5.15.1

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • country
    PandasIndex
    PandasIndex(Index(['Argentina', 'Chile', 'Colombia', 'Venezuela'], dtype='object', name='country'))
  • channel
    PandasIndex
    PandasIndex(Index(['a', 'b'], dtype='object', name='channel'))
  • created_at :
    2024-09-05T14:26:26.206696+00:00
    arviz_version :
    0.19.0
    inference_library :
    pymc
    inference_library_version :
    5.15.1

  • \n", " \n", " \n", " \n", @@ -2583,18 +2614,18 @@ "\t> constant_data" ] }, - "execution_count": 57, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "mmm.fit(X=df.drop(columns=\"target\"), y=df.drop(columnas=channels))" + "mmm.fit(X=df.drop(columns=\"target\"), y=df.drop(columns=channels))" ] }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -2606,10 +2637,10 @@ "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", "clusterdate (10) x country (4) x channel (2)\n", "\n", @@ -2627,8 +2658,8 @@ "\n", "\n", "clusterchannel (2)\n", - "\n", - "channel (2)\n", + "\n", + "channel (2)\n", "\n", "\n", "\n", @@ -2647,13 +2678,13 @@ "Deterministic\n", "\n", "\n", - "\n", + "\n", "channel_data->channel_contributions\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "mu\n", "\n", "mu\n", @@ -2661,19 +2692,11 @@ "Deterministic\n", "\n", "\n", - "\n", + "\n", "channel_contributions->mu\n", "\n", "\n", "\n", - "\n", - "\n", - "target\n", - "\n", - "target\n", - "~\n", - "Data\n", - "\n", "\n", "\n", "y\n", @@ -2683,11 +2706,19 @@ "Normal\n", "\n", "\n", - "\n", + "\n", "mu->y\n", "\n", "\n", "\n", + "\n", + "\n", + "target\n", + "\n", + "target\n", + "~\n", + "Data\n", + "\n", "\n", "\n", "y->target\n", @@ -2703,51 +2734,51 @@ "Normal\n", "\n", "\n", - "\n", + "\n", "intercept->mu\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "saturation_beta\n", - "\n", - "saturation_beta\n", - "~\n", - "HalfNormal\n", + "adstock_alpha\n", + "\n", + "adstock_alpha\n", + "~\n", + "Beta\n", "\n", - "\n", - "\n", - "saturation_beta->channel_contributions\n", - "\n", - "\n", + "\n", + "\n", + "adstock_alpha->channel_contributions\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "adstock_alpha\n", - "\n", - "adstock_alpha\n", - "~\n", - "Beta\n", + "saturation_beta\n", + "\n", + "saturation_beta\n", + "~\n", + "HalfNormal\n", "\n", - "\n", + "\n", "\n", - "adstock_alpha->channel_contributions\n", - "\n", - "\n", + "saturation_beta->channel_contributions\n", + "\n", + "\n", "\n", "\n", "\n", "saturation_lam\n", - "\n", - "saturation_lam\n", - "~\n", - "Gamma\n", + "\n", + "saturation_lam\n", + "~\n", + "Gamma\n", "\n", "\n", "\n", "saturation_lam->channel_contributions\n", - "\n", + "\n", "\n", "\n", "\n", @@ -2759,7 +2790,7 @@ "HalfNormal\n", "\n", "\n", - "\n", + "\n", "y_sigma->y\n", "\n", "\n", @@ -2768,10 +2799,10 @@ "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 58, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -2782,7 +2813,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -2800,7 +2831,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -2824,7 +2855,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "46ce03cb22ce4498a5a07e14cc19549e", + "model_id": "37cc1e5641e64bdbaa0f6c7d8408c9a9", "version_major": 2, "version_minor": 0 }, @@ -2876,8 +2907,8 @@ "
    \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-09-05T14:26:35.964194+00:00
    arviz_version :
    0.19.0
    inference_library :
    pymc
    inference_library_version :
    5.15.1
    sampling_time :
    2.6974968910217285
    tuning_steps :
    1000

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-09-05T14:26:35.966608+00:00
    arviz_version :
    0.19.0
    inference_library :
    pymc
    inference_library_version :
    5.15.1

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-09-05T14:26:35.967389+00:00
    arviz_version :
    0.19.0
    inference_library :
    pymc
    inference_library_version :
    5.15.1

  • \n", " \n", " \n", " \n", @@ -5107,7 +5138,7 @@ "\t> constant_data" ] }, - "execution_count": 61, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -5118,7 +5149,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -5130,10 +5161,10 @@ "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", "clusterdate (10) x channel (2)\n", "\n", @@ -5146,8 +5177,8 @@ "\n", "\n", "clusterchannel (2)\n", - "\n", - "channel (2)\n", + "\n", + "channel (2)\n", "\n", "\n", "\n", @@ -5166,13 +5197,13 @@ "Deterministic\n", "\n", "\n", - "\n", + "\n", "channel_data->channel_contributions\n", "\n", "\n", "\n", "\n", - "\n", + "\n", "mu\n", "\n", "mu\n", @@ -5180,19 +5211,11 @@ "Deterministic\n", "\n", "\n", - "\n", + "\n", "channel_contributions->mu\n", "\n", "\n", "\n", - "\n", - "\n", - "target\n", - "\n", - "target\n", - "~\n", - "Data\n", - "\n", "\n", "\n", "y\n", @@ -5202,11 +5225,19 @@ "Normal\n", "\n", "\n", - "\n", + "\n", "mu->y\n", "\n", "\n", "\n", + "\n", + "\n", + "target\n", + "\n", + "target\n", + "~\n", + "Data\n", + "\n", "\n", "\n", "y->target\n", @@ -5222,7 +5253,7 @@ "HalfNormal\n", "\n", "\n", - "\n", + "\n", "y_sigma->y\n", "\n", "\n", @@ -5236,61 +5267,61 @@ "Normal\n", "\n", "\n", - "\n", + "\n", "intercept->mu\n", "\n", "\n", "\n", - "\n", + "\n", "\n", - "saturation_beta\n", - "\n", - "saturation_beta\n", - "~\n", - "HalfNormal\n", + "adstock_alpha\n", + "\n", + "adstock_alpha\n", + "~\n", + "Beta\n", "\n", - "\n", - "\n", - "saturation_beta->channel_contributions\n", - "\n", - "\n", + "\n", + "\n", + "adstock_alpha->channel_contributions\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "adstock_alpha\n", - "\n", - "adstock_alpha\n", - "~\n", - "Beta\n", + "saturation_beta\n", + "\n", + "saturation_beta\n", + "~\n", + "HalfNormal\n", "\n", - "\n", + "\n", "\n", - "adstock_alpha->channel_contributions\n", - "\n", - "\n", + "saturation_beta->channel_contributions\n", + "\n", + "\n", "\n", "\n", "\n", "saturation_lam\n", - "\n", - "saturation_lam\n", - "~\n", - "Gamma\n", + "\n", + "saturation_lam\n", + "~\n", + "Gamma\n", "\n", "\n", "\n", "saturation_lam->channel_contributions\n", - "\n", + "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 62, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" }