diff --git a/source/generator_collection.py b/source/generator_collection.py index fa854d1..f593942 100644 --- a/source/generator_collection.py +++ b/source/generator_collection.py @@ -96,7 +96,7 @@ def __init__(self, json_path): else: self.randomize = 1 - self.raw_data = io.imread((self.raw_data_file)) + self.data = io.imread((self.raw_data_file)) self.preprocess() local_data = self.data.flatten() self.local_mean = np.mean(local_data) @@ -116,7 +116,7 @@ def preprocess(self): self.num_frames, self.img_rows, self.img_cols = self.data.shape def detrend(self, order=2): - trace = np.mean(self.raw_data, axis=(1, 2)) + trace = np.mean(self.data, axis=(1, 2)) X = np.arange(1, trace.shape[0] + 1) X = X.reshape(X.shape[0], 1) pf = PolynomialFeatures(order) @@ -124,7 +124,8 @@ def detrend(self, order=2): md = LinearRegression() md.fit(Xp, trace) self.trend = md.predict(Xp) - self.data = self.raw_data - np.reshape(self.trend, (self.trend.shape[0], 1, 1)) + self.data = (self.data - np.reshape(self.trend, (self.trend.shape[0], 1, 1))).astype('float32') + def __len__(self): "Denotes the total number of batches" @@ -293,7 +294,7 @@ def __getitem__(self, index): # Generate indexes of the batch indexes = np.arange(index * self.batch_size, (index + 1) * self.batch_size) - + shuffle_indexes = self.list_samples[indexes] input_full = np.zeros(