-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
97 lines (75 loc) · 2.89 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from deap.gp import graph
from networkx.drawing.nx_agraph import graphviz_layout
import matplotlib.pyplot as plt
import networkx as nx
from collections import Counter
BIG_INT = int(1e6)
def dead_end_(x, y):
return BIG_INT
def same_modi_count(individual):
count = Counter()
for node in individual:
if 'modi' in node.name:
count[node.name] += 1
return max(count.values()) if count else 0
def nodes_count(individual):
count_size = 0
count_gate = 0
for node in individual:
if node.arity != 0:
count_size += 1
if 'modi' not in node.name and 'end' not in node.name:
count_gate += 1
or_count = str(individual).count('+')
return count_size, count_gate + or_count
def plot_modi_tree(individual, visualize_output=False):
nodes, edges, labels = graph(individual)
if visualize_output:
modis = [i for i in labels if 'modi' in labels[i]]
modi_edges = [edge for edge in edges if edge[1] in modis]
edges = [edge for edge in edges if edge[1] not in modis]
for i in range(len(edges)):
edge = edges[i]
if edge[0] in modis:
for m_edge in modi_edges:
if edge[0] == m_edge[1]:
edges.append((m_edge[0], edge[1]))
# new nodes
outs = [nodes[-1] + 1 + i for i in range(individual.num_outputs)]
for i in range(individual.num_outputs):
labels[outs[i]] = 'out' + str(i)
nodes += outs
modi_map = {}
for modi in modis:
modi_index = int(labels[modi].replace('modi', ''))
modi_map[modi] = modi_index
# new edges
for i in range(len(edges)):
f, t = edges[i]
if f in modi_map:
f = outs[modi_map[f]]
if t in modi_map:
t = outs[modi_map[t]]
edges[i] = (f, t)
# delete modis
nodes = [n for n in nodes if n not in modis]
labels = {n: labels[n] for n in labels if n not in modis}
# delete ends
ends = [i for i in labels if 'end' in labels[i]]
nodes = [n for n in nodes if n not in ends]
edges = [e for e in edges if e[0] not in ends and e[1] not in ends]
labels = {n: labels[n] for n in labels if n not in ends}
# delete alone nodes
flatten_edges = [item for sublist in edges for item in sublist]
nodes = [n for n in nodes if n in flatten_edges]
labels = {n: labels[n] for n in labels if n in flatten_edges}
g = nx.DiGraph()
g.add_nodes_from(nodes)
g.add_edges_from(edges)
g = nx.reverse_view(g)
pos = graphviz_layout(g, prog="dot")
nx.draw_networkx_nodes(g, pos, node_size=700)
nx.draw_networkx_edges(g, pos, arrowsize=24)
nx.draw_networkx_labels(g, pos, labels)
plt.title(str(individual) + '\n' + str(individual.fitness))
plt.show()