diff --git a/models/get_3d.py b/models/get_3d.py index 53a4a7b..42a7938 100644 --- a/models/get_3d.py +++ b/models/get_3d.py @@ -7,7 +7,7 @@ print ("hum") import numpy as np import sys - +import keras def get_data(datafile): #get data for training #print ('Loading Data from .....', datafile) @@ -20,7 +20,11 @@ def get_data(datafile): X = X.astype(np.float32) y = y.astype(np.float32) y = y/100. - ecal = np.squeeze(np.sum(X, axis=(1, 2, 3))) + if keras.backend.image_data_format() !='channels_last': + X =np.moveaxis(X, -1, 1) + ecal = np.squeeze(np.sum(X, axis=(2, 3, 4))) + else: + ecal = np.squeeze(np.sum(X, axis=(1, 2, 3))) print (X.shape) print (y.shape) print (ecal.shape) diff --git a/nnlo/train/GanModel.py b/nnlo/train/GanModel.py index 93389d3..965657f 100644 --- a/nnlo/train/GanModel.py +++ b/nnlo/train/GanModel.py @@ -113,8 +113,15 @@ def _Model(**args): else: return Model(**args) def discriminator(fixed_bn = False, discr_drop_out=0.2): + if keras.backend.image_data_format() =='channels_last': + dshape=(25, 25, 25,1) + daxis=(1,2,3) + else: + dshape=(1, 25, 25, 25) + daxis=(2,3,4) + + image = Input(shape=dshape, name='image') - image = Input(shape=( 25, 25, 25,1 ), name='image') bnm=2 if fixed_bn else 0 f=(5,5,5) @@ -163,19 +170,22 @@ def discriminator(fixed_bn = False, discr_drop_out=0.2): fake = _Dense(1, activation='sigmoid', name='classification')(dnn_out) aux = _Dense(1, activation='linear', name='energy')(dnn_out) - ecal = Lambda(lambda x: K.sum(x, axis=(1, 2, 3)), name='sum_cell')(image) + ecal = Lambda(lambda x: K.sum(x, daxis), name='sum_cell')(image) return _Model(output=[fake, aux, ecal], input=image, name='discriminator_model') def generator(latent_size=200, return_intermediate=False, with_bn=True): - + if keras.backend.image_data_format() =='channels_last': + dim = (7,7,8,8) + else: + dim = (8, 7, 7,8) latent = Input(shape=(latent_size, )) bnm=0 x = _Dense(64 * 7* 7, init='glorot_normal', name='gen_dense1' )(latent) - x = Reshape((7, 7,8, 8))(x) + x = Reshape(dim)(x) x = _Conv3D(64, 6, 6, 8, border_mode='same', init='he_uniform', name='gen_c1' )(x) @@ -211,9 +221,14 @@ def generator(latent_size=200, return_intermediate=False, with_bn=True): return _Model(input=[latent], output=fake_image, name='generator_model') def get_sums(images): - sumsx = np.squeeze(np.sum(images, axis=(2,3))) - sumsy = np.squeeze(np.sum(images, axis=(1,3))) - sumsz = np.squeeze(np.sum(images, axis=(1,2))) + if keras.backend.image_data_format() =='channels_last': + sumsx = np.squeeze(np.sum(images, axis=(2,3))) + sumsy = np.squeeze(np.sum(images, axis=(1,3))) + sumsz = np.squeeze(np.sum(images, axis=(1,2))) + else: + sumsx = np.squeeze(np.sum(images, axis=(3,4))) + sumsy = np.squeeze(np.sum(images, axis=(2,4))) + sumsz = np.squeeze(np.sum(images, axis=(2,3))) return sumsx, sumsy, sumsz def get_moments(images, sumsx, sumsy, sumsz, totalE, m): @@ -535,6 +550,10 @@ def make_opt(**args): loss=['binary_crossentropy', 'mean_absolute_percentage_error', 'mean_absolute_percentage_error'], loss_weights=self.discr_loss_weights ) + if kv2: + self.discriminator.trainable = True #workaround for keras 2 bug + + self.combined.metrics_names = self.discriminator.metrics_names #print ("disc metrics",self.discriminator.metrics_names) #print ("comb metrics",self.combined.metrics_names)