diff --git a/deepy/layers/recurrent.py b/deepy/layers/recurrent.py index afb39a2..b83487c 100644 --- a/deepy/layers/recurrent.py +++ b/deepy/layers/recurrent.py @@ -92,17 +92,18 @@ def compute_step(self, state, lstm_cell=None, input=None, additional_inputs=None return outputs @neural_computation - def get_initial_states(self, input_var): + def get_initial_states(self, input_var, init_state=None): """ :type input_var: T.var :rtype: dict """ initial_states = {} for state in self.state_names: - if self._input_type == 'sequence' and input_var.ndim == 2: - init_state = T.alloc(np.cast[FLOATX](0.), self.hidden_size) - else: - init_state = T.alloc(np.cast[FLOATX](0.), input_var.shape[0], self.hidden_size) + if state != "state" or not init_state: + if self._input_type == 'sequence' and input_var.ndim == 2: + init_state = T.alloc(np.cast[FLOATX](0.), self.hidden_size) + else: + init_state = T.alloc(np.cast[FLOATX](0.), input_var.shape[0], self.hidden_size) initial_states[state] = init_state return initial_states