-
Notifications
You must be signed in to change notification settings - Fork 0
/
ctrnn.py
97 lines (77 loc) · 3.41 KB
/
ctrnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import division
import math
class CTRNNNode:
def __init__(self, gain=0, tau=0):
self.parents = []
self.weights = []
self.gain = gain
self.tau = tau
self.output = 0
self.next_output = 0
self.y = 0
def add_parent(self, parent, weight):
'''Add a parent, duh.'''
self.parents.append(parent)
self.weights.append(weight)
def reset(self):
'''Reset all internal state'''
self.output = 0
self.next_output = 0
self.y = 0
def timestep(self):
'''Compute next output level, and update internal state.
Should not be used for input nodes.'''
s = sum([self.weights[i] * self.parents[i].output for i in xrange(len(self.parents))])
self.y += (s - self.y) / self.tau
self.next_output = 1/(1 + math.exp(self.y*self.gain))
def update_output(self):
'''Actually update output levels. Makes sure all nodes see output from the same timestep.'''
self.output = self.next_output
class CTRNN:
'''Implements a single instance of a Beer-type Continuous-Time Recurrent Neural Network'''
def __init__(self, num_input, num_hidden, num_output,
weight_list, bias_list, gain_list, tau_list):
'''CTRNN constructor.
Order of the weight list is complicated. Check the code.'''
# Create all the nodes
self.bias_node = CTRNNNode()
self.bias_node.output = 1
self.input_nodes = [CTRNNNode() for i in xrange(num_input)]
self.hidden_nodes = []
for i in xrange(num_hidden):
self.hidden_nodes.append(CTRNNNode(gain_list[i], tau_list[i]))
self.output_nodes = []
for i in xrange(num_output):
self.output_nodes.append(CTRNNNode(gain_list[num_hidden+i], tau_list[num_hidden+i]))
# Create connections between them
for node in self.hidden_nodes:
for parent in self.input_nodes:
node.add_parent(parent, weight_list.pop())
for parent in self.hidden_nodes:
node.add_parent(parent, weight_list.pop())
node.add_parent(self.bias_node, bias_list.pop())
for node in self.output_nodes:
for parent in self.hidden_nodes:
node.add_parent(parent, weight_list.pop())
for parent in self.output_nodes:
node.add_parent(parent, weight_list.pop())
node.add_parent(self.bias_node, bias_list.pop())
def reset(self):
'''Reset internal state of each node'''
for node in self.hidden_nodes:
node.reset()
for node in self.output_nodes:
node.reset()
def timestep(self, sensor_input):
'''Compute new output levels for all nodes, and return output from output nodes.'''
for i in xrange(len(sensor_input)):
self.input_nodes[i].output = sensor_input[i]
for node in self.hidden_nodes:
node.timestep()
for node in self.hidden_nodes:
node.update_output()
for node in self.output_nodes:
node.timestep()
for node in self.output_nodes:
node.update_output()
return [node.output for node in self.output_nodes]