diff --git a/nbs/losses.pytorch.ipynb b/nbs/losses.pytorch.ipynb index efcff01a..d8d333dd 100644 --- a/nbs/losses.pytorch.ipynb +++ b/nbs/losses.pytorch.ipynb @@ -1586,8 +1586,8 @@ " alpha = alpha.expand(shape)\n", " beta = beta.expand(shape)\n", "\n", - " N = torch.poisson(rate)\n", - " gamma = torch.distributions.gamma.Gamma(N*alpha, beta)\n", + " N = torch.poisson(rate) + 1e-5\n", + " gamma = torch.distributions.gamma.Gamma(N * alpha, beta)\n", " samples = gamma.sample()\n", " samples[N==0] = 0\n", "\n", @@ -1603,29 +1603,26 @@ " return a - b\n", "\n", "def tweedie_domain_map(input: torch.Tensor):\n", - " \"\"\" Tweedie Domain Map\n", - " Maps input into distribution constraints, by construction input's \n", - " last dimension is of matching `distr_args` length.\n", - "\n", - " **Parameters:**
\n", - " `input`: tensor, of dimensions [B,T,H,theta] or [B,H,theta].
\n", + " \"\"\"\n", + " Maps output of neural network to domain of distribution loss\n", "\n", - " **Returns:**
\n", - " `(log_mu,)`: tuple with tensors of Tweedie distribution arguments.
\n", " \"\"\"\n", - " # log_mu, probs = torch.tensor_split(input, 2, dim=-1)\n", " return (input.squeeze(-1),)\n", "\n", "def tweedie_scale_decouple(output, loc=None, scale=None):\n", - " \"\"\" Tweedie Scale Decouple\n", + " \"\"\"Tweedie Scale Decouple\n", "\n", " Stabilizes model's output optimization, by learning total\n", " count and logits based on anchoring `loc`, `scale`.\n", " Also adds Tweedie domain protection to the distribution parameters.\n", " \"\"\"\n", " log_mu = output[0]\n", + " log_mu = F.softplus(log_mu)\n", + " log_mu = torch.clamp(log_mu, 1e-9, 37)\n", " if (loc is not None) and (scale is not None):\n", - " log_mu += torch.log(loc) # TODO : rho scaling\n", + " log_mu += torch.log(loc)\n", + "\n", + " log_mu = torch.clamp(log_mu, 1e-9, 37)\n", " return (log_mu,)" ] }, diff --git a/neuralforecast/losses/pytorch.py b/neuralforecast/losses/pytorch.py index a65b1c53..a713b5b3 100644 --- a/neuralforecast/losses/pytorch.py +++ b/neuralforecast/losses/pytorch.py @@ -1000,7 +1000,7 @@ def sample(self, sample_shape=torch.Size()): alpha = alpha.expand(shape) beta = beta.expand(shape) - N = torch.poisson(rate) + N = torch.poisson(rate) + 1e-5 gamma = torch.distributions.gamma.Gamma(N * alpha, beta) samples = gamma.sample() samples[N == 0] = 0 @@ -1018,17 +1018,10 @@ def log_prob(self, y_true): def tweedie_domain_map(input: torch.Tensor): - """Tweedie Domain Map - Maps input into distribution constraints, by construction input's - last dimension is of matching `distr_args` length. - - **Parameters:**
- `input`: tensor, of dimensions [B,T,H,theta] or [B,H,theta].
+ """ + Maps output of neural network to domain of distribution loss - **Returns:**
- `(log_mu,)`: tuple with tensors of Tweedie distribution arguments.
""" - # log_mu, probs = torch.tensor_split(input, 2, dim=-1) return (input.squeeze(-1),) @@ -1040,8 +1033,12 @@ def tweedie_scale_decouple(output, loc=None, scale=None): Also adds Tweedie domain protection to the distribution parameters. """ log_mu = output[0] + log_mu = F.softplus(log_mu) + log_mu = torch.clamp(log_mu, 1e-9, 37) if (loc is not None) and (scale is not None): - log_mu += torch.log(loc) # TODO : rho scaling + log_mu += torch.log(loc) + + log_mu = torch.clamp(log_mu, 1e-9, 37) return (log_mu,) # %% ../../nbs/losses.pytorch.ipynb 67