diff --git a/.gitignore b/.gitignore index 6dc9626..f4ddb62 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ + +# Jupyter notebooks +JupyterNB/ + .idea/ visualization/ diff --git a/gng.py b/gng.py index 82cee4e..d546629 100644 --- a/gng.py +++ b/gng.py @@ -5,6 +5,9 @@ import networkx as nx import matplotlib.pyplot as plt from sklearn import decomposition +import copy +import os + __authors__ = 'Adrien Guille' __email__ = 'adrien.guille@univ-lyon2.fr' @@ -18,10 +21,13 @@ class GrowingNeuralGas: - def __init__(self, input_data): + def __init__(self, input_data,output_folder): self.network = None + self.output_folder = output_folder self.data = input_data self.units_created = 0 + if not os.path.exists(output_folder): + os.makedirs(output_folder) plt.style.use('ggplot') def find_nearest_units(self, observation): @@ -35,12 +41,21 @@ def find_nearest_units(self, observation): return ranking def prune_connections(self, a_max): + listToRemoveE = [] + listToRemoveU = [] for u, v, attributes in self.network.edges(data=True): if attributes['age'] > a_max: - self.network.remove_edge(u, v) + listToRemoveE.append([u, v]) for u in self.network.nodes(): if self.network.degree(u) == 0: - self.network.remove_node(u) + listToRemoveU.append(u) + try: + self.network.remove_edges_from(listToRemoveE) + self.network.remove_nodes_from(listToRemoveU) + except: + print('Error while removing...') + print('Edges to remove',listToRemoveE) + print('Nodes to remove',listToRemoveU) def fit_network(self, e_b, e_n, a_max, l, a, d, passes=1, plot_evolution=False): # logging variables @@ -70,7 +85,7 @@ def fit_network(self, e_b, e_n, a_max, l, a, d, passes=1, plot_evolution=False): s_1 = nearest_units[0] s_2 = nearest_units[1] # 3. increment the age of all edges emanating from s_1 - for u, v, attributes in self.network.edges_iter(data=True, nbunch=[s_1]): + for u, v, attributes in self.network.edges(data=True, nbunch=[s_1]): self.network.add_edge(u, v, age=attributes['age']+1) # 4. add the squared distance between the observation and the nearest unit in input space self.network.node[s_1]['error'] += spatial.distance.euclidean(observation, self.network.node[s_1]['vector'])**2 @@ -91,12 +106,13 @@ def fit_network(self, e_b, e_n, a_max, l, a, d, passes=1, plot_evolution=False): steps += 1 if steps % l == 0: if plot_evolution: - self.plot_network('visualization/sequence/' + str(sequence) + '.png') + print(str(self.output_folder)+'/' + str(sequence) + '.png') + self.plot_network(str(self.output_folder)+'/' + str(sequence) + '.png') sequence += 1 # 8.a determine the unit q with the maximum accumulated error q = 0 error_max = 0 - for u in self.network.nodes_iter(): + for u in self.network.nodes(): if self.network.node[u]['error'] > error_max: error_max = self.network.node[u]['error'] q = u @@ -123,13 +139,13 @@ def fit_network(self, e_b, e_n, a_max, l, a, d, passes=1, plot_evolution=False): self.network.node[r]['error'] = self.network.node[q]['error'] # 9. decrease all error variables by multiplying them with a constant d error = 0 - for u in self.network.nodes_iter(): + for u in self.network.nodes(): error += self.network.node[u]['error'] accumulated_local_error.append(error) network_order.append(self.network.order()) network_size.append(self.network.size()) total_units.append(self.units_created) - for u in self.network.nodes_iter(): + for u in self.network.nodes(): self.network.node[u]['error'] *= d if self.network.degree(nbunch=[u]) == 0: print(u) @@ -155,7 +171,7 @@ def plot_network(self, file_path): plt.clf() plt.scatter(self.data[:, 0], self.data[:, 1]) node_pos = {} - for u in self.network.nodes_iter(): + for u in self.network.nodes(): vector = self.network.node[u]['vector'] node_pos[u] = (vector[0], vector[1]) nx.draw(self.network, pos=node_pos)