Skip to content

Commit

Permalink
[FIX] Fix Tweedie loss (#1164)
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint authored Oct 10, 2024
1 parent 137e3a8 commit 69706e4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 24 deletions.
23 changes: 10 additions & 13 deletions nbs/losses.pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1586,8 +1586,8 @@
" alpha = alpha.expand(shape)\n",
" beta = beta.expand(shape)\n",
"\n",
" N = torch.poisson(rate)\n",
" gamma = torch.distributions.gamma.Gamma(N*alpha, beta)\n",
" N = torch.poisson(rate) + 1e-5\n",
" gamma = torch.distributions.gamma.Gamma(N * alpha, beta)\n",
" samples = gamma.sample()\n",
" samples[N==0] = 0\n",
"\n",
Expand All @@ -1603,29 +1603,26 @@
" return a - b\n",
"\n",
"def tweedie_domain_map(input: torch.Tensor):\n",
" \"\"\" Tweedie Domain Map\n",
" Maps input into distribution constraints, by construction input's \n",
" last dimension is of matching `distr_args` length.\n",
"\n",
" **Parameters:**<br>\n",
" `input`: tensor, of dimensions [B,T,H,theta] or [B,H,theta].<br>\n",
" \"\"\"\n",
" Maps output of neural network to domain of distribution loss\n",
"\n",
" **Returns:**<br>\n",
" `(log_mu,)`: tuple with tensors of Tweedie distribution arguments.<br>\n",
" \"\"\"\n",
" # log_mu, probs = torch.tensor_split(input, 2, dim=-1)\n",
" return (input.squeeze(-1),)\n",
"\n",
"def tweedie_scale_decouple(output, loc=None, scale=None):\n",
" \"\"\" Tweedie Scale Decouple\n",
" \"\"\"Tweedie Scale Decouple\n",
"\n",
" Stabilizes model's output optimization, by learning total\n",
" count and logits based on anchoring `loc`, `scale`.\n",
" Also adds Tweedie domain protection to the distribution parameters.\n",
" \"\"\"\n",
" log_mu = output[0]\n",
" log_mu = F.softplus(log_mu)\n",
" log_mu = torch.clamp(log_mu, 1e-9, 37)\n",
" if (loc is not None) and (scale is not None):\n",
" log_mu += torch.log(loc) # TODO : rho scaling\n",
" log_mu += torch.log(loc)\n",
"\n",
" log_mu = torch.clamp(log_mu, 1e-9, 37)\n",
" return (log_mu,)"
]
},
Expand Down
19 changes: 8 additions & 11 deletions neuralforecast/losses/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def sample(self, sample_shape=torch.Size()):
alpha = alpha.expand(shape)
beta = beta.expand(shape)

N = torch.poisson(rate)
N = torch.poisson(rate) + 1e-5
gamma = torch.distributions.gamma.Gamma(N * alpha, beta)
samples = gamma.sample()
samples[N == 0] = 0
Expand All @@ -1018,17 +1018,10 @@ def log_prob(self, y_true):


def tweedie_domain_map(input: torch.Tensor):
"""Tweedie Domain Map
Maps input into distribution constraints, by construction input's
last dimension is of matching `distr_args` length.
**Parameters:**<br>
`input`: tensor, of dimensions [B,T,H,theta] or [B,H,theta].<br>
"""
Maps output of neural network to domain of distribution loss
**Returns:**<br>
`(log_mu,)`: tuple with tensors of Tweedie distribution arguments.<br>
"""
# log_mu, probs = torch.tensor_split(input, 2, dim=-1)
return (input.squeeze(-1),)


Expand All @@ -1040,8 +1033,12 @@ def tweedie_scale_decouple(output, loc=None, scale=None):
Also adds Tweedie domain protection to the distribution parameters.
"""
log_mu = output[0]
log_mu = F.softplus(log_mu)
log_mu = torch.clamp(log_mu, 1e-9, 37)
if (loc is not None) and (scale is not None):
log_mu += torch.log(loc) # TODO : rho scaling
log_mu += torch.log(loc)

log_mu = torch.clamp(log_mu, 1e-9, 37)
return (log_mu,)

# %% ../../nbs/losses.pytorch.ipynb 67
Expand Down

0 comments on commit 69706e4

Please sign in to comment.