diff --git a/ot_tf.py b/ot_tf.py new file mode 100644 index 0000000..fa58ae1 --- /dev/null +++ b/ot_tf.py @@ -0,0 +1,51 @@ +import tensorflow as tf +#tf.compat.v1.disable_eager_execution() + +def sink(a, b, M, m_size, reg, numItermax=1000, stopThr=1e-9): + # we assume that no distances are null except those of the diagonal of distances + + # a = tf.expand_dims(tf.ones(shape=(m_size[0],)) / m_size[0], axis=1) # (na, 1) + # b = tf.expand_dims(tf.ones(shape=(m_size[1],)) / m_size[1], axis=1) # (nb, 1) + + # init data + Nini = m_size[0] + Nfin = m_size[1] + + u = tf.expand_dims(tf.ones(Nini) / Nini, axis=1) # (na, 1) + v = tf.expand_dims(tf.ones(Nfin) / Nfin, axis=1) # (nb, 1) + + K = tf.exp(-M / reg) # (na, nb) + + Kp = (1.0 / a) * K # (na, 1) * (na, nb) = (na, nb) + + cpt = tf.constant(0) + err = tf.constant(1.0) + + c = lambda cpt, u, v, err: tf.logical_and(cpt < numItermax, err > stopThr) + + def err_f1(): + # we can speed up the process by checking for the error only all the 10th iterations + transp = u * (K * tf.squeeze(v)) # (na, 1) * ((na, nb) * (nb,)) = (na, nb) + err_ = tf.pow(tf.norm(tensor=tf.reduce_sum(input_tensor=transp) - b, ord=1), 2) # (,) + return err_ + + def err_f2(): + return err + + def loop_func(cpt, u, v, err): + KtransposeU = tf.matmul(tf.transpose(a=K, perm=(1, 0)), u) # (nb, na) x (na, 1) = (nb, 1) + v = tf.compat.v1.div(b, KtransposeU) # (nb, 1) + u = 1.0 / tf.matmul(Kp, v) # (na, 1) + + err = tf.cond(pred=tf.equal(cpt % 10, 0), true_fn=err_f1, false_fn=err_f2) + + cpt = tf.add(cpt, 1) + return cpt, u, v, err + + _, u, v, _ = tf.while_loop(cond=c, body=loop_func, loop_vars=[cpt, u, v, err]) + + result = tf.reduce_sum(input_tensor=u * K * tf.reshape(v, (1, -1)) * M) + + return result + + diff --git a/qDenseCNN.py b/qDenseCNN.py index 26a7580..2aaafa9 100644 --- a/qDenseCNN.py +++ b/qDenseCNN.py @@ -10,6 +10,71 @@ import numpy as np import json +# for sinkhorn metric +import ot_tf +import ot + +hexCoords = np.array([ + [0.0, 0.0], [0.0, -2.4168015], [0.0, -4.833603], [0.0, -7.2504044], + [2.09301, -1.2083969], [2.09301, -3.6251984], [2.09301, -6.042], [2.09301, -8.458794], + [4.18602, -2.4168015], [4.18602, -4.833603], [4.18602, -7.2504044], [4.18602, -9.667198], + [6.27903, -3.6251984], [6.27903, -6.042], [6.27903, -8.458794], [6.27903, -10.875603], + [-8.37204, -10.271393], [-6.27903, -9.063004], [-4.18602, -7.854599], [-2.0930138, -6.6461945], + [-8.37204, -7.854599], [-6.27903, -6.6461945], [-4.18602, -5.4377975], [-2.0930138, -4.229393], + [-8.37204, -5.4377975], [-6.27903, -4.229393], [-4.18602, -3.020996], [-2.0930138, -1.8125992], + [-8.37204, -3.020996], [-6.27903, -1.8125992], [-4.18602, -0.6042023], [-2.0930138, 0.6042023], + [4.7092705, -12.386101], [2.6162605, -11.177696], [0.5232506, -9.969299], [-1.5697594, -8.760895], + [2.6162605, -13.594498], [0.5232506, -12.386101], [-1.5697594, -11.177696], [-3.6627693, -9.969299], + [0.5232506, -14.802895], [-1.5697594, -13.594498], [-3.6627693, -12.386101], [-5.7557793, -11.177696], + [-1.5697594, -16.0113], [-3.6627693, -14.802895], [-5.7557793, -13.594498], [-7.848793, -12.386101]]) +hexMetric = tf.constant( ot.dist(hexCoords, hexCoords, 'euclidean'), tf.float32) + +def myfunc(a): + reg=0.5 + y_true, y_pred = tf.split(a,num_or_size_splits=2,axis=1) + tf_sinkhorn_loss = ot_tf.sink(y_true, y_pred, hexMetric, (48, 48), reg) + return tf_sinkhorn_loss + +def sinkhorn_loss(y_true, y_pred): + y_true = K.cast(y_true, y_pred.dtype) + y_pred = K.reshape(y_pred, (-1,48,1)) + y_true = K.reshape(y_true, (-1,48,1)) + cc = tf.concat([y_true, y_pred], axis=2) + return K.mean( tf.map_fn(myfunc, cc), axis=(-1) ) + + # return K.mean( tf.map_fn(myfunc, y_true), axis=(-1) ) + # return K.mean( tf.map_fn(myfunc, [y_true, y_pred]), axis=(-1) ) + # tf_sinkhorn_loss = K.mean( tf.numpy_function(myfunc, [y_true, y_pred], y_pred.dtype) ) + # return tf_sinkhorn_loss + # sy_true = tf.split(y_true,num_or_size_splits=K.shape(y_true)[0],axis=0) + # sy_pred = tf.split(y_pred,num_or_size_splits=K.shape(y_pred)[0],axis=0) + # losses = [ ot_tf.sink(sy_true[i], sy_pred[i], hexMetric, (48, 48), reg) for r in range(len(sy_true))] + # return losses[0] + # tf_sinkhorn_loss = K.mean( ot_tf.sink(y_true, y_pred, hexMetric, (48, 48), reg), axis=(-1) ) + # tf_sinkhorn_loss = K.mean( tf.numpy_function(myfunc, [y_true, y_pred], y_pred.dtype) ) + # return tf_sinkhorn_loss + +def other_loss(y_true, y_pred): + y_true = K.cast(y_true, y_pred.dtype) + loss1 = K.mean(K.square(y_true - y_pred) * K.maximum(y_pred, y_true), axis=(-1)) + + # y_pred_rs = K.reshape(y_pred, (-1,48)) + # y_true_rs = K.reshape(y_true, (-1,48)) + # y_pred_x = + + y_pred_pool = tf.nn.pool(y_pred,(2,2),'AVG',strides=[1,1]) + y_true_pool = tf.nn.pool(y_true,(2,2),'AVG',strides=[1,1]) + loss2 = K.mean(K.square(y_true_pool - y_pred_pool) * K.maximum(y_true_pool, y_pred_pool), axis=(-1)) + #return loss1 + loss2 + return loss1 + + # return K.mean( tf.map_fn(myfunc, cc), axis=(-1) ) + +def weightedMSE(self, y_true, y_pred): + y_true = K.cast(y_true, y_pred.dtype) + loss = K.mean(K.square(y_true - y_pred) * K.maximum(y_pred, y_true), axis=(-1)) + return loss + class qDenseCNN: def __init__(self, name='', weights_f=''): @@ -199,6 +264,11 @@ def init(self, printSummary=True): # keep_negitive = 0 on inputs, otherwise for self.autoencoder.compile(loss=self.weightedMSE, optimizer='adam') self.encoder.compile(loss=self.weightedMSE, optimizer='adam') + elif self.pams['loss'] == 'sink': + self.autoencoder.compile(loss=other_loss, optimizer='adam') + self.encoder.compile(loss=other_loss, optimizer='adam') + # self.autoencoder.compile(loss=sinkhorn_loss, optimizer='adam') + # self.encoder.compile(loss=sinkhorn_loss, optimizer='adam') elif self.pams['loss'] != '': self.autoencoder.compile(loss=self.pams['loss'], optimizer='adam') self.encoder.compile(loss=self.pams['loss'], optimizer='adam') diff --git a/scan_precision.py b/scan_precision.py index a97f6e0..f083306 100644 --- a/scan_precision.py +++ b/scan_precision.py @@ -8,43 +8,38 @@ import json from train import trainCNN - -def plotHist(x,y,ye, name, odir,xtitle, ytitle): - plt.figure() - plt.errorbar(x,y,ye) - plt.title('') - plt.ylabel(ytitle) - plt.xlabel(xtitle) - plt.legend(['others 16,6'], loc='upper right') - plt.savefig(odir+"/"+name+".png") - return +from utils import plotGraphErr def plotScan(x,outs,name,odir,xtitle="n bits"): outs = pd.concat(outs) for metric in ['ssd','corr','emd']: - plotHist(x, outs[metric], outs[metric+'_err'], name+"_"+metric, + plotGraphErr(x, outs[metric], outs[metric+'_err'], name+"_"+metric, odir,xtitle=xtitle,ytitle=metric) + outs.to_csv(odir+"/"+name+".csv") return def BitScan(options, args): - # test inputs - bits = [i+3 for i in range(6)] - bits = [i+3 for i in range(2)] - updates = [{'nBits_input':{'total': b, 'integer': 2}} for b in bits] - outputs = [trainCNN(options,args,u) for u in updates] - plotScan(bits,outputs,"test_input_bits",options.odir,xtitle="total input bits") - - exit(0) - - # test weights - bits = [i+1 for i in range(8)] - updates = [{'nBits_weight':{'total': 2*b+1, 'integer': b}} for b in bits] - outputs = [trainCNN(options,args,u) for u in updates] - plotScan(bits,outputs,"test_weight_bits",xtitle="total input bits") - - emd, emde = zip(*[trainCNN(options,args,u) for u in updates]) - plotScan(bits,emd,emde,"test_weight_bits") + if False: + # test inputs + bits = [i+3 for i in range(6)] + updates = [{'nBits_input':{'total': b, 'integer': 2}} for b in bits] + outputs = [trainCNN(options,args,u) for u in updates] + plotScan(bits,outputs,"test_input_bits",options.odir,xtitle="total input bits") + + if False: + # test weights + bits = [i+1 for i in range(8)] + updates = [{'nBits_weight':{'total': 2*b+1, 'integer': b}} for b in bits] + outputs = [trainCNN(options,args,u) for u in updates] + plotScan(bits,outputs,"test_weight_bits",options.odir,xtitle="total weight bits") + + if True: + # test encoded bits + bits = [4,6,8,10,12,16] + updates = [{'nBits_encod':{'total': b, 'integer': b/2},'encoded_dim':int(64/b)} for b in bits] + outputs = [trainCNN(options,args,u) for u in updates] + plotScan(bits,outputs,"test_encod_bits",options.odir,xtitle="bits per encoded node") exit(0) diff --git a/tests/test_sinkhorn.py b/tests/test_sinkhorn.py new file mode 100644 index 0000000..6e369bc --- /dev/null +++ b/tests/test_sinkhorn.py @@ -0,0 +1,49 @@ +import tensorflow as tf +import numpy as np +import sys +sys.path.append("/home/therwig/data/sandbox/hgcal/Ecoder/") +import ot_tf +import ot +#tf.compat.v1.disable_eager_execution() + +def main(): + + na=48 + nb=48 + reg=0.5 + a = tf.expand_dims(tf.ones(shape=(na,)) / na, axis=1) # (na, 1) + b = tf.expand_dims(tf.ones(shape=(nb,)) / nb, axis=1) # (nb, 1) + m = tf.constant( hexMetric(), tf.float32 ) + tf_sinkhorn_loss = ot_tf.sink(a, b, m, (na, nb), reg) + + # print('finish loads') + # print(a) + # print(b) + # print(m) + print(tf_sinkhorn_loss) + + x = tf.ones(shape=(100,48,1,)) + y = tf.split(x,num_or_size_splits=100,axis=0) + print(y[0]) + + return + + +hexCoords = np.array([ + [0.0, 0.0], [0.0, -2.4168015], [0.0, -4.833603], [0.0, -7.2504044], + [2.09301, -1.2083969], [2.09301, -3.6251984], [2.09301, -6.042], [2.09301, -8.458794], + [4.18602, -2.4168015], [4.18602, -4.833603], [4.18602, -7.2504044], [4.18602, -9.667198], + [6.27903, -3.6251984], [6.27903, -6.042], [6.27903, -8.458794], [6.27903, -10.875603], + [-8.37204, -10.271393], [-6.27903, -9.063004], [-4.18602, -7.854599], [-2.0930138, -6.6461945], + [-8.37204, -7.854599], [-6.27903, -6.6461945], [-4.18602, -5.4377975], [-2.0930138, -4.229393], + [-8.37204, -5.4377975], [-6.27903, -4.229393], [-4.18602, -3.020996], [-2.0930138, -1.8125992], + [-8.37204, -3.020996], [-6.27903, -1.8125992], [-4.18602, -0.6042023], [-2.0930138, 0.6042023], + [4.7092705, -12.386101], [2.6162605, -11.177696], [0.5232506, -9.969299], [-1.5697594, -8.760895], + [2.6162605, -13.594498], [0.5232506, -12.386101], [-1.5697594, -11.177696], [-3.6627693, -9.969299], + [0.5232506, -14.802895], [-1.5697594, -13.594498], [-3.6627693, -12.386101], [-5.7557793, -11.177696], + [-1.5697594, -16.0113], [-3.6627693, -14.802895], [-5.7557793, -13.594498], [-7.848793, -12.386101]]) +def hexMetric(): + return ot.dist(hexCoords, hexCoords, 'euclidean') + +if __name__== "__main__": + main() diff --git a/train.py b/train.py index 7465810..9a58db1 100644 --- a/train.py +++ b/train.py @@ -10,6 +10,8 @@ matplotlib.use('Agg') import matplotlib.pyplot as plt +##from utils import plotHist + import numba import json @@ -32,22 +34,42 @@ def normalize(data,rescaleInputToMax=False): data[i] = 1.*data[i]/data[i].sum() return data,np.array(norm) +def plotHist(vals,name,odir='.',xtitle="",ytitle="",nbins=40, + stats=True, logy=False, leg=None): + plt.figure(figsize=(6,4)) + if leg: + plt.hist(vals,nbins,label=leg) + else: + plt.hist(vals,nbins) + ax = plt.axes() + plt.text(0.1, 0.9, name,transform=ax.transAxes) + if stats: + mu = np.mean(vals) + std = np.std(vals) + plt.text(0.1, 0.8, r'$\mu=%.3f,\ \sigma=%.3f$'%(mu,std),transform=ax.transAxes) + plt.xlabel(xtitle) + plt.ylabel(ytitle if ytitle else 'Entries') + if logy: plt.yscale('log') + plt.savefig(odir+"/"+name+".png") + plt.close() + return + def split(shaped_data, validation_frac=0.2): N = round(len(shaped_data)*validation_frac) #randomly select 25% entries - index = np.random.choice(shaped_data.shape[0], N, replace=False) + val_index = np.random.choice(shaped_data.shape[0], N, replace=False) #select the indices of the other 75% full_index = np.array(range(0,len(shaped_data))) - train_index = np.logical_not(np.in1d(full_index,index)) + train_index = np.logical_not(np.in1d(full_index,val_index)) - val_input = shaped_data[index] + val_input = shaped_data[val_index] train_input = shaped_data[train_index] print('training shape',train_input.shape) print('validation shape',val_input.shape) - return val_input,train_input + return val_input,train_input,val_index def train(autoencoder,encoder,train_input,val_input,name,n_epochs=100): @@ -67,7 +89,7 @@ def train(autoencoder,encoder,train_input,val_input,name,n_epochs=100): plt.xlabel('Epoch') plt.legend(['Train', 'Test'], loc='upper right') plt.savefig("history_%s.png"%name) - #plt.show() + plt.close() save_models(autoencoder,name) @@ -89,28 +111,17 @@ def save_models(autoencoder, name): encoder.save_weights('%s.hdf5'%("encoder_"+name)) decoder.save_weights('%s.hdf5'%("decoder_"+name)) return - - -def predict(x,autoencoder,encoder,reshape=True): - decoded_Q = autoencoder.predict(x) - encoded_Q = encoder.predict(x) - - #need reshape for CNN layers - if reshape : - decoded_Q = np.reshape(decoded_Q,(len(decoded_Q),12,4)) - encoded_shape = encoded_Q.shape - encoded_Q = np.reshape(encoded_Q,(len(encoded_Q),encoded_shape[3],encoded_shape[1])) - - return decoded_Q, encoded_Q ### cross correlation of input/output def cross_corr(x,y): cov = np.cov(x.flatten(),y.flatten()) - std = np.sqrt(np.diag(cov)+1e-10) - corr = cov / np.multiply.outer(std, std) + std = np.sqrt(np.diag(cov)) + stdsqr = np.multiply.outer(std, std) + corr = np.divide(cov, stdsqr, out=np.zeros_like(cov), where=(stdsqr!=0)) return corr[0,1] def ssd(x,y): + if (np.sum(x)==0 or np.sum(y)==0): return 1. ssd=np.sum(((x-y)**2).flatten()) ssd = ssd/(np.sum(x**2)*np.sum(y**2))**0.5 return ssd @@ -131,11 +142,13 @@ def ssd(x,y): [0.5232506, -14.802895], [-1.5697594, -13.594498], [-3.6627693, -12.386101], [-5.7557793, -11.177696], [-1.5697594, -16.0113], [-3.6627693, -14.802895], [-5.7557793, -13.594498], [-7.848793, -12.386101]]) hexMetric = ot.dist(hexCoords, hexCoords, 'euclidean') -def emd(_x, _y, threshold=-1): +MAXDIST = 16.08806614 +def emd(_x, _y, threshold=-1): + if (np.sum(_x)==0 or np.sum(_y)==0): return MAXDIST x = np.array(_x, dtype=np.float64) y = np.array(_y, dtype=np.float64) - x = 1./x.sum()*x.flatten() - y = 1./y.sum()*y.flatten() + x = (1./x.sum() if x.sum() else 1.)*x.flatten() + y = (1./y.sum() if y.sum() else 1.)*y.flatten() if threshold > 0: # only keep entries above 2%, e.g. @@ -146,19 +159,61 @@ def emd(_x, _y, threshold=-1): return ot.emd2(x, y, hexMetric) -def visualize(input_Q,decoded_Q,encoded_Q,index,name='model_X'): - if index.size==0: - Nevents=8 - #randomly pick Nevents if index is not specified - index = np.random.choice(input_Q.shape[0], Nevents, replace=False) - else: - Nevents = len(index) +def d_weighted_mean(x, y): + if (np.sum(x)==0 or np.sum(y)==0): return MAXDIST/2. + x = (1./x.sum() if x.sum() else 1.)*x.flatten() + y = (1./y.sum() if y.sum() else 1.)*y.flatten() + dx = hexCoords[:,0].dot(x-y) + dy = hexCoords[:,1].dot(x-y) + return np.sqrt(dx*dx+dy*dy) + +def make_supercells(inQ, shareQ=False): + outQ = inQ.copy() + inshape = inQ[0].shape + mask = np.array([ + [ 0, 1, 4, 5], #indices for 1 supercell + [ 2, 3, 6, 7], + [ 8, 9, 12, 13], + [10, 11, 14, 15], + [16, 17, 20, 21], + [18, 19, 22, 23], + [24, 25, 28, 29], + [26, 27, 30, 31], + [32, 33, 36, 37], + [34, 35, 38, 39], + [40, 41, 44, 45], + [43, 43, 46, 47]]) + for i in range(len(inQ)): + inFlat = inQ[i].flatten() + outFlat = outQ[i].flatten() + for sc in mask: + # set max cell to sum + if shareQ: + mysum = np.sum( inFlat[sc] ) + outFlat[sc]=mysum/4. + else: + ii = np.argmax( inFlat[sc] ) + mysum = np.sum( inFlat[sc] ) + outFlat[sc]=0 + outFlat[sc[ii]]=mysum + outQ[i] = outFlat.reshape(inshape) + return outQ + +def threshold(_x, norm, cut): + x = _x.copy() + # reshape to allow broadcasting to all cells + norm_shape = norm.reshape((norm.shape[0],)+(1,)*(x.ndim-1)) + x = np.where(x*norm_shape>=cut,x,0) + return x + +def visDisplays(index,input_Q,decoded_Q,encoded_Q=np.array([]),name='model_X'): + Nevents = len(index) inputImg = input_Q[index] - encodedImg = encoded_Q[index] outputImg = decoded_Q[index] - - fig, axs = plt.subplots(3, Nevents, figsize=(16, 10)) + + nrows = 3 if len(encoded_Q) else 2 + fig, axs = plt.subplots(nrows, Nevents, figsize=(16, 10)) for i in range(0,Nevents): if i==0: @@ -173,59 +228,36 @@ def visualize(input_Q,decoded_Q,encoded_Q,index,name='model_X'): else: axs[1,i].set(xlabel='cell_x',title='CNN Ouput_%i'%i) c1=axs[1,i].imshow(outputImg[i]) - - for i in range(0,Nevents): - if i==0: - axs[2,i].set(xlabel='latent dim',ylabel='depth',title='Encoded_%i'%i) - else: - axs[2,i].set(xlabel='latent dim',title='Encoded_%i'%i) - c1=axs[2,i].imshow(encodedImg[i]) - plt.colorbar(c1,ax=axs[2,i]) + + if len(encoded_Q): + encodedImg = encoded_Q[index] + for i in range(0,Nevents): + if i==0: + axs[2,i].set(xlabel='latent dim',ylabel='depth',title='Encoded_%i'%i) + else: + axs[2,i].set(xlabel='latent dim',title='Encoded_%i'%i) + c1=axs[2,i].imshow(encodedImg[i]) + plt.colorbar(c1,ax=axs[2,i]) plt.tight_layout() plt.savefig("%s_examples.png"%name) - -def visMetric(input_Q,decoded_Q,maxQ,name, skipPlot=False): - def plothist(y,xlabel,name): - plt.figure(figsize=(6,4)) - plt.hist(y,50) - mu = np.mean(y) - std = np.std(y) - ax = plt.axes() - plt.text(0.1, 0.9, name,transform=ax.transAxes) - plt.text(0.1, 0.8, r'$\mu=%.3f,\ \sigma=%.3f$'%(mu,std),transform=ax.transAxes) - plt.xlabel(xlabel) - plt.ylabel('Entry') - plt.title('%s on validation set'%xlabel) - plt.savefig("hist_%s.png"%name) - - cross_corr_arr = np.array([cross_corr(input_Q[i],decoded_Q[i]) for i in range(0,len(decoded_Q))] ) - ssd_arr = np.array([ssd(decoded_Q[i],input_Q[i]) for i in range(0,len(decoded_Q))]) - emd_arr = np.array([emd(decoded_Q[i],input_Q[i]) for i in range(0,len(decoded_Q))]) - - if skipPlot: return cross_corr_arr,ssd_arr,emd_arr - - plothist(cross_corr_arr,'cross correlation',name+"_corr") - plothist(ssd_arr,'sum squared difference',name+"_ssd") - plothist(emd_arr,'earth movers distance',name+"_emd") - + plt.close() + +def visMetric(input_Q,decoded_Q,metric,name,odir,skipPlot=False): + + plotHist(vals,name,options.odir,xtitle=longMetric[mname]) + plt.figure(figsize=(6,4)) plt.hist([input_Q.flatten(),decoded_Q.flatten()],20,label=['input','output']) plt.yscale('log') plt.legend(loc='upper right') plt.xlabel('Charge fraction') plt.savefig("hist_Qfr_%s.png"%name) - + plt.close() + input_Q_abs = np.array([input_Q[i] * maxQ[i] for i in range(0,len(input_Q))]) decoded_Q_abs = np.array([decoded_Q[i]*maxQ[i] for i in range(0,len(decoded_Q))]) - - plt.figure(figsize=(6,4)) - plt.hist([input_Q_abs.flatten(),decoded_Q_abs.flatten()],20,label=['input','output']) - plt.yscale('log') - plt.legend(loc='upper right') - plt.xlabel('Charge') - plt.savefig("hist_Qabs_%s.png"%name) - + nonzeroQs = np.count_nonzero(input_Q_abs.reshape(len(input_Q_abs),48),axis=1) occbins = [0,5,10,20,48] fig, axes = plt.subplots(1,len(occbins)-1, figsize=(16, 4)) @@ -241,14 +273,26 @@ def plothist(y,xlabel,name): plt.tight_layout() #plt.show() plt.savefig('corr_vs_occ_%s.png'%name) - + plt.close() + return cross_corr_arr,ssd_arr,emd_arr - -def GetBitsString(In, Accum, Weight): + +def GetBitsString(In, Accum, Weight, Encoded, Dense=False, Conv=False): s="" s += "Input{}b{}i".format(In['total'],In['integer']) s += "_Accum{}b{}i".format(Accum['total'],Accum['integer']) - s += "_Weight{}b{}i".format(Weight['total'],Weight['integer']) + if Dense: + s += "_Dense{}b{}i".format(Dense['total'], Dense['integer']) + if Conv: + s += "_Conv{}b{}i".format(Conv['total'], Conv['integer']) + else: + s += "_Conv{}b{}i".format(Weight['total'], Weight['integer']) + elif Conv: + s += "_Dense{}b{}i".format(Weight['total'], Weight['integer']) + s += "_Conv{}b{}i".format(Conv['total'], Conv['integer']) + else: + s += "_Weight{}b{}i".format(Weight['total'],Weight['integer']) + s += "_Encod{}b{}i".format(Encoded['total'], Encoded['integer']) return s def trainCNN(options, args, pam_updates=None): @@ -258,10 +302,10 @@ def trainCNN(options, args, pam_updates=None): print("Is GPU available? ", tf.test.is_gpu_available()) # default precisions for quantized training - nBits_input = {'total': 16, 'integer': 6} - nBits_accum = {'total': 16, 'integer': 6} - nBits_weight = {'total': 16, 'integer': 6} - nBits_encod = {'total': 16, 'integer': 6} + nBits_input = {'total': 32, 'integer': 4} + nBits_accum = {'total': 32, 'integer': 4} + nBits_weight = {'total': 32, 'integer': 4} + nBits_encod = {'total': 32, 'integer': 4} # model-dependent -- use common weights unless overridden conv_qbits = nBits_weight dense_qbits = nBits_weight @@ -269,8 +313,18 @@ def trainCNN(options, args, pam_updates=None): # from tensorflow.keras import backend # backend.set_image_data_format('channels_first') - - data = pd.read_csv(options.inputFile,dtype=np.float64) ## big 300k file + if os.path.isdir(options.inputFile): + df_arr = [] + for infile in os.listdir(options.inputFile): + infile = os.path.join(options.inputFile,infile) + df_arr.append(pd.read_csv(infile, dtype=np.float64, header=0, usecols=[*range(1, 49)])) + data = pd.concat(df_arr) + data = data.loc[(data.sum(axis=1) != 0)] #drop rows where occupancy = 0 + print(data.shape) + data.describe() + else: + data = pd.read_csv(options.inputFile,dtype=np.float64) + #data = pd.read_csv(options.inputFile,dtype=np.float64) ## big 300k file normdata,maxdata = normalize(data.values.copy(),rescaleInputToMax=options.rescaleInputToMax) arrange8x8 = np.array([ @@ -311,12 +365,22 @@ def trainCNN(options, args, pam_updates=None): 15,31, 47]) models = [ - {'name': '4x4_norm_d10', 'ws': '', - 'pams': {'shape': (4, 4, 3), + #{'name': '4x4_norm_d10', 'ws': '', + # 'pams': {'shape': (4, 4, 3), + # 'channels_first': False, + # 'arrange': arrange443, + # 'encoded_dim': 10, + # 'loss': 'weightedMSE'}}, + {'name': '4x4_norm_v7', 'ws': '', + 'pams': {'shape': (4, 4, 3), 'channels_first': False, - 'arrange': arrange443, - 'encoded_dim': 10, - 'loss': 'weightedMSE'}}, + 'arrange': arrange443, + #'loss': 'weightedMSE', + 'loss': 'sink', + 'CNN_layer_nodes': [4, 4, 4], + 'CNN_kernel_size': [5, 5, 3], + 'CNN_pool': [False, False, False], }}, + ] #{'name':'denseCNN', 'ws':'denseCNN.hdf5', 'pams':{'shape':(1,8,8) } }, @@ -428,9 +492,26 @@ def trainCNN(options, args, pam_updates=None): m['pams'].update(pam_updates) print ('updated parameters for model',m['name']) - summary = pd.DataFrame(columns=['name','en_pams','tot_pams', - 'corr','ssd','emd', - 'corr_err','ssd_err','emd_err',]) + # compression algorithms, autoencoder and more traditional benchmarks + algnames = ['ae','stc1','stc2','thr_lo','thr_hi'] + # metrics to compute on the validation dataset + metrics = {'cross_corr' :cross_corr, + 'SSD' :ssd, + 'EMD' :emd, + 'dMean':d_weighted_mean, + 'zero_frac':(lambda x,y: np.all(y==0)),} + longMetric = {'cross_corr' :'cross correlation', + 'SSD' :'sum of squared differences', + 'EMD' :'earth movers distance', + 'dMean':'difference in energy-weighted mean', + 'zero_frac':'zero fraction',} + summary_entries=['name','en_pams','tot_pams'] + for algname in algnames: + for mname in metrics: + name = mname+"_"+algname + summary_entries.append(mname+"_"+algname) + summary_entries.append(mname+"_"+algname+"_err") + summary = pd.DataFrame(columns=summary_entries) orig_dir = os.getcwd() if not os.path.exists(options.odir): os.mkdir(options.odir) @@ -438,9 +519,10 @@ def trainCNN(options, args, pam_updates=None): for model in models: model_name = model['name'] if options.quantize: - bit_str = GetBitsString(m['pams']['nBits_input'], - m['pams']['nBits_accum'], - m['pams']['nBits_weight']) + bit_str = GetBitsString(model['pams']['nBits_input'], model['pams']['nBits_accum'], + model['pams']['nBits_weight'], model['pams']['nBits_encod'], + (model['pams']['nBits_dense'] if 'nBits_dense' in model['pams'] else False), + (model['pams']['nBits_conv'] if 'nBits_conv' in model['pams'] else False)) model_name += "_" + bit_str if not os.path.exists(model_name): os.mkdir(model_name) os.chdir(model_name) @@ -451,17 +533,22 @@ def trainCNN(options, args, pam_updates=None): m = denseCNN(weights_f=model['ws']) m.setpams(model['pams']) m.init() - shaped_data = m.prepInput(normdata) - val_input, train_input = split(shaped_data) - m_autoCNN , m_autoCNNen = m.get_models() + shaped_data = m.prepInput(normdata) + val_input, train_input, val_ind = split(shaped_data) + m_autoCNN , m_autoCNNen = m.get_models() + val_max = maxdata[val_ind] + if model['ws']=='': + if options.quickTrain: train_input = train_input[:5000] history = train(m_autoCNN,m_autoCNNen,train_input,val_input,name=model_name,n_epochs = options.epochs) else: save_models(m_autoCNN,model_name) - - Nevents = 8 - N_verify = 50 - + + summary_dict = { + 'name':model_name, + 'en_pams' : m_autoCNNen.count_params(), + 'tot_pams': m_autoCNN.count_params(),} + input_Q,cnn_deQ ,cnn_enQ = m.predict(val_input) ## csv files for RTL verification @@ -470,40 +557,58 @@ def trainCNN(options, args, pam_updates=None): np.savetxt("verify_output.csv",cnn_enQ[0:N_csv].reshape(N_csv,m.pams['encoded_dim']), delimiter=",",fmt='%.12f') np.savetxt("verify_decoded.csv",cnn_deQ[0:N_csv].reshape(N_csv,48), delimiter=",",fmt='%.12f') + stc1_Q = make_supercells(input_Q) + stc2_Q = make_supercells(input_Q,shareQ=True) + thr_lo_Q = threshold(input_Q,val_max,47) # 1.35 transverse MIPs + thr_hi_Q = threshold(input_Q,val_max,69) # 2.0 transverse MIPs + occupancy = np.count_nonzero(input_Q.reshape(len(input_Q),48),axis=1) + alg_outs = {'ae' : cnn_deQ, + 'stc1': stc1_Q, + 'stc2': stc2_Q, + 'thr_lo': thr_lo_Q, + 'thr_hi': thr_hi_Q, + } + + # to generate event displays + Nevents = 8 index = np.random.choice(input_Q.shape[0], Nevents, replace=False) - corr_arr, ssd_arr, emd_arr = visMetric(input_Q,cnn_deQ,maxdata,name=model_name, skipPlot=options.skipPlot) - - if not options.skipPlot: - hi_corr_index = (np.where(corr_arr>0.9))[0] - low_corr_index = (np.where(corr_arr<0.2))[0] - visualize(input_Q,cnn_deQ,cnn_enQ,index,name=model_name) - if len(hi_corr_index)>0: - index = np.random.choice(hi_corr_index, min(Nevents,len(hi_corr_index)), replace=False) - visualize(input_Q,cnn_deQ,cnn_enQ,index,name=model_name+"_corr0.9") + + # compute metrics for each alg + for algname, alg_out in alg_outs.items(): + # charge fraction comparison + if(not options.skipPlot): plotHist([input_Q.flatten(),alg_out.flatten()], + algname+"_fracQ",xtitle="charge fraction",ytitle="Cells", + stats=False,logy=True,leg=['input','output']) + input_Q_abs = np.array([input_Q[i]*val_max[i] for i in range(0,len(input_Q))]) + alg_out_abs = np.array([alg_out[i]*val_max[i] for i in range(0,len(alg_out))]) + if(not options.skipPlot): plotHist([input_Q_abs.flatten(),alg_out_abs.flatten()], + algname+"_absQ",xtitle="absolute charge",ytitle="Cells", + stats=False,logy=True,leg=['input','output']) + # event displays + if(not options.skipPlot): visDisplays(index, input_Q, alg_out, (cnn_enQ if algname=='ae' else np.array([])), name=algname) + for mname, metric in metrics.items(): + name = mname+"_"+algname + vals = np.array([metric(input_Q[i],alg_out[i]) for i in range(0,len(input_Q))]) + vals = np.sort(vals) + model[name] = np.round(np.mean(vals), 3) + model[name+'_err'] = np.round(np.std(vals), 3) + summary_dict[name] = model[name] + summary_dict[name+'_err'] = model[name+'_err'] + if(not options.skipPlot) and (not('zero_frac' in mname)): + plotHist(vals,"hist_"+name,xtitle=longMetric[mname]) + hi_index = (np.where(vals>np.quantile(vals,0.9)))[0] + lo_index = (np.where(vals0: + hi_index = np.random.choice(hi_index, min(Nevents,len(hi_index)), replace=False) + visDisplays(hi_index, input_Q, alg_out, (cnn_enQ if algname=='ae' else np.array([])), name=algname) + if len(lo_index)>0: + lo_index = np.random.choice(lo_index, min(Nevents,len(lo_index)), replace=False) + visDisplays(lo_index, input_Q, alg_out, (cnn_enQ if algname=='ae' else np.array([])), name=algname) - if len(low_corr_index)>0: - index = np.random.choice(low_corr_index,min(Nevents,len(low_corr_index)), replace=False) - visualize(input_Q,cnn_deQ,cnn_enQ,index,name=model_name+"_corr0.2") - - model['corr'] = np.round(np.mean(corr_arr),3) - model['ssd'] = np.round(np.mean(ssd_arr),3) - model['emd'] = np.round(np.mean(emd_arr),3) - model['corr_err'] = np.round(np.std(corr_arr),3) - model['ssd_err'] = np.round(np.std(ssd_arr),3) - model['emd_err'] = np.round(np.std(emd_arr),3) - - summary = summary.append( - {'name':model_name, - 'corr':model['corr'], - 'ssd':model['ssd'], - 'emd':model['emd'], - 'corr_err':model['corr_err'], - 'ssd_err':model['ssd_err'], - 'emd_err':model['emd_err'], - 'en_pams' : m_autoCNNen.count_params(), - 'tot_pams': m_autoCNN.count_params(),}, - ignore_index=True) - + print('summary_dict',summary_dict) + summary = summary.append(summary_dict, ignore_index=True) + with open(model_name+"_pams.json",'w') as f: f.write(json.dumps(m.get_pams(),indent=4)) @@ -521,6 +626,7 @@ def trainCNN(options, args, pam_updates=None): parser.add_option("--dryRun", action='store_true', default = False,dest="dryRun", help="dryRun") parser.add_option("--epochs", type='int', default = 100, dest="epochs", help="n epoch to train") parser.add_option("--skipPlot", action='store_true', default = False,dest="skipPlot", help="skip the plotting step") + parser.add_option("--quickTrain", action='store_true', default = False,dest="quickTrain", help="train w only 5k events for testing purposes") parser.add_option("--nCSV", type='int', default = 50, dest="nCSV", help="n of validation events to write to csv") parser.add_option("--rescaleInputToMax", action='store_true', default = False,dest="rescaleInputToMax", help="recale the input images so the maximum deposit is 1. Else normalize") (options, args) = parser.parse_args() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..7fa1517 --- /dev/null +++ b/utils.py @@ -0,0 +1,61 @@ +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +def plotGraph(x, y, name, odir, xtitle, ytitle, leg=None): + plt.figure() + plt.plot(x,y) + plt.title('') + plt.ylabel(ytitle) + plt.xlabel(xtitle) + if leg: plt.legend(leg, loc='upper right') + plt.savefig(odir+"/"+name+".png") + plt.close() + return + +def plotGraphErr(x, y, ye, name, odir, xtitle, ytitle, leg=None): + plt.figure() + plt.errorbar(x,y,ye) + plt.title('') + plt.ylabel(ytitle) + plt.xlabel(xtitle) + if leg: plt.legend(leg, loc='upper right') + plt.savefig(odir+"/"+name+".png") + plt.close() + return + +def plotHist(vals,name,odir,xtitle="",ytitle="",nbins=40): + plt.figure() + plt.hist(vals,nbins) + mu = np.mean(vals) + std = np.std(vals) + ax = plt.axes() + plt.text(0.1, 0.9, name,transform=ax.transAxes) + plt.text(0.1, 0.8, r'$\mu=%.3f,\ \sigma=%.3f$'%(mu,std),transform=ax.transAxes) + plt.xlabel(xtitle) + plt.ylabel(ytitle if ytitle else 'Entries') + plt.savefig(odir+"/"+name+".png") + plt.close() + return + +def decode_ECON(mantissa, exp, n_mantissa=3,n_exp=4): + if exp==0: return mantissa + mantissa += (1<>exp) - (1<<(n_mantissa-1)) + return (mantissa,exp) + +def test_econ(): + for m in range(1<<3): + for e in range(1<<4): + val = decode_ECON(m,e) + m1, e1 = encode_ECON(val) + print(m,e,'-->',val,'-->',m1,e1)