Skip to content

Commit

Permalink
Enhance recurrent
Browse files Browse the repository at this point in the history
  • Loading branch information
zomux committed Nov 17, 2016
1 parent bf2b0cb commit 3943682
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions deepy/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,18 @@ def compute_step(self, state, lstm_cell=None, input=None, additional_inputs=None
return outputs

@neural_computation
def get_initial_states(self, input_var):
def get_initial_states(self, input_var, init_state=None):
"""
:type input_var: T.var
:rtype: dict
"""
initial_states = {}
for state in self.state_names:
if self._input_type == 'sequence' and input_var.ndim == 2:
init_state = T.alloc(np.cast[FLOATX](0.), self.hidden_size)
else:
init_state = T.alloc(np.cast[FLOATX](0.), input_var.shape[0], self.hidden_size)
if state != "state" or not init_state:
if self._input_type == 'sequence' and input_var.ndim == 2:
init_state = T.alloc(np.cast[FLOATX](0.), self.hidden_size)
else:
init_state = T.alloc(np.cast[FLOATX](0.), input_var.shape[0], self.hidden_size)
initial_states[state] = init_state
return initial_states

Expand Down

0 comments on commit 3943682

Please sign in to comment.