Skip to content

Commit

Permalink
fix(scalers): remove cast to float (#1115)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Aug 23, 2024
1 parent 6aed50e commit dbf02a4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions nbs/common.scalers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
" **Returns:**<br>\n",
" `x_median`: torch.Tensor with normalized values.\n",
" \"\"\"\n",
" x_nan = x.float().masked_fill(mask<1, float(\"nan\"))\n",
" x_nan = x.masked_fill(mask<1, float(\"nan\"))\n",
" x_median, _ = x_nan.nanmedian(dim=dim, keepdim=keepdim)\n",
" x_median = torch.nan_to_num(x_median, nan=0.0)\n",
" return x_median\n",
Expand All @@ -150,7 +150,7 @@
" **Returns:**<br>\n",
" `x_mean`: torch.Tensor with normalized values.\n",
" \"\"\"\n",
" x_nan = x.float().masked_fill(mask<1, float(\"nan\"))\n",
" x_nan = x.masked_fill(mask<1, float(\"nan\"))\n",
" x_mean = x_nan.nanmean(dim=dim, keepdim=keepdim)\n",
" x_mean = torch.nan_to_num(x_mean, nan=0.0)\n",
" return x_mean"
Expand Down
4 changes: 2 additions & 2 deletions neuralforecast/common/_scalers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def masked_median(x, mask, dim=-1, keepdim=True):
**Returns:**<br>
`x_median`: torch.Tensor with normalized values.
"""
x_nan = x.float().masked_fill(mask < 1, float("nan"))
x_nan = x.masked_fill(mask < 1, float("nan"))
x_median, _ = x_nan.nanmedian(dim=dim, keepdim=keepdim)
x_median = torch.nan_to_num(x_median, nan=0.0)
return x_median
Expand All @@ -49,7 +49,7 @@ def masked_mean(x, mask, dim=-1, keepdim=True):
**Returns:**<br>
`x_mean`: torch.Tensor with normalized values.
"""
x_nan = x.float().masked_fill(mask < 1, float("nan"))
x_nan = x.masked_fill(mask < 1, float("nan"))
x_mean = x_nan.nanmean(dim=dim, keepdim=keepdim)
x_mean = torch.nan_to_num(x_mean, nan=0.0)
return x_mean
Expand Down

0 comments on commit dbf02a4

Please sign in to comment.