Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions ot_tf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import tensorflow as tf
#tf.compat.v1.disable_eager_execution()

def sink(a, b, M, m_size, reg, numItermax=1000, stopThr=1e-9):
# we assume that no distances are null except those of the diagonal of distances

# a = tf.expand_dims(tf.ones(shape=(m_size[0],)) / m_size[0], axis=1) # (na, 1)
# b = tf.expand_dims(tf.ones(shape=(m_size[1],)) / m_size[1], axis=1) # (nb, 1)

# init data
Nini = m_size[0]
Nfin = m_size[1]

u = tf.expand_dims(tf.ones(Nini) / Nini, axis=1) # (na, 1)
v = tf.expand_dims(tf.ones(Nfin) / Nfin, axis=1) # (nb, 1)

K = tf.exp(-M / reg) # (na, nb)

Kp = (1.0 / a) * K # (na, 1) * (na, nb) = (na, nb)

cpt = tf.constant(0)
err = tf.constant(1.0)

c = lambda cpt, u, v, err: tf.logical_and(cpt < numItermax, err > stopThr)

def err_f1():
# we can speed up the process by checking for the error only all the 10th iterations
transp = u * (K * tf.squeeze(v)) # (na, 1) * ((na, nb) * (nb,)) = (na, nb)
err_ = tf.pow(tf.norm(tensor=tf.reduce_sum(input_tensor=transp) - b, ord=1), 2) # (,)
return err_

def err_f2():
return err

def loop_func(cpt, u, v, err):
KtransposeU = tf.matmul(tf.transpose(a=K, perm=(1, 0)), u) # (nb, na) x (na, 1) = (nb, 1)
v = tf.compat.v1.div(b, KtransposeU) # (nb, 1)
u = 1.0 / tf.matmul(Kp, v) # (na, 1)

err = tf.cond(pred=tf.equal(cpt % 10, 0), true_fn=err_f1, false_fn=err_f2)

cpt = tf.add(cpt, 1)
return cpt, u, v, err

_, u, v, _ = tf.while_loop(cond=c, body=loop_func, loop_vars=[cpt, u, v, err])

result = tf.reduce_sum(input_tensor=u * K * tf.reshape(v, (1, -1)) * M)

return result


70 changes: 70 additions & 0 deletions qDenseCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,71 @@
import numpy as np
import json

# for sinkhorn metric
import ot_tf
import ot

hexCoords = np.array([
[0.0, 0.0], [0.0, -2.4168015], [0.0, -4.833603], [0.0, -7.2504044],
[2.09301, -1.2083969], [2.09301, -3.6251984], [2.09301, -6.042], [2.09301, -8.458794],
[4.18602, -2.4168015], [4.18602, -4.833603], [4.18602, -7.2504044], [4.18602, -9.667198],
[6.27903, -3.6251984], [6.27903, -6.042], [6.27903, -8.458794], [6.27903, -10.875603],
[-8.37204, -10.271393], [-6.27903, -9.063004], [-4.18602, -7.854599], [-2.0930138, -6.6461945],
[-8.37204, -7.854599], [-6.27903, -6.6461945], [-4.18602, -5.4377975], [-2.0930138, -4.229393],
[-8.37204, -5.4377975], [-6.27903, -4.229393], [-4.18602, -3.020996], [-2.0930138, -1.8125992],
[-8.37204, -3.020996], [-6.27903, -1.8125992], [-4.18602, -0.6042023], [-2.0930138, 0.6042023],
[4.7092705, -12.386101], [2.6162605, -11.177696], [0.5232506, -9.969299], [-1.5697594, -8.760895],
[2.6162605, -13.594498], [0.5232506, -12.386101], [-1.5697594, -11.177696], [-3.6627693, -9.969299],
[0.5232506, -14.802895], [-1.5697594, -13.594498], [-3.6627693, -12.386101], [-5.7557793, -11.177696],
[-1.5697594, -16.0113], [-3.6627693, -14.802895], [-5.7557793, -13.594498], [-7.848793, -12.386101]])
hexMetric = tf.constant( ot.dist(hexCoords, hexCoords, 'euclidean'), tf.float32)

def myfunc(a):
reg=0.5
y_true, y_pred = tf.split(a,num_or_size_splits=2,axis=1)
tf_sinkhorn_loss = ot_tf.sink(y_true, y_pred, hexMetric, (48, 48), reg)
return tf_sinkhorn_loss

