Skip to content

Commit

Permalink
Safe exp in parsing (#33640)
Browse files Browse the repository at this point in the history
  • Loading branch information
haraschax authored Sep 24, 2024
1 parent 6dfc154 commit deb6b72
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions selfdrive/modeld/parse_model_outputs.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import numpy as np
from openpilot.selfdrive.modeld.constants import ModelConstants

def sigmoid(x):
def safe_exp(x, out=None):
# -11 is around 10**14, more causes float16 overflow
clipped_x = np.clip(x, -11, np.inf)
return 1. / (1. + np.exp(-clipped_x))
return np.exp(np.clip(x, -np.inf, 11), out=out)

def sigmoid(x):
return 1. / (1. + safe_exp(-x))

def softmax(x, axis=-1):
x -= np.max(x, axis=axis, keepdims=True)
if x.dtype == np.float32 or x.dtype == np.float64:
np.exp(x, out=x)
safe_exp(x, out=x)
else:
x = np.exp(x)
x = safe_exp(x)
x /= np.sum(x, axis=axis, keepdims=True)
return x

Expand Down Expand Up @@ -46,7 +48,7 @@ def parse_mdn(self, name, outs, in_N=0, out_N=1, out_shape=None):

n_values = (raw.shape[2] - out_N)//2
pred_mu = raw[:,:,:n_values]
pred_std = np.exp(raw[:,:,n_values: 2*n_values])
pred_std = safe_exp(raw[:,:,n_values: 2*n_values])

if in_N > 1:
weights = np.zeros((raw.shape[0], in_N, out_N), dtype=raw.dtype)
Expand Down

0 comments on commit deb6b72

Please sign in to comment.