Skip to content

Commit

Permalink
fixing a bug in the training param (wrong loading due to copy paste e…
Browse files Browse the repository at this point in the history
…rror
  • Loading branch information
BDonnot committed Jun 29, 2020
1 parent 1431433 commit a0d2750
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
10 changes: 5 additions & 5 deletions l2rpn_baselines/test/test_trainingparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ def test_loadback(self):

def test_loadback_modified(self):
for el in TrainingParam._int_attr:
self._test_attr(el, 1)
self._test_attr(el, None)
self._aux_test_attr(el, 1)
self._aux_test_attr(el, None)
for el in TrainingParam._float_attr:
self._test_attr(el, 1.)
self._test_attr(el, None)
self._aux_test_attr(el, 1.)
self._aux_test_attr(el, None)

def _test_attr(self, attr, val):
def _aux_test_attr(self, attr, val):
"""
test that i can modify an attribut and then load the training parameters the correct way
"""
Expand Down
34 changes: 22 additions & 12 deletions l2rpn_baselines/utils/TrainingParam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class TrainingParam(object):
----------
buffer_size: ``int``
Size of the replay buffer
minibatch_size: ``int``
Size of the training minibatch
update_freq: ``int``
Expand Down Expand Up @@ -149,13 +150,13 @@ def __init__(self,

self.random_sample_datetime_start = random_sample_datetime_start

self.buffer_size = buffer_size
self.minibatch_size = minibatch_size
self.min_observation = min_observation
self.buffer_size = int(buffer_size)
self.minibatch_size = int(minibatch_size)
self.min_observation = int(min_observation)
self._final_epsilon = float(final_epsilon) # have on average 1 random action per day of approx 288 timesteps at the end (never kill completely the exploration)
self._initial_epsilon = float(initial_epsilon)
self.step_for_final_epsilon = float(step_for_final_epsilon)
self.lr = lr
self.lr = float(lr)
self.lr_decay_steps = float(lr_decay_steps)
self.lr_decay_rate = float(lr_decay_rate)

Expand All @@ -164,14 +165,15 @@ def __init__(self,
self.max_value_grad = max_value_grad
self.max_loss = max_loss

self.last_step = 0
self.last_step = int(0)
self.num_frames = int(num_frames)
self.discount_factor = float(discount_factor)
self.tau = float(tau)
self.update_freq = update_freq
self.min_iter = min_iter
self.max_iter = max_iter
self._update_nb_iter = update_nb_iter
self.update_freq = int(update_freq)
self.min_iter = int(min_iter)
self.max_iter = int(max_iter)
self._1_update_nb_iter = None
self._update_nb_iter = int(update_nb_iter)
if step_increase_nb_iter is None:
# 0 and None have the same effect: it disable the feature
step_increase_nb_iter = 0
Expand Down Expand Up @@ -285,7 +287,7 @@ def from_dict(tmp):
else:
setattr(res, attr_nm, None)
res.update_nb_iter = res._update_nb_iter
res.update_nb_iter = res._initial_epsilon
res.initial_epsilon = res._initial_epsilon
res._compute_exp_facto()
return res

Expand Down Expand Up @@ -320,7 +322,15 @@ def __eq__(self, other):
for el in self._int_attr:
me_ = getattr(self, el)
oth_ = getattr(other, el)
if me_ != oth_:
if me_ is None and oth_ is not None:
res = False
break
if oth_ is None and me_ is not None:
res = False
break
if me_ is None and oth_ is None:
continue
if int(me_) != int(oth_):
res = False
break
if res:
Expand All @@ -335,7 +345,7 @@ def __eq__(self, other):
break
if me_ is None and oth_ is None:
continue
if abs(float(me_) != float(oth_)) > self._tol_float_equal:
if abs(float(me_) - float(oth_)) > self._tol_float_equal:
res = False
break
return res

0 comments on commit a0d2750

Please sign in to comment.