Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@

# Jupyter notebooks
JupyterNB/

.idea/

visualization/
Expand Down
34 changes: 25 additions & 9 deletions gng.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import networkx as nx
import matplotlib.pyplot as plt
from sklearn import decomposition
import copy
import os


__authors__ = 'Adrien Guille'
__email__ = 'adrien.guille@univ-lyon2.fr'
Expand All @@ -18,10 +21,13 @@

class GrowingNeuralGas:

def __init__(self, input_data):
def __init__(self, input_data,output_folder):
self.network = None
self.output_folder = output_folder
self.data = input_data
self.units_created = 0
if not os.path.exists(output_folder):
os.makedirs(output_folder)
plt.style.use('ggplot')

def find_nearest_units(self, observation):
Expand All @@ -35,12 +41,21 @@ def find_nearest_units(self, observation):
return ranking

def prune_connections(self, a_max):
listToRemoveE = []
listToRemoveU = []
for u, v, attributes in self.network.edges(data=True):
if attributes['age'] > a_max:
self.network.remove_edge(u, v)
listToRemoveE.append([u, v])
for u in self.network.nodes():
if self.network.degree(u) == 0:
self.network.remove_node(u)
listToRemoveU.append(u)
try:
self.network.remove_edges_from(listToRemoveE)
self.network.remove_nodes_from(listToRemoveU)
except:
print('Error while removing...')
print('Edges to remove',listToRemoveE)
print('Nodes to remove',listToRemoveU)

def fit_network(self, e_b, e_n, a_max, l, a, d, passes=1, plot_evolution=False):
# logging variables
Expand Down Expand Up @@ -70,7 +85,7 @@ def fit_network(self, e_b, e_n, a_max, l, a, d, passes=1, plot_evolution=False):
s_1 = nearest_units[0]
s_2 = nearest_units[1]
# 3. increment the age of all edges emanating from s_1
for u, v, attributes in self.network.edges_iter(data=True, nbunch=[s_1]):
for u, v, attributes in self.network.edges(data=True, nbunch=[s_1]):
self.network.add_edge(u, v, age=attributes['age']+1)
# 4. add the squared distance between the observation and the nearest unit in input space
self.network.node[s_1]['error'] += spatial.distance.euclidean(observation, self.network.node[s_1]['vector'])**2
Expand All @@ -91,12 +106,13 @@ def fit_network(self, e_b, e_n, a_max, l, a, d, passes=1, plot_evolution=False):
steps += 1
if steps % l == 0:
if plot_evolution:
self.plot_network('visualization/sequence/' + str(sequence) + '.png')
print(str(self.output_folder)+'/' + str(sequence) + '.png')
self.plot_network(str(self.output_folder)+'/' + str(sequence) + '.png')
sequence += 1
# 8.a determine the unit q with the maximum accumulated error
q = 0
error_max = 0
for u in self.network.nodes_iter():
for u in self.network.nodes():
if self.network.node[u]['error'] > error_max:
error_max = self.network.node[u]['error']
q = u
Expand All @@ -123,13 +139,13 @@ def fit_network(self, e_b, e_n, a_max, l, a, d, passes=1, plot_evolution=False):
self.network.node[r]['error'] = self.network.node[q]['error']
# 9. decrease all error variables by multiplying them with a constant d
error = 0
for u in self.network.nodes_iter():
for u in self.network.nodes():
error += self.network.node[u]['error']
accumulated_local_error.append(error)
network_order.append(self.network.order())
network_size.append(self.network.size())
total_units.append(self.units_created)
for u in self.network.nodes_iter():
for u in self.network.nodes():
self.network.node[u]['error'] *= d
if self.network.degree(nbunch=[u]) == 0:
print(u)
Expand All @@ -155,7 +171,7 @@ def plot_network(self, file_path):
plt.clf()
plt.scatter(self.data[:, 0], self.data[:, 1])
node_pos = {}
for u in self.network.nodes_iter():
for u in self.network.nodes():
vector = self.network.node[u]['vector']
node_pos[u] = (vector[0], vector[1])
nx.draw(self.network, pos=node_pos)
Expand Down