def sinkhorn_loss(y_true, y_pred):
y_true = K.cast(y_true, y_pred.dtype)
y_pred = K.reshape(y_pred, (-1,48,1))
y_true = K.reshape(y_true, (-1,48,1))
cc = tf.concat([y_true, y_pred], axis=2)
return K.mean( tf.map_fn(myfunc, cc), axis=(-1) )

# return K.mean( tf.map_fn(myfunc, y_true), axis=(-1) )
# return K.mean( tf.map_fn(myfunc, [y_true, y_pred]), axis=(-1) )
# tf_sinkhorn_loss = K.mean( tf.numpy_function(myfunc, [y_true, y_pred], y_pred.dtype) )
# return tf_sinkhorn_loss
# sy_true = tf.split(y_true,num_or_size_splits=K.shape(y_true)[0],axis=0)
# sy_pred = tf.split(y_pred,num_or_size_splits=K.shape(y_pred)[0],axis=0)
# losses = [ ot_tf.sink(sy_true[i], sy_pred[i], hexMetric, (48, 48), reg) for r in range(len(sy_true))]
# return losses[0]
# tf_sinkhorn_loss = K.mean( ot_tf.sink(y_true, y_pred, hexMetric, (48, 48), reg), axis=(-1) )
# tf_sinkhorn_loss = K.mean( tf.numpy_function(myfunc, [y_true, y_pred], y_pred.dtype) )
# return tf_sinkhorn_loss

def other_loss(y_true, y_pred):
y_true = K.cast(y_true, y_pred.dtype)
loss1 = K.mean(K.square(y_true - y_pred) * K.maximum(y_pred, y_true), axis=(-1))

# y_pred_rs = K.reshape(y_pred, (-1,48))
# y_true_rs = K.reshape(y_true, (-1,48))
# y_pred_x =

y_pred_pool = tf.nn.pool(y_pred,(2,2),'AVG',strides=[1,1])
y_true_pool = tf.nn.pool(y_true,(2,2),'AVG',strides=[1,1])
loss2 = K.mean(K.square(y_true_pool - y_pred_pool) * K.maximum(y_true_pool, y_pred_pool), axis=(-1))
#return loss1 + loss2
return loss1

# return K.mean( tf.map_fn(myfunc, cc), axis=(-1) )

def weightedMSE(self, y_true, y_pred):
y_true = K.cast(y_true, y_pred.dtype)
loss = K.mean(K.square(y_true - y_pred) * K.maximum(y_pred, y_true), axis=(-1))
return loss


class qDenseCNN:
def __init__(self, name='', weights_f=''):
Expand Down Expand Up @@ -199,6 +264,11 @@ def init(self, printSummary=True): # keep_negitive = 0 on inputs, otherwise for
self.autoencoder.compile(loss=self.weightedMSE, optimizer='adam')
self.encoder.compile(loss=self.weightedMSE, optimizer='adam')

elif self.pams['loss'] == 'sink':
self.autoencoder.compile(loss=other_loss, optimizer='adam')
self.encoder.compile(loss=other_loss, optimizer='adam')
# self.autoencoder.compile(loss=sinkhorn_loss, optimizer='adam')
# self.encoder.compile(loss=sinkhorn_loss, optimizer='adam')
elif self.pams['loss'] != '':
self.autoencoder.compile(loss=self.pams['loss'], optimizer='adam')
self.encoder.compile(loss=self.pams['loss'], optimizer='adam')
Expand Down
51 changes: 23 additions & 28 deletions scan_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,38 @@
import json

from train import trainCNN

def plotHist(x,y,ye, name, odir,xtitle, ytitle):
plt.figure()
plt.errorbar(x,y,ye)
plt.title('')
plt.ylabel(ytitle)
plt.xlabel(xtitle)
plt.legend(['others 16,6'], loc='upper right')
plt.savefig(odir+"/"+name+".png")
return
from utils import plotGraphErr

def plotScan(x,outs,name,odir,xtitle="n bits"):
outs = pd.concat(outs)
for metric in ['ssd','corr','emd']:
plotHist(x, outs[metric], outs[metric+'_err'], name+"_"+metric,
plotGraphErr(x, outs[metric], outs[metric+'_err'], name+"_"+metric,
odir,xtitle=xtitle,ytitle=metric)
outs.to_csv(odir+"/"+name+".csv")
return

def BitScan(options, args):

# test inputs
bits = [i+3 for i in range(6)]
bits = [i+3 for i in range(2)]
updates = [{'nBits_input':{'total': b, 'integer': 2}} for b in bits]
outputs = [trainCNN(options,args,u) for u in updates]
plotScan(bits,outputs,"test_input_bits",options.odir,xtitle="total input bits")

exit(0)

# test weights
bits = [i+1 for i in range(8)]
updates = [{'nBits_weight':{'total': 2*b+1, 'integer': b}} for b in bits]
outputs = [trainCNN(options,args,u) for u in updates]
plotScan(bits,outputs,"test_weight_bits",xtitle="total input bits")

emd, emde = zip(*[trainCNN(options,args,u) for u in updates])
plotScan(bits,emd,emde,"test_weight_bits")
if False:
# test inputs
bits = [i+3 for i in range(6)]
updates = [{'nBits_input':{'total': b, 'integer': 2}} for b in bits]
outputs = [trainCNN(options,args,u) for u in updates]
plotScan(bits,outputs,"test_input_bits",options.odir,xtitle="total input bits")

if False:
# test weights
bits = [i+1 for i in range(8)]
updates = [{'nBits_weight':{'total': 2*b+1, 'integer': b}} for b in bits]
outputs = [trainCNN(options,args,u) for u in updates]
plotScan(bits,outputs,"test_weight_bits",options.odir,xtitle="total weight bits")

if True:
# test encoded bits
bits = [4,6,8,10,12,16]
updates = [{'nBits_encod':{'total': b, 'integer': b/2},'encoded_dim':int(64/b)} for b in bits]
outputs = [trainCNN(options,args,u) for u in updates]
plotScan(bits,outputs,"test_encod_bits",options.odir,xtitle="bits per encoded node")

exit(0)

Expand Down
49 changes: 49 additions & 0 deletions tests/test_sinkhorn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import tensorflow as tf
import numpy as np
import sys
sys.path.append("/home/therwig/data/sandbox/hgcal/Ecoder/")
import ot_tf
import ot
#tf.compat.v1.disable_eager_execution()

def main():

na=48
nb=48
reg=0.5
a = tf.expand_dims(tf.ones(shape=(na,)) / na, axis=1) # (na, 1)
b = tf.expand_dims(tf.ones(shape=(nb,)) / nb, axis=1) # (nb, 1)
m = tf.constant( hexMetric(), tf.float32 )
tf_sinkhorn_loss = ot_tf.sink(a, b, m, (na, nb), reg)

# print('finish loads')
# print(a)
# print(b)
# print(m)
print(tf_sinkhorn_loss)

x = tf.ones(shape=(100,48,1,))
y = tf.split(x,num_or_size_splits=100,axis=0)
print(y[0])

return


hexCoords = np.array([
[0.0, 0.0], [0.0, -2.4168015], [0.0, -4.833603], [0.0, -7.2504044],
[2.09301, -1.2083969], [2.09301, -3.6251984], [2.09301, -6.042], [2.09301, -8.458794],
[4.18602, -2.4168015], [4.18602, -4.833603], [4.18602, -7.2504044], [4.18602, -9.667198],
[6.27903, -3.6251984], [6.27903, -6.042], [6.27903, -8.458794], [6.27903, -10.875603],
[-8.37204, -10.271393], [-6.27903, -9.063004], [-4.18602, -7.854599], [-2.0930138, -6.6461945],
[-8.37204, -7.854599], [-6.27903, -6.6461945], [-4.18602, -5.4377975], [-2.0930138, -4.229393],
[-8.37204, -5.4377975], [-6.27903, -4.229393], [-4.18602, -3.020996], [-2.0930138, -1.8125992],
[-8.37204, -3.020996], [-6.27903, -1.8125992], [-4.18602, -0.6042023], [-2.0930138, 0.6042023],
[4.7092705, -12.386101], [2.6162605, -11.177696], [0.5232506, -9.969299], [-1.5697594, -8.760895],
[2.6162605, -13.594498], [0.5232506, -12.386101], [-1.5697594, -11.177696], [-3.6627693, -9.969299],
[0.5232506, -14.802895], [-1.5697594, -13.594498], [-3.6627693, -12.386101], [-5.7557793, -11.177696],
[-1.5697594, -16.0113], [-3.6627693, -14.802895], [-5.7557793, -13.594498], [-7.848793, -12.386101]])
def hexMetric():
return ot.dist(hexCoords, hexCoords, 'euclidean')

if __name__== "__main__":
main()
Loading