diff --git a/environment.yml b/environment.yml index 91f7ec2c..800276ff 100644 --- a/environment.yml +++ b/environment.yml @@ -10,6 +10,7 @@ dependencies: - torchvision - cudatoolkit - cartopy + - natsort - pip - pip: - -e . diff --git a/radionets/dl_framework/architecture.py b/radionets/dl_framework/architecture.py index 2129e6e0..9ba1a31c 100644 --- a/radionets/dl_framework/architecture.py +++ b/radionets/dl_framework/architecture.py @@ -1,6 +1,9 @@ -from radionets.dl_framework.architectures.basics import * -from radionets.dl_framework.architectures.unet import * -from radionets.dl_framework.architectures.filter_deep import * +#from radionets.dl_framework.architectures.basics import * +#from radionets.dl_framework.architectures.unet import * +# from radionets.dl_framework.architectures.filter_deep import * from radionets.dl_framework.architectures.superRes import * -from radionets.dl_framework.architectures.res_exp import * -from radionets.dl_framework.architectures.lists import * +from radionets.dl_framework.architectures.superRes import SRResNet_dirtyModel, SRResNet_dirtyModel_pretrainedL1, GANCS_generator, GANCS_generator_test, RIM, RIM_DC +from radionets.dl_framework.architectures.superRes import RIM_DC_noDetach +#from radionets.dl_framework.architectures.superRes import automap +# from radionets.dl_framework.architectures.res_exp import * +#from radionets.dl_framework.architectures.lists import * diff --git a/radionets/dl_framework/architectures/superRes.py b/radionets/dl_framework/architectures/superRes.py index 426c6937..8aef9da4 100644 --- a/radionets/dl_framework/architectures/superRes.py +++ b/radionets/dl_framework/architectures/superRes.py @@ -6,272 +6,39 @@ ResBlock_amp, ResBlock_phase, SRBlock, + EDSRBaseBlock, + RDB, + FBB, + Lambda, + better_symmetry, + fft, + gradFunc2, + manual_grad, + tf_shift, + btf_shift, + CirculationShiftPad, + SRBlockPad, + BetterShiftPad, Lambda, symmetry, + SRBlock_noBias, + HardDC, + SoftDC, + calc_DirtyBeam, + gauss, + ConvGRUCell, + gradFunc, + gradFunc2, + gradFunc_putzky, + fft_conv, + ConvGRUCellBN, ) from functools import partial - - -class superRes_simple(nn.Module): - def __init__(self, img_size): - super().__init__() - self.img_size = img_size - self.conv1_amp = nn.Sequential( - nn.Conv2d(1, 4, stride=2, kernel_size=3, padding=3 // 2), nn.ReLU() - ) - self.conv1_phase = nn.Sequential( - nn.Conv2d(1, 4, stride=2, kernel_size=3, padding=3 // 2), GeneralELU(1 - pi) - ) - self.conv2_amp = nn.Sequential( - nn.Conv2d(4, 8, stride=2, kernel_size=3, padding=3 // 2), nn.ReLU() - ) - self.conv2_phase = nn.Sequential( - nn.Conv2d(4, 8, stride=2, kernel_size=3, padding=3 // 2), GeneralELU(1 - pi) - ) - self.conv3_amp = nn.Sequential( - nn.Conv2d(8, 16, stride=2, kernel_size=3, padding=3 // 2), nn.ReLU() - ) - self.conv3_phase = nn.Sequential( - nn.Conv2d(8, 16, stride=2, kernel_size=3, padding=3 // 2), - GeneralELU(1 - pi), - ) - self.conv4_amp = nn.Sequential( - nn.Conv2d(16, 32, stride=2, kernel_size=3, padding=3 // 2), nn.ReLU() - ) - self.conv4_phase = nn.Sequential( - nn.Conv2d(16, 32, stride=2, kernel_size=3, padding=3 // 2), - GeneralELU(1 - pi), - ) - self.conv5_amp = nn.Sequential( - nn.Conv2d(32, 64, stride=2, kernel_size=3, padding=3 // 2), nn.ReLU() - ) - self.conv5_phase = nn.Sequential( - nn.Conv2d(32, 64, stride=2, kernel_size=3, padding=3 // 2), - GeneralELU(1 - pi), - ) - self.final_amp = nn.Sequential( - nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, img_size ** 2) - ) - self.final_phase = nn.Sequential( - nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, img_size ** 2) - ) - - def forward(self, x): - amp = x[:, 0, :].unsqueeze(1) - phase = x[:, 1, :].unsqueeze(1) - - amp = self.conv1_amp(amp) - phase = self.conv1_phase(phase) - - amp = self.conv2_amp(amp) - phase = self.conv2_phase(phase) - - amp = self.conv3_amp(amp) - phase = self.conv3_phase(phase) - - amp = self.conv4_amp(amp) - phase = self.conv4_phase(phase) - - amp = self.conv5_amp(amp) - phase = self.conv5_phase(phase) - - amp = self.final_amp(amp) - phase = self.final_phase(phase) - - amp = amp.reshape(-1, 1, self.img_size, self.img_size) - phase = phase.reshape(-1, 1, self.img_size, self.img_size) - - comb = torch.cat([amp, phase], dim=1) - return comb - - -class superRes_res18(nn.Module): - def __init__(self, img_size): - super().__init__() - torch.cuda.set_device(1) - self.img_size = img_size - - self.preBlock_amp = nn.Sequential( - nn.Conv2d(1, 64, 7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU() - ) - self.preBlock_phase = nn.Sequential( - nn.Conv2d(1, 64, 7, stride=2, padding=3), - nn.BatchNorm2d(64), - GeneralELU(1 - pi), - ) - - self.maxpool_amp = nn.MaxPool2d(3, 2, 1) - self.maxpool_phase = nn.MaxPool2d(3, 2, 1) - - # first block - self.layer1_amp = nn.Sequential(ResBlock_amp(64, 64), ResBlock_amp(64, 64)) - self.layer1_phase = nn.Sequential( - ResBlock_phase(64, 64), ResBlock_phase(64, 64) - ) - - self.layer2_amp = nn.Sequential( - ResBlock_amp(64, 128, stride=2), ResBlock_amp(128, 128) - ) - self.layer2_phase = nn.Sequential( - ResBlock_phase(64, 128, stride=2), ResBlock_phase(128, 128) - ) - - self.layer3_amp = nn.Sequential( - ResBlock_amp(128, 256, stride=2), ResBlock_amp(256, 256) - ) - self.layer3_phase = nn.Sequential( - ResBlock_phase(128, 256, stride=2), ResBlock_phase(256, 256) - ) - - self.layer4_amp = nn.Sequential( - ResBlock_amp(256, 512, stride=2), ResBlock_amp(512, 512) - ) - self.layer4_phase = nn.Sequential( - ResBlock_phase(256, 512, stride=2), ResBlock_phase(512, 512) - ) - - self.final_amp = nn.Sequential( - nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(512, img_size ** 2) - ) - self.final_phase = nn.Sequential( - nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(512, img_size ** 2) - ) - - def forward(self, x): - amp = x[:, 0, :].unsqueeze(1) - phase = x[:, 1, :].unsqueeze(1) - - amp = self.preBlock_amp(amp) - phase = self.preBlock_phase(phase) - - amp = self.maxpool_amp(amp) - phase = self.maxpool_phase(phase) - - amp = self.layer1_amp(amp) - phase = self.layer1_phase(phase) - - amp = self.layer2_amp(amp) - phase = self.layer2_phase(phase) - - amp = self.layer3_amp(amp) - phase = self.layer3_phase(phase) - - amp = self.layer4_amp(amp) - phase = self.layer4_phase(phase) - - amp = self.final_amp(amp) - phase = self.final_phase(phase) - - amp = amp.reshape(-1, 1, self.img_size, self.img_size) - phase = phase.reshape(-1, 1, self.img_size, self.img_size) - - comb = torch.cat([amp, phase], dim=1) - return comb - - -class superRes_res34(nn.Module): - def __init__(self, img_size): - super().__init__() - self.img_size = img_size - - self.preBlock_amp = nn.Sequential( - nn.Conv2d(1, 64, 7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU() - ) - self.preBlock_phase = nn.Sequential( - nn.Conv2d(1, 64, 7, stride=2, padding=3), - nn.BatchNorm2d(64), - GeneralELU(1 - pi), - ) - - self.maxpool_amp = nn.MaxPool2d(3, 2, 1) - self.maxpool_phase = nn.MaxPool2d(3, 2, 1) - - # first block - self.layer1_amp = nn.Sequential( - ResBlock_amp(64, 64), ResBlock_amp(64, 64), ResBlock_amp(64, 64) - ) - self.layer1_phase = nn.Sequential( - ResBlock_phase(64, 64), ResBlock_phase(64, 64), ResBlock_phase(64, 64) - ) - - self.layer2_amp = nn.Sequential( - ResBlock_amp(64, 128, stride=2), - ResBlock_amp(128, 128), - ResBlock_amp(128, 128), - ResBlock_amp(128, 128), - ) - self.layer2_phase = nn.Sequential( - ResBlock_phase(64, 128, stride=2), - ResBlock_phase(128, 128), - ResBlock_phase(128, 128), - ResBlock_phase(128, 128), - ) - - self.layer3_amp = nn.Sequential( - ResBlock_amp(128, 256, stride=2), - ResBlock_amp(256, 256), - ResBlock_amp(256, 256), - ResBlock_amp(256, 256), - ResBlock_amp(256, 256), - ResBlock_amp(256, 256), - ) - self.layer3_phase = nn.Sequential( - ResBlock_phase(128, 256, stride=2), - ResBlock_phase(256, 256), - ResBlock_phase(256, 256), - ResBlock_phase(256, 256), - ResBlock_phase(256, 256), - ResBlock_phase(256, 256), - ) - - self.layer4_amp = nn.Sequential( - ResBlock_amp(256, 512, stride=2), - ResBlock_amp(512, 512), - ResBlock_amp(512, 512), - ) - self.layer4_phase = nn.Sequential( - ResBlock_phase(256, 512, stride=2), - ResBlock_phase(512, 512), - ResBlock_phase(512, 512), - ) - - self.final_amp = nn.Sequential( - nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(512, img_size ** 2) - ) - self.final_phase = nn.Sequential( - nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(512, img_size ** 2) - ) - - def forward(self, x): - amp = x[:, 0, :].unsqueeze(1) - phase = x[:, 1, :].unsqueeze(1) - - amp = self.preBlock_amp(amp) - phase = self.preBlock_phase(phase) - - amp = self.maxpool_amp(amp) - phase = self.maxpool_phase(phase) - - amp = self.layer1_amp(amp) - phase = self.layer1_phase(phase) - - amp = self.layer2_amp(amp) - phase = self.layer2_phase(phase) - - amp = self.layer3_amp(amp) - phase = self.layer3_phase(phase) - - amp = self.layer4_amp(amp) - phase = self.layer4_phase(phase) - - amp = self.final_amp(amp) - phase = self.final_phase(phase) - - amp = amp.reshape(-1, 1, self.img_size, self.img_size) - phase = phase.reshape(-1, 1, self.img_size, self.img_size) - - comb = torch.cat([amp, phase], dim=1) - return comb +import torchvision +import radionets.evaluation.utils as ut +import numpy as np +import matplotlib.pyplot as plt +import numpy as np class SRResNet(nn.Module): @@ -360,10 +127,10 @@ def forward(self, x): class SRResNet_corr(nn.Module): - def __init__(self): + def __init__(self, img_size): super().__init__() # torch.cuda.set_device(1) - # self.img_size = img_size + self.img_size = img_size self.preBlock = nn.Sequential( nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2), nn.PReLU() @@ -379,6 +146,14 @@ def __init__(self): SRBlock(64, 64), SRBlock(64, 64), SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), ) self.postBlock = nn.Sequential( @@ -389,119 +164,522 @@ def __init__(self): nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2), ) - self.symmetry_amp = Lambda(partial(symmetry, mode="real")) - self.symmetry_imag = Lambda(partial(symmetry, mode="imag")) + #new symmetry - def forward(self, x): - s = x.shape[-1] + self.symmetry = Lambda(better_symmetry) + + #pi layer + # self.pi = nn.Tanh() + def forward(self, x): x = self.preBlock(x) x = x + self.postBlock(self.blocks(x)) x = self.final(x) - x0 = self.symmetry_amp(x[:, 0]).reshape(-1, 1, s, s) - x1 = self.symmetry_imag(x[:, 1]).reshape(-1, 1, s, s) + # x[:,0][x[:,0]<0] = 0 + # x[:,0][x[:,0]>2] = 2 + # x[:,1] = np.pi*self.pi(x[:,1]) - return torch.cat([x0, x1], dim=1) + + return self.symmetry(x) -class SRResNet_amp(nn.Module): - def __init__(self): +class SRResNet_sym(nn.Module): + def __init__(self, img_size): super().__init__() # torch.cuda.set_device(1) - # self.img_size = img_size + self.tf = Lambda(tf_shift) self.preBlock = nn.Sequential( - nn.Conv2d(1, 32, 9, stride=1, padding=4, groups=1), nn.PReLU() + nn.Conv2d(2, 64, 9, stride=1, padding=4, groups=2), nn.PReLU() ) - # ResBlock 12 + # ResBlock 16 self.blocks = nn.Sequential( - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), ) self.postBlock = nn.Sequential( - nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32) + nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.BatchNorm2d(64) ) self.final = nn.Sequential( - nn.Conv2d(32, 1, 9, stride=1, padding=4, groups=1), + nn.Conv2d(64, 2, 9, stride=1, padding=4, groups=2), ) - self.symmetry_amp = Lambda(partial(symmetry, mode="real")) + #new symmetry - def forward(self, x): - x = x[:, 0].unsqueeze(1) + self.btf = Lambda(btf_shift) + def forward(self, x): + x = self.tf(x) x = self.preBlock(x) x = x + self.postBlock(self.blocks(x)) x = self.final(x) - x = self.symmetry_amp(x).reshape(-1, 1, 63, 63) - - return x + return self.btf(x) - -class SRResNet_phase(nn.Module): - def __init__(self): +class SRResNet_sym_pad(nn.Module): + def __init__(self, img_size): super().__init__() # torch.cuda.set_device(1) - # self.img_size = img_size + self.tf = Lambda(tf_shift) self.preBlock = nn.Sequential( - nn.Conv2d(1, 32, 9, stride=1, padding=4, groups=1), nn.PReLU() + BetterShiftPad((4,4,4,4)), + nn.Conv2d(2, 64, 9, stride=1, padding=0, groups=2), nn.PReLU() ) - # ResBlock 12 + # ResBlock 16 self.blocks = nn.Sequential( - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), - SRBlock(32, 32), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), + SRBlockPad(64, 64), ) self.postBlock = nn.Sequential( - nn.Conv2d(32, 32, 3, stride=1, padding=1), nn.BatchNorm2d(32) + BetterShiftPad((1,1,1,1)), + nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.BatchNorm2d(64) ) self.final = nn.Sequential( - nn.Conv2d(32, 1, 9, stride=1, padding=4, groups=1), + BetterShiftPad((4,4,4,4)), + nn.Conv2d(64, 2, 9, stride=1, padding=0, groups=2), ) - self.symmetry_phase = Lambda(partial(symmetry, mode="imag")) + #new symmetry - def forward(self, x): - x = x[:, 1].unsqueeze(1) + self.btf = Lambda(btf_shift) + def forward(self, x): + x = self.tf(x) x = self.preBlock(x) x = x + self.postBlock(self.blocks(x)) x = self.final(x) - x = self.symmetry_phase(x).reshape(-1, 1, 63, 63) + return self.btf(x) + - return x +class discriminator(nn.Module): + def __init__(self): + super().__init__() + torch.cuda.set_device(0) + + self.preBlock = nn.Sequential(nn.Conv2d(1, 64, 3, stride=1, padding=1), nn.LeakyReLU(0.2)) + + self.block1 = nn.Sequential(nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2)) + self.block2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) + self.block3 = nn.Sequential(nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) + self.block4 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) + self.block5 = nn.Sequential(nn.Conv2d(256, 256, 3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) + self.block6 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) + self.block7 = nn.Sequential(nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) + + self.main = nn.Sequential(self.block1, self.block2, self.block3, self.block4, self.block5, self.block6, self.block7) + + # self.postBlock = nn.Sequential(nn.Linear(512*4*4, 1024), nn.LeakyReLU(0.2), nn.Linear(1024,1), nn.Sigmoid()) #GAN + self.postBlock = nn.Sequential(nn.Linear(512*4*4, 1024), nn.LeakyReLU(0.2), nn.Linear(1024,1)) #WGAN + + def forward(self, x): + if isinstance(x, tuple) or isinstance(x, list): + if len(x) == 2: + x = x[1] + else: + x = x[0] + + if x.shape[1] == 2: + amp_x = (10 ** (10 * x[:,0]) - 1) / 10 ** 10 + phase_x = x[:,1] + compl_x = amp_x * torch.exp(1j * phase_x) + ifft_x = torch.fft.ifft2(compl_x) + img_x = torch.absolute(ifft_x) + shift_x = torch.fft.ifftshift(img_x).unsqueeze(1) + else: + shift_x = x + # shift_x[torch.isnan(shift_x)] = 0 + pred = self.preBlock(shift_x) + pred = self.main(pred) + pred = torch.flatten(pred, 1) + pred = self.postBlock(pred) + return pred + +class GANCS_generator(nn.Module): + def __init__(self): + super().__init__() + torch.cuda.set_device(1) + self.blocks = nn.Sequential( + SRBlock(2, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + SRBlock(64, 64), + ) + + self.post = nn.Sequential( + nn.Conv2d(64, 64, 3, stride=1, padding=1), nn.ReLU(), + nn.Conv2d(64, 64, 1, stride=1, padding=0), nn.ReLU(), + nn.Conv2d(64, 2, 1, stride=1, padding=0) + ) + + self.DC = HardDC(45, 10) + + def forward(self, x): + ap = x[0] + base_mask = x[1] + A = x[2] + + + amp = ap[:,0].clone().detach() + phase = ap[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) #k measured + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft) + # change to two channels real/imag + input = torch.zeros(ap.shape).to('cuda') + input[:,0] = spatial.real + input[:,1] = spatial.imag + # dirty = input.clone().detach() + + + pred = self.blocks(input) + + pred = self.post(pred) + + pred = self.DC(pred, compl.unsqueeze(1), A, base_mask) + + + return pred + + +class GANCS_critic(nn.Module): + def __init__(self): + super().__init__() + self.block1 = nn.Sequential(nn.Conv2d(2, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2)) + self.block2 = nn.Sequential(nn.Conv2d(64, 128, 3, stride=1, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) + self.block3 = nn.Sequential(nn.Conv2d(128, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2)) + self.block4 = nn.Sequential(nn.Conv2d(128, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) + self.block5 = nn.Sequential(nn.Conv2d(256, 256, 3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2)) + self.block6 = nn.Sequential(nn.Conv2d(256, 512, 3, stride=1, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) + self.block7 = nn.Sequential(nn.Conv2d(512, 512, 3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2)) + + self.blocks = nn.Sequential(self.block1, self.block2, self.block3, self.block4, self.block5, self.block6, self.block7, nn.AdaptiveAvgPool2d(1)) + + def forward(self, x): + if x.shape[1] == 2: + amp = x[:,0].clone().detach() + phase = x[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + x = torch.fft.ifftshift(ifft).unsqueeze(1) + input = torch.zeros((x.shape[0],2,x.shape[2], x.shape[3])).to('cuda') + input[:,0] = x.real.squeeze(1) + input[:,1] = x.imag.squeeze(1) + return self.blocks(input) + +class ConvRNN_deepClean(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(4, 64, 11, stride=3, dilation=1, padding=2), # use stride=4 for 63 px images + nn.Tanh(), + ) + self.GRU1 = ConvGRUCell(64, 64, 11) + self.conv2 = nn.Sequential( + nn.ConvTranspose2d(64, 64, 11, stride=3, dilation=1, padding=2), + nn.Tanh(), + ) + self.GRU2 = ConvGRUCell(64, 64, 11) + self.conv3 = nn.Sequential( + nn.Conv2d(64, 2, 11, stride=1, dilation=1, padding=5, bias=False) + ) + # self.weight = nn.Parameter(torch.tensor([0.25])) + + def forward(self, x, hx=None): + if not hx: + hx = [None]*2 + + complex2channels = torch.cat((x[:,0].real.unsqueeze(1),x[:,0].imag.unsqueeze(1),x[:,1].real.unsqueeze(1),x[:,1].imag.unsqueeze(1)), dim=1) + # print(complex2channels.shape) + # print(complex2channels.dtype) + + c1 = self.conv1(complex2channels) + g1 = self.GRU1(c1, hx[0]) + c2 = self.conv2(g1) + g2 = self.GRU2(c2, hx[1]) + c3 = self.conv3(g2) + + # plt.imshow(torch.absolute(c2[0,0]+1j*c2[0,1]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + channels2complex = (c3[:,0]+1j*c3[:,1]).unsqueeze(1) + + return channels2complex, [g1.detach(), g2.detach()] # ??? detach() ??? + +class RIM_DC(nn.Module): + def __init__(self, n_steps=10): + super().__init__() + torch.cuda.set_device(0) + # torch.set_default_dtype(torch.float64) ## this is really important since we do a lot of ffts. otherwise torch.zeros is float32 and we can't save complex128 into it! + self.n_steps = n_steps + + self.cRNN = ConvRNN_deepClean() + # self.type(torch.complex64) + # torch.backends.cudnn.enabled = False + + + def forward(self, x, hx=None): + ap = x[0] + amp = ap[:,0] + phase = ap[:,1] + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) #k measured + + + data = compl.clone().detach().unsqueeze(1) + compl_shift = torch.fft.fftshift(compl) # shift low freq to corner + ifft = torch.fft.ifft2(compl_shift, norm="forward") + eta = torch.fft.ifftshift(ifft).unsqueeze(1) # shift low freq to center + # print(eta.shape) + # eta = torch.zeros(ap.shape, dtype=torch.float64).to('cuda') + # eta[:,0] = ifft_shift.real + # eta[:,1] = ifft_shift.imag + + + + etas = [] + # plt.imshow(torch.abs(eta[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + for i in range(self.n_steps): + + grad = gradFunc_putzky(eta.detach(), [data, x[1]]).detach() + # grad = manual_grad(eta.detach(), [data, x[1]]).detach() + # plt.imshow(torch.absolute(grad[0,0]+1j*grad[0,1]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + # break + + input = torch.cat((eta.detach(),grad), dim=1) + delta, hx = self.cRNN(input, hx) + # plt.imshow(torch.abs(grad[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + # plt.imshow(torch.absolute(grad[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + # print(hx[0].requires_grad) + # plt.imshow(torch.abs(delta[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + eta = eta.detach() + delta + # plt.imshow(torch.abs(eta[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + etas.append(eta) + + + + return [eta/eta.shape[2]**2 for eta in etas] + + +class ConvRNN_deepClean_noDetach(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(4, 64, 11, stride=4, dilation=1, padding=2), # use stride=4 for 63 px images # for blackhole model use 4, 64, 11, stride=3, dilation=1, padding=2 + nn.Tanh(), + ) + self.GRU1 = ConvGRUCell(64, 64, 11) + self.conv2 = nn.Sequential( + nn.ConvTranspose2d(64, 64, 11, stride=4, dilation=1, padding=2), + nn.Tanh(), + ) + self.GRU2 = ConvGRUCell(64, 64, 11) + self.conv3 = nn.Sequential( + nn.Conv2d(64, 2, 11, stride=1, dilation=1, padding=5, bias=False) + ) + # self.weight = nn.Parameter(torch.tensor([0.25])) + + def forward(self, x, hx=None): + if not hx: + hx = [None]*2 + + complex2channels = torch.cat((x[:,0].real.unsqueeze(1),x[:,0].imag.unsqueeze(1),x[:,1].real.unsqueeze(1),x[:,1].imag.unsqueeze(1)), dim=1) + # print(complex2channels.shape) + # print(complex2channels.dtype) + + c1 = self.conv1(complex2channels) + g1 = self.GRU1(c1, hx[0]) + c2 = self.conv2(g1) + g2 = self.GRU2(c2, hx[1]) + c3 = self.conv3(g2) + + # plt.imshow(torch.absolute(c2[0,0]+1j*c2[0,1]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + channels2complex = (c3[:,0]+1j*c3[:,1]).unsqueeze(1) + + return channels2complex, [g1, g2] # ??? detach() ??? + +class RIM_DC_noDetach(nn.Module): + def __init__(self, n_steps=10): + super().__init__() + torch.cuda.set_device(1) + # torch.set_default_dtype(torch.float64) ## this is really important since we do a lot of ffts. otherwise torch.zeros is float32 and we can't save complex128 into it! + self.n_steps = n_steps + + # self.cRNN = ConvRNN_deepClean_noDetach() + self.cRNN = ConvRNN_deepClean_noDetach_smallKernel() + # self.type(torch.complex64) + # torch.backends.cudnn.enabled = False + + + def forward(self, x, hx=None, factor=1): + ap = x[0] + amp = ap[:,0] + uv_cov = amp.unsqueeze(1).clone().detach() + phase = ap[:,1] + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) #k measured + + + data = compl.clone().detach().unsqueeze(1) + compl_shift = torch.fft.fftshift(compl) # shift low freq to corner + ifft = torch.fft.ifft2(compl_shift, norm="forward") + eta = torch.fft.ifftshift(ifft).unsqueeze(1)*factor # shift low freq to center + #calc beam + # uv_cov[uv_cov!=0] = 1 + # beam = abs(torch.fft.ifftshift(torch.fft.fft2(torch.fft.fftshift(uv_cov)))) + + # beam = beam/torch.max(torch.max(beam,2)[0],2)[0][:,:,None,None] + # plt.imshow(beam[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + + + etas = [] + + for i in range(self.n_steps): + + grad = gradFunc_putzky(eta, [data, x[1]]) + + # if i == 0: + # eta = fft_conv(eta,beam) + input = torch.cat((eta,grad), dim=1) + delta, hx = self.cRNN(input, hx) + eta = eta + delta + # plt.imshow(abs(eta[0,0].cpu().detach().numpy())/(64**2), cmap='hot') + # plt.colorbar() + # plt.show() + # plt.imshow(abs(grad[0,0].cpu().detach().numpy()), cmap='hot') + # plt.colorbar() + # plt.show() + etas.append(eta) + + # plt.imshow(abs(fft_conv(eta,beam)[0,0].cpu().detach().numpy())/(64**2), cmap='hot') + # plt.colorbar() + # plt.show() + # return [fft_conv(eta,beam)/eta.shape[2]**2 for eta in etas] + return [eta/eta.shape[2]**2 for eta in etas] + + + +class ConvRNN_deepClean_noDetach_smallKernel(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(4, 64, 3, stride=1, dilation=2, padding=2), + nn.Tanh(), + ) + self.conv1b = nn.Sequential( + nn.Conv2d(4, 64, 1, stride=1, dilation=2, padding=0), + nn.Tanh(), + ) + self.conv1c = nn.Sequential( + nn.Conv2d(4, 64, 5, stride=1, dilation=2, padding=4), + nn.Tanh(), + ) + self.GRU1 = ConvGRUCell(64*3, 64*3, 3) + self.conv2 = nn.Sequential( + nn.ConvTranspose2d(64, 64, 3, stride=1, dilation=2, padding=2), + nn.Tanh(), + ) + self.conv2b = nn.Sequential( + nn.ConvTranspose2d(64, 64, 1, stride=1, dilation=2, padding=0), + nn.Tanh(), + ) + self.conv2c = nn.Sequential( + nn.ConvTranspose2d(64, 64, 5, stride=1, dilation=2, padding=4), + nn.Tanh(), + ) + self.GRU2 = ConvGRUCell(64*3, 64*3, 3) + self.conv3 = nn.Sequential( + nn.Conv2d(64*3, 2, 3, stride=1, dilation=2, padding=2, bias=False) + ) + # self.weight = nn.Parameter(torch.tensor([0.25])) + + def forward(self, x, hx=None): + if not hx: + hx = [None]*2 + + complex2channels = torch.cat((x[:,0].real.unsqueeze(1),x[:,0].imag.unsqueeze(1),x[:,1].real.unsqueeze(1),x[:,1].imag.unsqueeze(1)), dim=1) + # print(complex2channels.shape) + # print(complex2channels.dtype) + + c1 = self.conv1(complex2channels) + c1b = self.conv1b(complex2channels) + c1c = self.conv1c(complex2channels) + comb = torch.cat((c1,c1b,c1c),dim=1) + g1 = self.GRU1(comb, hx[0]) + g1abc = torch.split(g1,64,dim=1) + c2 = self.conv2(g1abc[0]) + c2b = self.conv2(g1abc[1]) + c2c = self.conv2(g1abc[2]) + comb2 = torch.cat((c2,c2b,c2c),dim=1) + g2 = self.GRU2(comb2, hx[1]) + c3 = self.conv3(g2) + + # plt.imshow(torch.absolute(c2[0,0]+1j*c2[0,1]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + channels2complex = (c3[:,0]+1j*c3[:,1]).unsqueeze(1) + + return channels2complex, [g1, g2] # ??? detach() ??? diff --git a/radionets/dl_framework/callbacks.py b/radionets/dl_framework/callbacks.py index bead13da..ae000e8f 100644 --- a/radionets/dl_framework/callbacks.py +++ b/radionets/dl_framework/callbacks.py @@ -91,6 +91,8 @@ def plot_lrs(self): plt.tight_layout() + + class NormCallback(Callback): _order = 2 @@ -119,29 +121,75 @@ def before_fit(self): class DataAug(Callback): _order = 3 + def __init__(self, vgg, physics_informed): + self.vgg = vgg + self.physics_informed = physics_informed + def before_batch(self): - x = self.xb[0].clone() + # x = self.xb[0].clone() y = self.yb[0].clone() + if self.physics_informed: + # y = self.yb[0][0].clone() + # base_mask = self.yb[0][1].clone() + # A = self.yb[0][2].clone() + x = self.xb[0][0].clone() + base_mask = self.xb[0][1].clone() + A = self.xb[0][2].clone() + else: + y = self.yb[0].clone() + randint = np.random.randint(0, 4, x.shape[0]) for i in range(x.shape[0]): - x[i, 0] = torch.rot90(x[i, 0], int(randint[i])) - x[i, 1] = torch.rot90(x[i, 1], int(randint[i])) - y[i, 0] = torch.rot90(y[i, 0], int(randint[i])) - y[i, 1] = torch.rot90(y[i, 1], int(randint[i])) - self.learn.xb = [x] + if x.shape[1] == 2: + x[i, 0] = torch.rot90(x[i, 0], int(randint[i])) + x[i, 1] = torch.rot90(x[i, 1], int(randint[i])) + else: + x[i, 0] = torch.rot90(x[i, 0], int(randint[i])) + + if not self.vgg: + y[i, 0] = torch.rot90(y[i, 0], int(randint[i])) + y[i, 1] = torch.rot90(y[i, 1], int(randint[i])) + if self.physics_informed: + base_mask[i] = torch.rot90(base_mask[i], int(randint[i]), dims=[0,1]) + A[i] = torch.rot90(A[i], int(randint[i]), dims=[0,1]) + # self.learn.xb = [x] self.learn.yb = [y] + if self.physics_informed: + # self.learn.yb = [(y, base_mask, A)] + self.learn.xb = [(x, base_mask, A)] + else: + self.learn.yb = [y] class SaveTempCallback(Callback): _order = 95 - def __init__(self, model_path): + def __init__(self, model_path, gan=False): self.model_path = model_path + self.gan = gan def after_epoch(self): p = Path(self.model_path).parent p.mkdir(parents=True, exist_ok=True) if (self.epoch + 1) % 10 == 0: out = p / f"temp_{self.epoch + 1}.model" - save_model(self, out) + save_model(self, out, self.gan) print(f"\nFinished Epoch {self.epoch + 1}, model saved.\n") + +# Best callback ever +class OverwriteOneBatch_CLEAN(Callback): + _order = 4 + def __init__(self, n_iter): + self.n_iter = n_iter + + def before_batch(self): + input = self.xb[0] + M = torch.zeros(input[0].shape).to('cuda') + self.learn.xb = [(self.xb[0][0],self.xb[0][1],self.xb[0][2],M)] + + for i in range(self.n_iter-1): + # self.model.zero_grad() + self._do_one_batch() + self.learn.xb = [(self.pred[0].clone().detach(),self.xb[0][1],self.xb[0][2],self.pred[1].clone().detach())] + + \ No newline at end of file diff --git a/radionets/dl_framework/data.py b/radionets/dl_framework/data.py index 63bc97cb..32fc476a 100644 --- a/radionets/dl_framework/data.py +++ b/radionets/dl_framework/data.py @@ -32,7 +32,7 @@ def do_normalisation(x, norm): class h5_dataset: - def __init__(self, bundle_paths, tar_fourier, amp_phase=None, source_list=False): + def __init__(self, bundle_paths, tar_fourier, amp_phase=None, source_list=False, vgg=False, physics_informed=False): """ Save the bundle paths and the number of bundles in one file. """ @@ -41,6 +41,8 @@ def __init__(self, bundle_paths, tar_fourier, amp_phase=None, source_list=False) self.tar_fourier = tar_fourier self.amp_phase = amp_phase self.source_list = source_list + self.vgg = vgg + self.physics_informed = physics_informed def __call__(self): return print("This is the h5_dataset class.") @@ -53,10 +55,14 @@ def __len__(self): def __getitem__(self, i): if self.source_list: + x = self.open_image("x", i) + y = self.open_image("z", i) + elif self.physics_informed: x = self.open_image("x", i) y = self.open_image("y", i) - z = self.open_image("z", i) - return x, y, z + base_mask = self.open_image("base_mask", i) + A = self.open_image("A", i) + return (x, base_mask.squeeze(0), A.squeeze(0)), y #x, (y, base_mask.squeeze(0), A.squeeze(0)) else: x = self.open_image("x", i) y = self.open_image("y", i) @@ -80,31 +86,26 @@ def open_image(self, var, i): h5py.File(self.bundles[bundle], "r") for bundle in bundle_unique ] bundle_paths_str = list(map(str, bundle_paths)) - if not var == "z": - data = torch.tensor( - np.array( - [ - bund[var][img] - for bund, bund_str in zip(bundle_paths, bundle_paths_str) - for img in image[ - bundle == bundle_unique[bundle_paths_str.index(bund_str)] - ] - ] - ) - ) + + if var == 'base_mask' or var == 'A' or var == "z": # Baselines and response matrices are the same for every single src_position/bundle + image[image != 0] = 0 + - else: - data = [ - np.array(bund[var + str(int(img))]) + + data = torch.tensor( + [ + bund[var][img] for bund, bund_str in zip(bundle_paths, bundle_paths_str) for img in image[ bundle == bundle_unique[bundle_paths_str.index(bund_str)] ] ] - return data + ) + if var == "x" or var == 'y': + if data.shape[1] == 1: - if var == "x" or self.tar_fourier is True: - if len(i) == 1: + data_channel = data + elif len(i) == 1: data_amp, data_phase = data[:, 0], data[:, 1] data_channel = torch.cat([data_amp, data_phase], dim=0) @@ -113,15 +114,24 @@ def open_image(self, var, i): data_channel = torch.cat([data_amp, data_phase], dim=1) else: - if data.shape[1] == 2: - raise ValueError( - "Two channeled data is used despite Fourier being False.\ - Set Fourier to True!" - ) - if len(i) == 1: - data_channel = data.reshape(data.shape[-1] ** 2) + if self.source_list: + data_channel = data + elif self.vgg: + data_channel = data + elif self.physics_informed: + data_channel = data else: - data_channel = data.reshape(-1, data.shape[-1] ** 2) + if data.shape[1] == 2: + raise ValueError( + "Two channeled data is used despite Fourier being False.\ + Set Fourier to True!" + ) + if len(i) == 1: + data_channel = data.reshape(data.shape[-1] ** 2) + else: + data_channel = data.reshape(-1, data.shape[-1] ** 2) + # if var == 'base_mask' or var == 'A': + # print(data_channel.shape) return data_channel.float() @@ -215,6 +225,18 @@ def save_fft_pair(path, x, y, z=None, name_x="x", name_y="y", name_z="z"): [hf.create_dataset(name_z + str(i), data=z[i]) for i in range(len(z))] hf.close() +def save_fft_pair_with_response(path, x, y, base_mask, A, name_x="x", name_y="y", name_base_mask="base_mask", name_A='A'): + """ + write fft_pairs created in second analysis step to h5 file + write response matrices & baselines + """ + with h5py.File(path, "w") as hf: + hf.create_dataset(name_x, data=x) + hf.create_dataset(name_y, data=y) + hf.create_dataset(name_base_mask, data=base_mask) + hf.create_dataset(name_A, data=A) + hf.close() + def open_fft_pair(path): """ @@ -247,7 +269,7 @@ def mean_and_std(array): return array.mean(), array.std() -def load_data(data_path, mode, fourier=False, source_list=False): +def load_data(data_path, mode, fourier=False, source_list=False, vgg=False, physics_informed=False): """ Load data set from a directory and return it as h5_dataset. @@ -266,9 +288,6 @@ def load_data(data_path, mode, fourier=False, source_list=False): dataset containing x and y images """ bundle_paths = get_bundles(data_path) - data = np.sort( - [path for path in bundle_paths if re.findall("samp_" + mode, path.name)] - ) - data = sorted(data, key=lambda f: int("".join(filter(str.isdigit, str(f))))) - ds = h5_dataset(data, tar_fourier=fourier, source_list=source_list) + data = [path for path in bundle_paths if re.findall("samp_" + mode, path.name)] + ds = h5_dataset(data, tar_fourier=fourier, source_list=source_list, vgg=vgg, physics_informed=physics_informed) return ds diff --git a/radionets/dl_framework/learner.py b/radionets/dl_framework/learner.py index e5c89e95..01912c72 100644 --- a/radionets/dl_framework/learner.py +++ b/radionets/dl_framework/learner.py @@ -8,11 +8,17 @@ AvgLossCallback, CudaCallback, ) -from fastai.optimizer import Adam +from fastai.optimizer import Adam, RMSProp from fastai.learner import Learner from fastai.data.core import DataLoaders from fastai.callback.schedule import ParamScheduler, combined_cos import radionets.dl_framework.loss_functions as loss_functions +from fastai.vision import models +# from radionets.dl_framework.architectures import superRes +import torchvision +from radionets.dl_training.utils import define_arch +from fastai.vision.gan import GANLearner, FixedGANSwitcher, _tk_diff, GANDiscriminativeLR +from fastai.callback.mixup import MixUp def get_learner( @@ -34,6 +40,7 @@ def define_learner( cbfs=[], test=False, lr_find=False, + gan=False, ): model_path = train_conf["model_path"] model_name = ( @@ -56,6 +63,13 @@ def define_learner( train_conf["lr_stop"], ) } + # lr_max = train_conf["lr_max"] + # div = 25. + # div_final = 1e5 + # pct_start = 0.25 + # moms = (0.95, 0.85) + # sched = {'lr': combined_cos(pct_start, lr_max/div, lr_max, lr_max/div_final), + # 'mom': combined_cos(pct_start, moms[0], moms[1], moms[0])} cbfs.extend([ParamScheduler(sched)]) if train_conf["gpu"]: cbfs.extend( @@ -63,12 +77,25 @@ def define_learner( CudaCallback, ] ) - if not test: + if not test and not gan: cbfs.extend( [ SaveTempCallback(model_path=model_path), AvgLossCallback, - DataAug, + DataAug(vgg=train_conf["vgg"], physics_informed=train_conf["physics_informed"]), + # OverwriteOneBatch_CLEAN(5), + # OverwriteOneBatch_CLEAN(10), + # MixUp(), + ] + ) + if gan: + cbfs.extend( + [ + SaveTempCallback(model_path=model_path, gan=gan), + AvgLossCallback, + DataAug(vgg=train_conf["vgg"], physics_informed=train_conf["physics_informed"]), + # WGANL1Callback, + # GANDiscriminativeLR, ] ) if train_conf["telegram_logger"] and not lr_find: @@ -84,6 +111,28 @@ def define_learner( else: loss_func = getattr(loss_functions, train_conf["loss_func"]) + if gan: + # gen_loss_func = getattr(loss_functions, 'gen_loss_func') #non physics informed + gen_loss_func = getattr(loss_functions, 'l1_wgan_GANCS') + crit_loss_func = getattr(loss_functions, 'crit_loss_func') + + generator = arch + critic = define_arch( + arch_name='GANCS_critic', img_size=train_conf["image_size"] + ) + # init_cnn(generator) + init_cnn(critic) + dls = DataLoaders.from_dsets( + data.train_ds, + data.valid_ds, + bs=data.train_dl.batch_size, + ) + switcher = FixedGANSwitcher(n_crit=1, n_gen=1) #GAN + # learn = GANLearner(dls, generator, critic, gen_loss_func, crit_loss_func, lr=lr, cbs=cbfs, opt_func=opt_func, switcher=switcher) #GAN + # learn = GANLearner.wgan(dls, generator, critic, lr=lr, cbs=cbfs, opt_func=RMSProp) #WGAN + learn = GANLearner(dls, generator, critic, gen_loss_func, _tk_diff, clip=0.01, switch_eval=False, lr=lr, cbs=cbfs, opt_func=RMSProp) #WGAN-l1 + return learn + # Combine model and data in learner learn = get_learner( data, arch, lr=lr, opt_func=opt_func, cb_funcs=cbfs, loss_func=loss_func diff --git a/radionets/dl_framework/loss_functions.py b/radionets/dl_framework/loss_functions.py index 6fe69051..e96e2bd3 100644 --- a/radionets/dl_framework/loss_functions.py +++ b/radionets/dl_framework/loss_functions.py @@ -5,6 +5,9 @@ import torch.nn.functional as F from pytorch_msssim import MS_SSIM from scipy.optimize import linear_sum_assignment +from radionets.dl_framework.architectures import superRes +from fastai.vision.gan import _tk_mean +import matplotlib.pyplot as plt class FeatureLoss(nn.Module): @@ -120,6 +123,301 @@ def l1(x, y): loss = l1(x, y) return loss +def l1_phyinfo(x, y): + l1 = nn.L1Loss() + return l1(x[1],y[0]) + +def l1_GANCS(x,y): + amp = y[:,0].clone().detach() + phase = y[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft).unsqueeze(1) + # change to two channels real/imag + # input = torch.zeros(y.shape) + # input[:,0] = spatial.real + # input[:,1] = spatial.imag + + l1 = nn.L1Loss() + + return l1(x,spatial) + +def l1_CLEAN(x,y): + amp = y[:,0].clone().detach() + phase = y[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft).unsqueeze(1) + + + l1 = nn.L1Loss() + + return l1((x[1][:,0]+1j*x[1][:,1]).unsqueeze(1),spatial) + +def l1_RIM(x,y): + amp = y[:,0].clone().detach() + phase = y[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft).unsqueeze(1) + + + l1 = nn.L1Loss() + loss = 0 + for eta in x: + loss += l1((eta[:,0]+1j*eta[:,1]).unsqueeze(1),spatial) + + loss = loss/len(x) + return loss + +def mse_RIM(x,y): + amp = y[:,0].clone().detach() + phase = y[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + compl_shift = torch.fft.fftshift(compl) + ifft = torch.fft.ifft2(compl_shift, norm="forward") + true = torch.fft.ifftshift(ifft).unsqueeze(1) + + + complex2channels_y = torch.cat((true.real,true.imag), dim=1) + + mse = nn.MSELoss() + loss = 0 + for eta in x: + complex2channels_x = torch.cat((eta.real,eta.imag), dim=1) + loss += mse(complex2channels_x*eta.shape[2]**2,complex2channels_y) + + loss = loss/len(x) + return loss + +def l1_wgan_GANCS(fake_pred,x,y): + amp = y[:,0].clone().detach() + phase = y[:,1].clone().detach() + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + spatial = torch.fft.ifftshift(ifft).unsqueeze(1) + + l1 = nn.L1Loss() + lamb = 1e-5 + + return l1(x,spatial)+lamb*_tk_mean(fake_pred, x, spatial) + +def dirty_model(x, y): + amp = x[1][:,0] + phase = x[1][:,1] + amp_rescaled = (10 ** (10 * amp) - 1) / 10 ** 10 + compl = amp_rescaled * torch.exp(1j * phase) + ifft = torch.fft.ifft2(compl) + pred = torch.fft.ifftshift(torch.absolute(ifft)).unsqueeze(1) + + # amp_t = y[0][:,0] + # phase_t = y[0][:,1] + # amp_rescaled_t = (10 ** (10 * amp_t) - 1) / 10 ** 10 + # compl_t = amp_rescaled_t * torch.exp(1j * phase_t) + # ifft_t = torch.fft.ifft2(compl_t) + # true = torch.fft.ifftshift(torch.absolute(ifft_t)).unsqueeze(1) + + + base_nums = torch.zeros(45) #hard code + n_tel = 10 #hardcode + c = 0 + for i in range(n_tel): + for j in range(n_tel): + if j<=i: + continue + base_nums[c] = 256 * (i + 1) + j + 1 + c += 1 + + base_mask = y[1] + A = y[2] + MD = torch.zeros(pred.shape, dtype=torch.complex64).to('cuda') + + + for idx, bn in enumerate(base_nums): + s_uv = torch.sum((base_mask == bn),3) + if not (base_mask == bn).any(): + continue + AI = torch.einsum('blm,bclm->bclm',A[...,idx],pred) + MD += torch.einsum('blm,bclm->bclm',s_uv,torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(AI)))) #spatial + + points = base_mask.clone() + points[points != 0] = 1 + points = torch.sum(points,3) + points[points == 0] = 1 + + MD = torch.fft.ifftshift(torch.absolute(torch.fft.ifft2(MD/points.unsqueeze(1)))) + + l1 = nn.L1Loss() + mse = nn.MSELoss() + loss = l1(pred, x[0]) + # loss = vgg19_feature_loss(MD,x[0]) + return loss + +# vgg19 = superRes.vgg19_feature_maps(5,4).eval().to('cuda:1') +def vgg19_feature_loss(x, y): + print() + if 'vgg19_feature_model_12' not in globals(): + global vgg19_feature_model_12 + # global vgg19_feature_model_22 + # global vgg19_feature_model_34 + # global vgg19_feature_model_44 + # global vgg19_feature_model_54 + vgg19_feature_model_12 = superRes.vgg19_feature_maps(1,2).eval().to('cuda:1') + # vgg19_feature_model_22 = superRes.vgg19_feature_maps(2,2).eval().to('cuda:1') + # vgg19_feature_model_34 = superRes.vgg19_feature_maps(3,4).eval().to('cuda:1') + # vgg19_feature_model_44 = superRes.vgg19_feature_maps(4,4).eval().to('cuda:1') + # vgg19_feature_model_54 = superRes.vgg19_feature_maps(5,4).eval().to('cuda:1') + + + mse = nn.MSELoss() + l1 = nn.L1Loss() + + # up1 = nn.Upsample(size=7, mode='nearest').to('cuda:1') + # up2 = nn.Upsample(size=15, mode='nearest').to('cuda:1') + # up3 = nn.Upsample(size=31, mode='nearest').to('cuda:1') + # up4 = nn.Upsample(size=63, mode='nearest').to('cuda:1') + # c34 = nn.Conv2d(256, 512, 1).to('cuda:1') + # c22 = nn.Conv2d(128, 512, 1).to('cuda:1') + # c12 = nn.Conv2d(64, 512, 1).to('cuda:1') + + # upx1 = up1(vgg19_feature_model_54(x)) + # upy1 = up1(vgg19_feature_model_54(y)) + + # upx2 = up2(vgg19_feature_model_44(x) + upx1) + # upy2 = up2(vgg19_feature_model_44(y) + upy1) + + # upx3 = up3(c34(vgg19_feature_model_34(x)) + upx2) + # upy3 = up3(c34(vgg19_feature_model_34(y)) + upy2) + + # upx4 = up4(c22(vgg19_feature_model_22(x)) + upx3) + # upy4 = up4(c22(vgg19_feature_model_22(y)) + upy3) + + # upx5 = (c12(vgg19_feature_model_12(x)) + upx4) + # upy5 = (c12(vgg19_feature_model_12(y)) + upy4) + + + # mix_x = (0.5*x[:,0]+0.5*x[:,1]).unsqueeze(1) + # mix_y = (0.5*y[:,0]+0.5*y[:,1]).unsqueeze(1) + # ones = torch.ones((x.shape[0], 1, x.shape[2], x.shape[3])).to('cuda:1') + # x_3c = torch.cat((x, ones), dim=1) + # y_3c = torch.cat((y, ones), dim=1) + + # loss = l1(vgg19_feature_model_22(x), vgg19_feature_model_22(y))# + l1(vgg19_feature_model_12(x), vgg19_feature_model_12(y)) + l1(vgg19_feature_model_34(x), vgg19_feature_model_34(y)) + l1(vgg19_feature_model_44(x), vgg19_feature_model_44(y)) + l1(vgg19_feature_model_54(x), vgg19_feature_model_54(y)) + loss = l1(vgg19_feature_model_22(x), vgg19_feature_model_22(y)) + return loss + +def gen_loss_func(fake_pred, x, y): + l1 = nn.L1Loss() + bce = nn.BCELoss() + + # mask = torch.zeros(x.shape).to('cuda:1') + # mask[:,:,31-10:31+10,31-10:31+10]=1 + # xm = torch.einsum('bcij,bcjk->bcik',x,mask) + # ym = torch.einsum('bcij,bcjk->bcik',y,mask) + + content_loss = l1(x,y) + # content_loss = automap_l2(x,y) + # content_loss = vgg19_feature_loss(x,y) + adv_loss = bce(fake_pred, torch.ones_like(fake_pred)) + + return content_loss + 1e-3*adv_loss + +def gen_loss_wgan_l1(fake_pred, x, y): + l1 = nn.L1Loss() + content_loss = l1(x[1], y[0]) + + + adv_loss = _tk_mean(fake_pred, x, y) + lamb = 1.5 # first:0.9 + + return lamb*content_loss + (1-lamb)*adv_loss + + +def gen_loss_func_physics_informed(fake_pred, x, y): + + + + bce = nn.BCELoss() + l1 = nn.L1Loss() + ######### physics informed stuff + # base_nums = torch.zeros(45) #hard code + # n_tel = 10 #hardcode + # c = 0 + # for i in range(n_tel): + # for j in range(n_tel): + # if j<=i: + # continue + # base_nums[c] = 256 * (i + 1) + j + 1 + # c += 1 + + # base_mask = y[1] + # A = y[2] + # MD = torch.zeros(x[1].shape, dtype=torch.complex64).to('cuda') + + + + + # for idx, bn in enumerate(base_nums): + # s_uv = torch.sum((base_mask == bn),3) + # if not (base_mask == bn).any(): + # continue + # AI = torch.einsum('blm,bclm->bclm',A[...,idx],x[1]) + # MD += torch.einsum('blm,bclm->bclm',s_uv,torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(AI)))) #spatial + + # points = base_mask.clone() + # points[points != 0] = 1 + # points = torch.sum(points,3) + # points[points == 0] = 1 + + # MD = torch.fft.ifftshift(torch.absolute(torch.fft.ifft2(MD/points.unsqueeze(1)))) + + + content_loss = l1(x[1], y[0]) + # print(fake_pred.requires_grad) + adv_loss = bce(fake_pred, torch.ones_like(fake_pred)) + + return 1e-3*adv_loss +content_loss + + + +def crit_loss_func(real_pred, fake_pred): + bce = nn.BCELoss() + loss = bce(real_pred, torch.ones_like(real_pred)) + bce(fake_pred, torch.zeros_like(fake_pred)) + # print(fake_pred.requires_grad) + return loss + +def cross_entropy(x,y): + loss = nn.CrossEntropyLoss() + return loss(x, y.squeeze().long()) + + +def l1_rnn(x, y): + l1 = nn.L1Loss() + x = torch.chunk(x, 4, dim=0) + + l = 0 + for i in range(4): + l += l1(x[i], y) + return l/4 + +def splitted_l1(x, y): + l1 = nn.L1Loss() + l = (10*l1(x[:,0], y[:,0]) + l1(x[:,1], y[:,1]))/2 + return l + +def l1_ssim(x,y): + fft_x, fft_y = inspec.fft_pred_torch(x,y) + l1 = nn.L1Loss() + print(inspec.ssim_torch(fft_x, fft_y).shape) + l = (l1(fft_x, fft_y) + (1-inspec.ssim_torch(fft_x, fft_y)))/2 + return l + + def l1_amp(x, y): l1 = nn.L1Loss() diff --git a/radionets/dl_framework/model.py b/radionets/dl_framework/model.py index c471e21a..3e752811 100644 --- a/radionets/dl_framework/model.py +++ b/radionets/dl_framework/model.py @@ -1,9 +1,14 @@ +from numpy.lib.function_base import diff import torch from torch import nn import torch.nn.functional as F from torch.nn.modules.utils import _pair from pathlib import Path from math import sqrt, pi +from fastcore.foundation import L +from torch.nn.common_types import _size_4_t +import numpy as np +import radionets.simulations.utils as utils class Lambda(nn.Module): @@ -91,6 +96,39 @@ def symmetry(x, mode="real"): x[:, i, j] = -x[:, u, v] return torch.rot90(x, 3, dims=(1, 2)) +def better_symmetry(x): + # rotation + x = torch.flip(x, [3]) + + # indices of upper and lower triangle + triu = torch.triu_indices(x.shape[2], x.shape[3], 1) + tril = torch.tril_indices(x.shape[2], x.shape[3], -1) + triu = torch.flip(triu, [1]) + + # sym amp and phase + x[:,0,tril[0], tril[1]] = x[:,0, triu[0], triu[1]] + x[:,1,tril[0], tril[1]] = -x[:,1, triu[0], triu[1]] + + # rotation + x = torch.flip(x, [3]) + + return x + +def tf_shift(x): + triu = torch.triu_indices(x.shape[2], x.shape[2], 0) + tf = torch.flip(x, [3])[:,:,triu[0], triu[1]].reshape(x.shape[0],x.shape[1],x.shape[2],int(x.shape[3]/2)+1) + + return tf + +def btf_shift(x): + btf = torch.zeros((x.shape[0],x.shape[1],x.shape[2], x.shape[3]*2-1)).cuda() + triu = torch.triu_indices(x.shape[2], x.shape[2], 0) + + btf[:,:,triu[0], triu[1]] = x[:,:].reshape(x.shape[0], x.shape[1], -1) + btf = torch.flip(btf, [3]) + + btf = better_symmetry(btf) + return btf class GeneralRelu(nn.Module): def __init__(self, leak=None, sub=None, maxv=None): @@ -254,7 +292,7 @@ def deconv(ni, nc, ks, stride, padding, out_padding): return layers -def load_pre_model(learn, pre_path, visualize=False): +def load_pre_model(learn, pre_path, visualize=False, gan=False): """ :param learn: object of type learner :param pre_path: string wich contains the path of the model @@ -280,20 +318,47 @@ def load_pre_model(learn, pre_path, visualize=False): learn.recorder.values = checkpoint["vals"] -def save_model(learn, model_path): - torch.save( - { - "model": learn.model.state_dict(), - "opt": learn.opt.state_dict(), - "epoch": learn.epoch, - "iters": learn.recorder.iters, - "vals": learn.recorder.values, - "train_loss": learn.avg_loss.loss_train, - "valid_loss": learn.avg_loss.loss_valid, - "lrs": learn.avg_loss.lrs, - }, - model_path, - ) +def save_model(learn, model_path, gan=False): + # print(learn.model.generator) + if not gan: + torch.save( + { + "model": learn.model.state_dict(), + "opt": learn.opt.state_dict(), + "epoch": learn.epoch, + "loss": learn.loss, + "iters": learn.recorder.iters, + "vals": learn.recorder.values, + "train_loss": learn.avg_loss.loss_train, + "valid_loss": learn.avg_loss.loss_valid, + "lrs": learn.avg_loss.lrs, + "recorder_train_loss": L(learn.recorder.values[0:]).itemgot(0), + "recorder_valid_loss": L(learn.recorder.values[0:]).itemgot(1), + "recorder_losses": learn.recorder.losses, + "recorder_lrs": learn.recorder.lrs, + }, + model_path, + ) + else: + torch.save( + { + "model": learn.model.generator.state_dict(), + "opt": learn.opt.state_dict(), + "epoch": learn.epoch, + "loss": learn.loss, + "iters": learn.recorder.iters, + "vals": learn.recorder.values, + "train_loss": learn.avg_loss.loss_train, + "valid_loss": learn.avg_loss.loss_valid, + "lrs": learn.avg_loss.lrs, + "recorder_train_loss": L(learn.recorder.values[0:]).itemgot(0), + "recorder_valid_loss": L(learn.recorder.values[0:]).itemgot(1), + "recorder_losses": learn.recorder.losses, + "recorder_lrs": learn.recorder.lrs, + }, + model_path, + ) + class LocallyConnected2d(nn.Module): @@ -397,3 +462,616 @@ def _conv_block(self, ni, nf, stride): nn.Conv2d(nf, nf, 3, stride=1, padding=1, bias=False), nn.BatchNorm2d(nf), ) + +class SRBlock_noBias(nn.Module): + def __init__(self, ni, nf, stride=1): + super().__init__() + self.convs = self._conv_block(ni, nf, stride) + self.idconv = nn.Identity() if ni == nf else nn.Conv2d(ni, nf, 1,bias=False) + self.pool = ( + nn.Identity() if stride == 1 else nn.AvgPool2d(2, ceil_mode=True) + ) # nn.AvgPool2d(8, 2, ceil_mode=True) + + def forward(self, x): + return self.convs(x) + self.idconv(self.pool(x)) + + def _conv_block(self, ni, nf, stride): + return nn.Sequential( + nn.Conv2d(ni, nf, 3, stride=stride, padding=1,bias=False), + nn.BatchNorm2d(nf), + nn.PReLU(), + nn.Conv2d(nf, nf, 3, stride=1, padding=1,bias=False), + nn.BatchNorm2d(nf), + ) + +class SRBlockPad(nn.Module): + def __init__(self, ni, nf, stride=1): + super().__init__() + self.convs = self._conv_block(ni, nf, stride) + self.idconv = nn.Identity() if ni == nf else nn.Conv2d(ni, nf, 1) + self.pool = ( + nn.Identity() if stride == 1 else nn.AvgPool2d(2, ceil_mode=True) + ) # nn.AvgPool2d(8, 2, ceil_mode=True) + + def forward(self, x): + return self.convs(x) + self.idconv(self.pool(x)) + + def _conv_block(self, ni, nf, stride): + return nn.Sequential( + BetterShiftPad((1,1,1,1)), + nn.Conv2d(ni, nf, 3, stride=stride, padding=0), + nn.BatchNorm2d(nf), + nn.PReLU(), + BetterShiftPad((1,1,1,1)), + nn.Conv2d(nf, nf, 3, stride=1, padding=0), + nn.BatchNorm2d(nf), + ) + +class EDSRBaseBlock(nn.Module): + def __init__(self, ni, nf, stride=1): + super().__init__() + self.convs = self._conv_block(ni,nf,stride) + self.idconv = nn.Identity() if ni == nf else nn.Conv2d(ni, nf, 1) + self.pool = nn.Identity() if stride == 1 else nn.AvgPool2d(2, ceil_mode=True)#nn.AvgPool2d(8, 2, ceil_mode=True) + + def forward(self, x): + return self.convs(x) + self.idconv(self.pool(x)) + + def _conv_block(self, ni, nf, stride): + return nn.Sequential( + nn.Conv2d(ni, nf, 3, stride=stride, padding=1), + nn.PReLU(), + nn.Conv2d(nf, nf, 3, stride=1, padding=1) + ) + +class RDB(nn.Module): + def __init__(self, ni, nf, stride=1): + super().__init__() + self.conv1 = self._conv_block(ni,nf,stride) + self.conv2 = self._conv_block(ni+nf,nf,stride) + self.conv3 = self._conv_block(ni+2*nf,nf,stride) + self.conv4 = self._conv_block(ni+3*nf,nf,stride) + self.conv5 = self._conv_block(ni+4*nf,nf,stride) + self.conv6 = self._conv_block(ni+5*nf,nf,stride) + + self.convFusion = nn.Conv2d(ni+6*nf, ni, 1, stride=1, padding=0, groups=2, bias=False) + + def forward(self, x): + x1_c = self.conv1(x) + cat = self._cat_split(x, x1_c) + x2_c = self.conv2(cat) + cat = self._cat_split(cat, x2_c) + x3_c = self.conv3(cat) + cat = self._cat_split(cat, x3_c) + x4_c = self.conv4(cat) + cat = self._cat_split(cat, x4_c) + x5_c = self.conv5(cat) + cat = self._cat_split(cat, x5_c) + x6_c = self.conv6(cat) + cat = self._cat_split(cat, x6_c) + + + return self.convFusion(cat) + x + + def _conv_block(self, ni, nf, stride): + return nn.Sequential( + nn.Conv2d(ni, nf, 3, stride=stride, padding=1, bias=False), + nn.PReLU() + ) + + def _cat_split(self, x, y): + x1, x2 = torch.chunk(x,2, dim=1) + y1, y2 = torch.chunk(y,2, dim=1) + return torch.cat((x1,y1,x2,y2), dim=1) + + +class FBB(nn.Module): + def __init__(self, ni, nf, stride=1, first=False): + super().__init__() + self.first = first + if first: + self.convCat = nn.Conv2d(ni*2, ni, 1, stride=1, padding=0, groups=2, bias=False) + else: + self.convCat = nn.Identity() + self.conv1 = self._conv_block(ni,nf,stride) + self.conv2 = self._conv_block(ni+nf,nf,stride) + self.conv3 = self._conv_block(ni+2*nf,nf,stride) + self.conv4 = self._conv_block(ni+3*nf,nf,stride) + self.conv5 = self._conv_block(ni+4*nf,nf,stride) + self.conv6 = self._conv_block(ni+5*nf,nf,stride) + + self.convFusion = nn.Conv2d(ni+6*nf, ni, 1, stride=1, padding=0, groups=2, bias=False) + + def forward(self, x): + # if self.first: + # comb = torch.chunk(x,2, dim=1) + # skip = comb[0] + # x = self._cat_split(comb[0], comb[1]) + # # x = self._cat_split(x, comb[2]) + # # x = self._cat_split(x, comb[3]) + # else: + # skip = x + + x_cc = self.convCat(x) + x1_c = self.conv1(x_cc) + cat = self._cat_split(x_cc, x1_c) + x2_c = self.conv2(cat) + cat = self._cat_split(cat, x2_c) + x3_c = self.conv3(cat) + cat = self._cat_split(cat, x3_c) + x4_c = self.conv4(cat) + cat = self._cat_split(cat, x4_c) + x5_c = self.conv5(cat) + cat = self._cat_split(cat, x5_c) + x6_c = self.conv6(cat) + cat = self._cat_split(cat, x6_c) + + + return self.convFusion(cat) + x_cc + + def _conv_block(self, ni, nf, stride): + return nn.Sequential( + nn.Conv2d(ni, nf, 3, stride=stride, padding=1, bias=False), + nn.PReLU() + ) + + def _cat_split(self, x, y): + x1, x2 = torch.chunk(x,2, dim=1) + y1, y2 = torch.chunk(y,2, dim=1) + return torch.cat((x1,y1,x2,y2), dim=1) + + +class _CirculationPadNd(nn.Module): + __constants__ = ['padding'] + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.pad(input, self.padding, 'circular') + + def extra_repr(self) -> str: + return '{}'.format(self.padding) + +class CirculationPad2d(_CirculationPadNd): + padding: _size_4_t + + def __init__(self, padding: _size_4_t) -> None: + super(CirculationPad2d, self).__init__() + self.padding = _pair(padding) + +class CirculationShiftPad(nn.Module): + padding: _size_4_t + + def __init__(self, padding: _size_4_t) -> None: + super(CirculationShiftPad, self).__init__() + self.padding = _pair(padding) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x = F.pad(input, self.padding, 'circular') + x[...,:self.padding[2],:] = 0 + x[...,-self.padding[3]:,:] = 0 + x[...,:,:self.padding[0]] = torch.roll(x[...,:,:self.padding[0]],1,2) + x[...,:,-self.padding[1]:] = torch.roll(x[...,:,-self.padding[1]:],-1,2) + x[...,:self.padding[2],:] = 0 + x[...,-self.padding[3]:,:] = 0 + return x + +def better_padding(input, padding): + in_shape = input.shape + paddable_shape = in_shape[2:] + + out_shape = in_shape[:2] + for idx, size in enumerate(paddable_shape): + out_shape += (size + padding[-(idx * 2 + 1)] + padding[-(idx * 2 + 2)],) + + # fill empty tensor of new shape with input + out = torch.zeros(out_shape, dtype=input.dtype, layout=input.layout, + device=input.device) + + out[..., padding[-2]:(out_shape[2]-padding[-1]), padding[-4]:(out_shape[3]-padding[-3])] = input + + # pad left + i0 = out_shape[3] - padding[-4] - padding[-3] + i1 = out_shape[3] - padding[-3] + o0 = 0 + o1 = padding[-4] + out[:, :, padding[-2]:out_shape[2]-padding[-1], o0:o1] = out[:, :, padding[-2]-1:out_shape[2]-padding[-1]-1, i0:i1] + + # pad right + i0 = padding[-4] + i1 = padding[-4] + padding[-3] + o0 = out_shape[3] - padding[-3] + o1 = out_shape[3] + out[:, :, padding[-2]:out_shape[2]-padding[-1], o0:o1] = out[:, :, padding[-2]+1:out_shape[2]-padding[-1]+1, i0:i1] + + return out + +class BetterShiftPad(nn.Module): + padding: _size_4_t + + def __init__(self, padding: _size_4_t) -> None: + super(BetterShiftPad, self).__init__() + self.padding = _pair(padding) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + x = better_padding(input, self.padding) + return x + +class HardDC(nn.Module): + def __init__(self, base_nums, n_tel): + super().__init__() + self.base_nums = torch.zeros(base_nums) + self.n_tel = n_tel + self.weights = nn.Parameter(torch.tensor(1).float()) + + def forward(self, x, input, A, base_mask): + c = 0 + for i in range(self.n_tel): + for j in range(self.n_tel): + if j<=i: + continue + self.base_nums[c] = 256 * (i + 1) + j + 1 + c += 1 + + + pred = torch.zeros((x.shape[0],1,x.shape[2],x.shape[3]), dtype=torch.complex64).to('cuda') + c = 0 + for idx, bn in enumerate(self.base_nums): + s_uv = torch.sum((base_mask == bn),3) + if not (base_mask == bn).any(): + continue + + xA = torch.einsum('bclm,blm->bclm',x,A[...,idx]) + x_prime = xA[:,0] + 1j*xA[:,1] #from 2 channels to complex for fft + k_prime = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(x_prime))) + y_prime = torch.einsum('blm,bclm->bclm',(1-s_uv),k_prime.unsqueeze(1)) + + Y = torch.einsum('blm,bclm->bclm', s_uv, input) + + full_k_space = Y + self.weights*y_prime + + pred += full_k_space # maybe a conj(A) missing, see paper 1910.07048 + c += 1 + + points = base_mask.clone() + points[points != 0] = 1 + points = torch.sum(points,3) + points[points == 0] = 1 + + + return torch.fft.fftshift(torch.fft.ifft2(torch.fft.fftshift(pred/c))) #divide by c because we summed c fully sampled maps in pred??? + +class SoftDC(nn.Module): + def __init__(self, base_nums, n_tel): + super().__init__() + self.base_nums = torch.zeros(base_nums) + self.n_tel = n_tel + self.weights = nn.Parameter(torch.tensor(1).float()) + + def forward(self, x, measured, A, base_mask): + c = 0 + for i in range(self.n_tel): + for j in range(self.n_tel): + if j<=i: + continue + self.base_nums[c] = 256 * (i + 1) + j + 1 + c += 1 + + + sum = torch.zeros((x.shape[0],1,x.shape[2],x.shape[3]), dtype=torch.complex64).to('cuda') + c = 0 + for idx, bn in enumerate(self.base_nums): + s_uv = torch.sum((base_mask == bn),3) + if not (base_mask == bn).any(): + continue + + xA = torch.einsum('bclm,blm->bclm',x,A[...,idx]) + x_prime = xA[:,0] + 1j*xA[:,1] #from 2 channels to complex for fft + k_prime = torch.fft.fftshift(torch.fft.fft2(torch.fft.fftshift(x_prime))) + y_prime = torch.einsum('blm,bclm->bclm',s_uv,k_prime.unsqueeze(1)) + + # Y = torch.einsum('blm,bclm->bclm', s_uv, input) + + diff = torch.fft.fftshift(torch.fft.ifft2(torch.fft.fftshift(y_prime))) + + sum += diff # maybe a conj(A) missing, see paper 1910.07048 + c += 1 + + sum = sum/c + + pred = torch.zeros(x.shape).to('cuda') + + pred[:,0] = sum.real.squeeze(1) + pred[:,1] = sum.imag.squeeze(1) + + + + + + return x + self.weights*(pred-measured) #divide by c because we summed c fully sampled maps in pred??? + + +def calc_DirtyBeam(base_mask): + s_uv = torch.sum(base_mask,3) + s_uv[s_uv != 0] = 1 + + b = torch.fft.fftshift(torch.fft.ifft2(torch.fft.fftshift(s_uv))) + beam = torch.zeros((b.shape[0],2, b.shape[1], b.shape[2])).to('cuda') + beam[:,0] = b.real.squeeze(1) + beam[:,1] = b.imag.squeeze(1) + return beam + + +def gauss(kernel_size, sigma): + # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2) + x_cord = torch.arange(kernel_size).to('cuda') + x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size) + y_grid = x_grid.t() + xy_grid = torch.stack([x_grid, y_grid], dim=-1) + + mean = (kernel_size - 1)/2. + variance = sigma**2. + + # Calculate the 2-dimensional gaussian kernel which is + # the product of two gaussian distributions for two different + # variables (in this case called x and y) + gaussian_kernel = (1./(2.*np.pi*variance)) *\ + torch.exp( + -torch.sum((xy_grid - mean)**2., dim=-1) /\ + (2*variance) + ) + # Make sure sum of values in gaussian kernel equals 1. + gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) + return gaussian_kernel + + +class ConvGRUCell(nn.Module): + def __init__(self, input_size, hidden_size, kernel_size, dilation=1, bias=True): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.kernel_size = kernel_size + self.dilation = dilation + self.bias = bias + + self.Wih = nn.Conv2d(input_size, 3*hidden_size, kernel_size, dilation=dilation, bias=bias, padding=dilation*(kernel_size-1)//2) + self.Whh = nn.Conv2d(input_size, 3*hidden_size, kernel_size, dilation=dilation, bias=bias, padding=dilation*(kernel_size-1)//2) + + def forward(self, x, hx=None): + if hx is None: + hx = torch.zeros((x.size(0), self.hidden_size) + x.size()[2:], requires_grad=False).to('cuda') + + ih = self.Wih(x).chunk(3, dim=1) + hh = self.Whh(hx).chunk(3, dim=1) + + z = torch.sigmoid(ih[0] + hh[0]) + r = torch.sigmoid(ih[1] + hh[1]) + n = torch.tanh(ih[2]+ r*hh[2]) + + # import matplotlib.pyplot as plt + # plt.imshow(torch.abs(hx[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + hx = (1-z)*hx + z*n + + return hx + +class ConvGRUCellBN(nn.Module): + def __init__(self, input_size, hidden_size, kernel_size, dilation=1, bias=True): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.kernel_size = kernel_size + self.dilation = dilation + self.bias = bias + + self.Wih = nn.Conv2d(input_size, 3*hidden_size, kernel_size, dilation=dilation, bias=bias, padding=dilation*(kernel_size-1)//2) + self.Whh = nn.Conv2d(input_size, 3*hidden_size, kernel_size, dilation=dilation, bias=bias, padding=dilation*(kernel_size-1)//2) + + self.bn1 = nn.BatchNorm2d(3*hidden_size) + self.bn2 = nn.BatchNorm2d(3*hidden_size) + + def forward(self, x, hx=None): + if hx is None: + hx = torch.zeros((x.size(0), self.hidden_size) + x.size()[2:], requires_grad=False).to('cuda') + + ih = self.bn1(self.Wih(x)).chunk(3, dim=1) + hh = self.bn2(self.Whh(hx)).chunk(3, dim=1) + + z = torch.sigmoid(ih[0] + hh[0]) + r = torch.sigmoid(ih[1] + hh[1]) + n = torch.tanh(ih[2]+ r*hh[2]) + + # import matplotlib.pyplot as plt + # plt.imshow(torch.abs(hx[0,0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + hx = (1-z)*hx + z*n + + return hx + +def gradFunc(x, y, A, base_mask, n_tel, base_nums): + does_require_grad = x.requires_grad + with torch.enable_grad(): + x.requires_grad_(True) + + mask = torch.sum(base_mask, 3) + mask[mask != 0] = 1 + + fx = torch.fft.fft2(torch.fft.fftshift(x[:,0]+1j*x[:,1])) # shift x low freq to corner & fft + pfx = torch.einsum('blm,blm->blm', mask, torch.fft.ifftshift(fx)) # shift low freq to center + py = torch.einsum('blm,blm->blm', mask, y) + + # import matplotlib.pyplot as plt + # plt.imshow(torch.absolute(py[0]-pfx[0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + diff = (py-pfx)**2 + # import matplotlib.pyplot as plt + # plt.imshow(torch.absolute(py[0]-pfx[0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + # diff_shift = torch.fft.fftshift(diff) # shift low freq to corner + + # error = torch.sum(torch.fft.ifftshift(torch.fft.ifft2(diff_shift))) # ifft & shift low freq to center + + # grad = torch.zeros((ift.size(0), 2) + ift.size()[1:]).to('cuda') + # grad[:,0] = ift.real.squeeze(1) + # grad[:,1] = ift.imag.squeeze(1) + + grad_x = torch.autograd.grad(torch.sum(diff), inputs=x, retain_graph=does_require_grad, + create_graph=does_require_grad)[0] + + + # import matplotlib.pyplot as plt + # plt.imshow(np.absolute((grad_x[:,0]+1j*grad_x[:,1])[0].cpu().detach().numpy())) + # plt.colorbar() + # plt.show() + # import matplotlib.pyplot as plt + # plt.imshow(np.absolute(torch.fft.ifftshift(torch.fft.ifft(torch.fft.fftshift(grad_x)))[0].cpu().detach().numpy())) + # plt.colorbar() + # plt.show() + + + + # import matplotlib.pyplot as plt + # # print(grad_x.shape) + # plt.imshow(torch.absolute(diff[0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + x.requires_grad_(does_require_grad) + + return grad_x + + +def gradFunc2(x, y, A, base_mask, n_tel, base_nums): + + + mask = torch.sum(base_mask, 3) + mask[mask != 0] = 1 + + fx = torch.fft.fft2(torch.fft.fftshift(x[:,0]+1j*x[:,1])) # shift x low freq to corner & fft + pfx = torch.einsum('blm,blm->blm', mask, torch.fft.ifftshift(fx)) # shift low freq to center + py = torch.einsum('blm,blm->blm', mask, y) # mask y otherwise diff is not zero if x=y since we do a lot of ffts + + diff = pfx-py + diff_shift = torch.fft.fftshift(diff) # shift low freq to corner + error = torch.fft.ifftshift(torch.fft.ifft2(diff_shift)) + + + grad = torch.zeros((error.size(0), 2) + error.size()[1:]).to('cuda') + grad[:,0] = error.real.squeeze(1) + grad[:,1] = error.imag.squeeze(1) + + # import matplotlib.pyplot as plt + # # # print(grad_x.shape) + # plt.imshow(torch.absolute(error[0]).cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + return grad + +def gradFunc_putzky(x, y): + base_mask = y[1] + data = y[0] + does_require_grad = x.requires_grad + with torch.enable_grad(): + x.requires_grad_(True) + + mask = torch.sum(base_mask, 3) + mask[mask != 0] = 1 + + fx = torch.fft.fft2(torch.fft.fftshift(x), norm="forward") + + # import matplotlib.pyplot as plt + # plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(fx))))[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + pfx = torch.einsum('blm,bclm->bclm', torch.flip(mask, [1]), torch.fft.ifftshift(fx)) # shift low freq to center + # py = torch.einsum('blmforward,bclm->bclm', mask, data) + # plt.imshow(torch.abs(pfx-data)[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + difference = pfx-data + # import matplotlib.pyplot as plt + # plt.imshow(abs(difference[0,0].cpu().detach().numpy())) + # plt.colorbar() + # plt.show() + + # import matplotlib.pyplot as plt + # plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(data))-torch.fft.ifftshift(torch.fft.ifft2(pfx))))[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + #import matplotlib.pyplot as plt + #plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(data))))[0,0].cpu().detach().numpy()) + #plt.colorbar() + #plt.show() + + #import matplotlib.pyplot as plt + #plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(fx))))[0,0].cpu().detach().numpy()) + #plt.colorbar() + #plt.show() + + + chi2 = torch.sum(torch.square(torch.abs(difference))) + + + grad_x = torch.autograd.grad(chi2, inputs=x, retain_graph=does_require_grad, + create_graph=does_require_grad)[0] + + # import matplotlib.pyplot as plt + # # # print(grad_x.shape) + # # test[test==0] = 1 + # # plt.figure(figsize=(12,8)) + # plt.imshow((torch.abs(grad_x))[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + x.requires_grad_(does_require_grad) + + return grad_x + +def manual_grad(x, y): + base_mask = y[1] + data = y[0] + + mask = torch.sum(base_mask, 3) + mask[mask != 0] = 1 + + fx = torch.fft.fft2(torch.fft.fftshift(x), norm="forward") + + # import matplotlib.pyplot as plt + # plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(fx))))[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + pfx = torch.einsum('blm,bclm->bclm', torch.flip(mask, [1]), torch.fft.ifftshift(fx)) # shift low freq to center + # py = torch.einsum('blmforward,bclm->bclm', mask, data) + # plt.imshow(torch.abs(pfx-data)[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + difference = pfx-data + + # import matplotlib.pyplot as plt + # plt.imshow((torch.abs(torch.fft.ifftshift(torch.fft.ifft2(data))-torch.fft.ifftshift(torch.fft.ifft2(pfx))))[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + grad_x = torch.fft.ifftshift(torch.fft.ifft2(torch.fft.fftshift(difference), norm='forward')) + + # import matplotlib.pyplot as plt + # # # print(grad_x.shape) + # # test[test==0] = 1 + # # plt.figure(figsize=(12,8)) + # plt.imshow((torch.abs(grad_x))[0,0].cpu().detach().numpy()) + # plt.colorbar() + # plt.show() + + return grad_x + +def fft_conv(a,b): + multiply = (torch.fft.fft2(torch.fft.fftshift(a))*torch.fft.fft2(torch.fft.fftshift(b), norm="ortho")) + ifft =torch.fft.ifftshift(torch.fft.ifft2(multiply)) + import matplotlib.pyplot as plt + # plt.imshow(abs(ifft[0,0].cpu().detach().numpy())) + # plt.colorbar() + # plt.show() + return ifft diff --git a/radionets/dl_training/scripts/start_training.py b/radionets/dl_training/scripts/start_training.py index f394945b..09193715 100644 --- a/radionets/dl_training/scripts/start_training.py +++ b/radionets/dl_training/scripts/start_training.py @@ -18,6 +18,8 @@ ) from radionets.evaluation.train_inspection import after_training_plots from pathlib import Path +import os +os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" @click.command() @@ -29,6 +31,7 @@ "train", "lr_find", "plot_loss", + "gan", ], case_sensitive=False, ), @@ -61,6 +64,8 @@ def main(configuration_path, mode): fourier=train_conf["fourier"], batch_size=train_conf["bs"], source_list=train_conf["source_list"], + vgg=train_conf["vgg"], + physics_informed=train_conf["physics_informed"] ) # get image size @@ -87,11 +92,12 @@ def main(configuration_path, mode): # load pretrained model if train_conf["pre_model"] != "none": learn.create_opt() - load_pre_model(learn, train_conf["pre_model"]) + load_pre_model(learn, train_conf["pre_model"], gan=False) # Train the model, except interrupt try: - learn.fit(train_conf["num_epochs"]) + # learn.fine_tune(train_conf["num_epochs"]) + learn.fit(train_conf["num_epochs"]) except KeyboardInterrupt: pop_interrupt(learn, train_conf) @@ -99,6 +105,37 @@ def main(configuration_path, mode): if train_conf["inspection"]: after_training_plots(train_conf, rand=True) + + if mode == "gan": + # check out path and look for existing model files + check_outpath(train_conf["model_path"], train_conf) + + click.echo("Start training of the GAN model.\n") + + # define_learner + learn = define_learner( + data, + arch, + train_conf, + gan=True + ) + + # load pretrained model + if train_conf["pre_model"] != "none": + learn.create_opt() + load_pre_model(learn, train_conf["pre_model"]) + + # Train the model, except interrupt + try: + # learn.fine_tune(train_conf["num_epochs"]) + learn.fit(train_conf["num_epochs"]) + except KeyboardInterrupt: + pop_interrupt(learn, train_conf, True) + + end_training(learn, train_conf, True) + + if train_conf["inspection"]: + after_training_plots(train_conf, rand=True) if mode == "lr_find": click.echo("Start lr_find.\n") diff --git a/radionets/dl_training/utils.py b/radionets/dl_training/utils.py index 8e2a0523..248a5ac8 100644 --- a/radionets/dl_training/utils.py +++ b/radionets/dl_training/utils.py @@ -8,10 +8,10 @@ from radionets.evaluation.train_inspection import create_inspection_plots -def create_databunch(data_path, fourier, source_list, batch_size): +def create_databunch(data_path, fourier, source_list, batch_size, vgg, physics_informed): # Load data sets - train_ds = load_data(data_path, "train", source_list=source_list, fourier=fourier) - valid_ds = load_data(data_path, "valid", source_list=source_list, fourier=fourier) + train_ds = load_data(data_path, "train", source_list=source_list, fourier=fourier, vgg=vgg, physics_informed=physics_informed) + valid_ds = load_data(data_path, "valid", source_list=source_list, fourier=fourier, vgg=vgg, physics_informed=physics_informed) # Create databunch with defined batchsize bs = batch_size @@ -34,6 +34,8 @@ def read_config(config): train_conf["lr"] = config["hypers"]["lr"] train_conf["fourier"] = config["general"]["fourier"] + train_conf["vgg"] = config["general"]["vgg"] + train_conf["physics_informed"] = config["general"]["physics_informed"] train_conf["amp_phase"] = config["general"]["amp_phase"] train_conf["arch_name"] = config["general"]["arch_name"] train_conf["loss_func"] = config["general"]["loss_func"] @@ -67,16 +69,16 @@ def check_outpath(model_path, train_conf): def define_arch(arch_name, img_size): - if "filter_deep" in arch_name or "resnet" in arch_name: + if "filter_deep" in arch_name or "resnet" in arch_name or "Net" in arch_name: arch = getattr(architecture, arch_name)(img_size) else: arch = getattr(architecture, arch_name)() return arch -def pop_interrupt(learn, train_conf): +def pop_interrupt(learn, train_conf, gan=False): if click.confirm("KeyboardInterrupt, do you want to save the model?", abort=False): - model_path = train_conf["model_path"] + model_path = Path(train_conf["model_path"]) # save model print(f"Saving the model after epoch {learn.epoch}") save_model(learn, model_path) @@ -92,9 +94,9 @@ def pop_interrupt(learn, train_conf): sys.exit(1) -def end_training(learn, train_conf): +def end_training(learn, train_conf, gan=False): # Save model - save_model(learn, Path(train_conf["model_path"])) + save_model(learn, Path(train_conf["model_path"]), gan=False) # Plot loss plot_loss(learn, Path(train_conf["model_path"])) diff --git a/radionets/evaluation/plotting.py b/radionets/evaluation/plotting.py index 250f2d3d..ea94808a 100644 --- a/radionets/evaluation/plotting.py +++ b/radionets/evaluation/plotting.py @@ -194,6 +194,7 @@ def visualize_with_fourier( out_path: str which contains the output path """ # reshaping and splitting in real and imaginary part if necessary + img_pred = img_pred.cpu() inp_real, inp_imag = img_input[0], img_input[1] real_pred, imag_pred = img_pred[0], img_pred[1] real_truth, imag_truth = img_truth[0], img_truth[1] diff --git a/radionets/evaluation/train_inspection.py b/radionets/evaluation/train_inspection.py index 3105e3cd..6db85ae7 100644 --- a/radionets/evaluation/train_inspection.py +++ b/radionets/evaluation/train_inspection.py @@ -329,7 +329,7 @@ def evaluate_viewing_angle(conf): pred = torch.cat((pred, pred_2), dim=1) ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred = get_ifft(pred.cpu(), amp_phase=conf["amp_phase"]) m_truth, n_truth, alpha_truth = calc_jet_angle(torch.tensor(ifft_truth)) m_pred, n_pred, alpha_pred = calc_jet_angle(torch.tensor(ifft_pred)) @@ -377,7 +377,7 @@ def evaluate_dynamic_range(conf): pred = torch.cat((pred, pred_2), dim=1) ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred = get_ifft(pred.cpu(), amp_phase=conf["amp_phase"]) dr_truth, dr_pred, _, _ = calc_dr(ifft_truth, ifft_pred) dr_truths = np.append(dr_truths, dr_truth) @@ -434,7 +434,7 @@ def evaluate_ms_ssim(conf): pred = torch.cat((pred, pred_2), dim=1) ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred = get_ifft(pred.cpu(), amp_phase=conf["amp_phase"]) if img_size < 160: ifft_truth = pad_unsqueeze(torch.tensor(ifft_truth)) @@ -461,7 +461,7 @@ def evaluate_ms_ssim(conf): def evaluate_mean_diff(conf): # create DataLoader loader = create_databunch( - conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"] + conf["data_path"], conf["fourier"], conf["source_list"], conf["batch_size"], conf["rim"], ) model_path = conf["model_path"] out_path = Path(model_path).parent / "evaluation" @@ -478,15 +478,25 @@ def evaluate_mean_diff(conf): # iterate trough DataLoader for i, (img_test, img_true) in enumerate(tqdm(loader)): + + if conf["rim"]: + pred = eval_model(img_test, model)[9] + else: + pred = eval_model(img_test, model) - pred = eval_model(img_test, model) if conf["model_path_2"] != "none": pred_2 = eval_model(img_test, model_2) pred = torch.cat((pred, pred_2), dim=1) - ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) - + ifft_truth = np.fft.ifftshift(get_ifft(img_true, amp_phase=conf["amp_phase"])) + if conf["rim"]: + ifft_pred = abs(pred.cpu()) + else: + ifft_pred = get_ifft(pred.cpu(), amp_phase=conf["amp_phase"]) + # import matplotlib.pyplot as plt + # plt.imshow(ifft_truth[0]) + # plt.colorbar() + # plt.show() for pred, truth in zip(ifft_pred, ifft_truth): blobs_pred, blobs_truth = calc_blobs(pred, truth) flux_pred, flux_truth = crop_first_component( @@ -532,7 +542,7 @@ def evaluate_area(conf): pred = torch.cat((pred, pred_2), dim=1) ifft_truth = get_ifft(img_true, amp_phase=conf["amp_phase"]) - ifft_pred = get_ifft(pred, amp_phase=conf["amp_phase"]) + ifft_pred = get_ifft(pred.cpu(), amp_phase=conf["amp_phase"]) for pred, truth in zip(ifft_pred, ifft_truth): val = area_of_contour(pred, truth) diff --git a/radionets/evaluation/utils.py b/radionets/evaluation/utils.py index 784344c0..ce5c438c 100644 --- a/radionets/evaluation/utils.py +++ b/radionets/evaluation/utils.py @@ -30,13 +30,18 @@ def source_list_collate(batch): return torch.stack(x), torch.stack(y), z -def create_databunch(data_path, fourier, source_list, batch_size): +def create_databunch(data_path, fourier, source_list, batch_size, rim): # Load data sets + if rim: + mode = "valid" + else: + mode = "test" test_ds = load_data( data_path, - mode="test", + mode=mode, fourier=fourier, source_list=source_list, + physics_informed=rim ) # Create databunch with defined batchsize and check for source_list @@ -66,6 +71,7 @@ def read_config(config): eval_conf["source_list"] = config["general"]["source_list"] eval_conf["arch_name_2"] = config["general"]["arch_name_2"] eval_conf["diff"] = config["general"]["diff"] + eval_conf["rim"] = config["general"]["rim"] eval_conf["vis_pred"] = config["inspection"]["visualize_prediction"] eval_conf["vis_source"] = config["inspection"]["visualize_source_reconstruction"] @@ -237,7 +243,21 @@ def load_pretrained_model(arch_name, model_path, img_size=63): arch: architecture object architecture with pretrained weigths """ - if "filter_deep" in arch_name or "resnet" in arch_name: + if 'vgg19' in arch_name: + arch = getattr(architecture, arch_name)() + elif 'GANCS' in arch_name: + arch = getattr(architecture, arch_name)() + elif 'CLEANNN' in arch_name: + arch = getattr(architecture, arch_name)() + elif 'RIM' in arch_name: + arch = getattr(architecture, arch_name)() + elif 'RIM_SR' in arch_name: + arch = getattr(architecture, arch_name)() + elif 'putzky' in arch_name: + arch = getattr(architecture, arch_name)() + elif 'automap' in arch_name: + arch = getattr(architecture, arch_name)() + elif "filter_deep" in arch_name or "resnet" or "Res" in arch_name: arch = getattr(architecture, arch_name)(img_size) else: arch = getattr(architecture, arch_name)() @@ -293,17 +313,20 @@ def eval_model(img, model, test=False): pred: n 1d arrays predicted images """ - if len(img.shape) == (3): - img = img.unsqueeze(0) model.eval() - if not test: - model.cuda() - with torch.no_grad(): - if not test: + model.cuda() + if isinstance(img, tuple) or isinstance(img, list): + img = (img[0].float().cuda(), img[1].float().cuda(), img[2].float().cuda()) #img[0].unsqueeze(0) + with torch.no_grad(): + pred = model(img) + else: + if len(img.shape) == (3): + img = img.unsqueeze(0) + with torch.no_grad(): pred = model(img.float().cuda()) - else: - pred = model(img.float()) - return pred.cpu() + if isinstance(pred, tuple): + return pred + return pred def get_ifft(array, amp_phase=False): @@ -354,8 +377,8 @@ def fft_pred(pred, truth, amp_phase=True): a = pred[:, 0, :, :] b = pred[:, 1, :, :] - a_true = truth[0, :, :] - b_true = truth[1, :, :] + a_true = truth[:, 0, :, :] + b_true = truth[:, 1, :, :] if amp_phase: amp_pred_rescaled = (10 ** (10 * a) - 1) / 10 ** 10 @@ -370,10 +393,10 @@ def fft_pred(pred, truth, amp_phase=True): compl_pred = a + 1j * b compl_true = a_true + 1j * b_true - ifft_pred = np.fft.ifft2(compl_pred) - ifft_true = np.fft.ifft2(compl_true) - - return np.absolute(ifft_pred)[0], np.absolute(ifft_true) + ifft_pred = np.fft.ifft2(np.fft.fftshift(compl_pred)) + ifft_true = np.fft.ifft2(np.fft.fftshift(compl_true)) + # return ifft_pred.real, ifft_true.real + return np.absolute(ifft_pred), np.absolute(ifft_true) def save_pred(path, x, y, z, name_x="x", name_y="y", name_z="z"): diff --git a/radionets/simulations/layouts/eht.txt b/radionets/simulations/layouts/eht.txt new file mode 100644 index 00000000..ccf0da8a --- /dev/null +++ b/radionets/simulations/layouts/eht.txt @@ -0,0 +1,8 @@ +#station_name X Y Z dish_dia el_low el_high SEFD altitude +ALMA50 2225037.1851 -5441199.162 -2479303.4629 84.7 15 85 110 5030 +SMTO -1828796.2 -5054406.8 3427865.2 10 15 85 11900 3185 +LMT -768713.9637 -5988541.7982 2063275.9472 50 15 85 560 4640 +Hawaii8 -5464523.4 -2493147.08 2150611.75 20.8 15 85 4900 4205 +PV 5088967.9 -301681.6 3825015.8 30 15 85 2900 2850 +PdBI 4523998.4 468045.24 4460309.76 7 15 85 1600 2550 +GLT 1500692 -1191735 6066409 12 15 85 4744 3210 \ No newline at end of file diff --git a/radionets/simulations/layouts/eht_spt.txt b/radionets/simulations/layouts/eht_spt.txt new file mode 100644 index 00000000..b05373c8 --- /dev/null +++ b/radionets/simulations/layouts/eht_spt.txt @@ -0,0 +1,9 @@ +station_name X Y Z dish_dia el_low el_high SEFD altitude +ALMA50 2225037.1851 -5441199.162 -2479303.4629 84.7 15 85 110 5030 +SMTO -1828796.2 -5054406.8 3427865.2 10 15 85 11900 3185 +LMT -768713.9637 -5988541.7982 2063275.9472 50 15 85 560 4640 +Hawaii8 -5464523.4 -2493147.08 2150611.75 20.8 15 85 4900 4205 +PV 5088967.9 -301681.6 3825015.8 30 15 85 2900 2850 +PdBI 4523998.4 468045.24 4460309.76 7 15 85 1600 2550 +SPT 0 0 -6359587.3 12 15 85 7300 2800 +GLT 1500692 -1191735 6066409 12 15 85 4744 3210 \ No newline at end of file diff --git a/radionets/simulations/layouts/layouts.py b/radionets/simulations/layouts/layouts.py index e6a7e6a9..52f0b112 100644 --- a/radionets/simulations/layouts/layouts.py +++ b/radionets/simulations/layouts/layouts.py @@ -9,3 +9,9 @@ def vlba(): x, y, z, _, _ = np.genfromtxt(file_dir / "vlba.txt", unpack=True) ant_pos = np.array([x, y, z]) return ant_pos + +def eht(): + _, x, y, z, _, _, _, _, _ = np.genfromtxt(file_dir / "eht.txt", unpack=True) + print(x) + ant_pos = np.array([x, y, z]) + return ant_pos diff --git a/radionets/simulations/process_vlbi.py b/radionets/simulations/process_vlbi.py new file mode 100644 index 00000000..8c37ac01 --- /dev/null +++ b/radionets/simulations/process_vlbi.py @@ -0,0 +1,771 @@ +import os +from tqdm import tqdm +from numpy import savez_compressed +from radionets.simulations.utils import ( + get_fft_bundle_paths, + get_real_bundle_paths, + prepare_fft_images, + interpol, +) +from radionets.dl_framework.data import ( + open_fft_bundle, + save_fft_pair, + save_fft_pair_with_response, +) +from radionets.simulations.uv_simulations import sample_freqs +import h5py +import numpy as np +from astropy.io import fits +from PIL import Image +import cv2 +import radionets.dl_framework.data as dt +import re +from natsort import natsorted, ns +from PIL import Image +import os +import vipy.simulation.utils as ut +import vipy.layouts.layouts as layouts +import astropy.constants as const +from astropy import units as un +import vipy.simulation.scan as scan + +# set env flags to catch BLAS used for scipy/numpy +# to only use 1 cpu, n_cpus will be totally controlled by csky +# flags from mirco +os.environ['MKL_NUM_THREADS'] = "12" +os.environ['NUMEXPR_NUM_THREADS'] = "12" +os.environ['OMP_NUM_THREADS'] = "12" +os.environ['OPENBLAS_NUM_THREADS'] = "12" +os.environ['VECLIB_MAXIMUM_THREADS'] = "12" + + +def process_data( + data_path, + # amp_phase, + # real_imag, + # fourier, + # compressed, + # interpolation, + # specific_mask, + # antenna_config, + # lon=None, + # lat=None, + # steps=None, +): + + print(f"\n Loading VLBI data set.\n") + bundles = dt.get_bundles('/net/big-tank/POOL/users/sfroese/vipy/eht/m87/blackhole/') + freq = 227297*10**6 # hard code #eht 227297 + bundles_target = dt.get_bundles(bundles[1]) + bundles_input = dt.get_bundles(bundles[0]) + bundle_paths_target = natsorted(bundles_target) + bundle_paths_input = natsorted(bundles_input) + size = len(bundle_paths_target) + img = np.zeros((size,256,256)) + samps = np.zeros((size,4,21000)) # hard code + for i in tqdm(range(size)): + sampled = bundle_paths_input[i] + target = bundle_paths_target[i] + + img[i] = np.asarray(Image.open(str(target))) + # img[i] = img[i]/np.sum(img[i]) + + with fits.open(sampled) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + samps[i] = [np.append(data['UU--']*freq,-data['UU--']*freq),np.append(data['VV--']*freq,-data['VV--']*freq),np.append(ap,ap),np.append(ph,-ph)] + + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0][0] + v_0 = samps[0][1] + N = 63 # hard code + mask = np.zeros((N,N,u_0.shape[0])) + fov = 0.00018382*np.pi/(3600*180) # hard code #default 0.00018382 + # delta_u = 1/(fov*N/256) # hard code + delta_u = 1/(fov) + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + mask = np.flip(mask, [0]) + points = np.sum(mask, 2) + points[points==0] = 1 + samp_img = np.zeros((size,2,N,N)) + img_resized = np.zeros((size,N,N)) + for i in tqdm(range(samps.shape[0])): + samp_img[i][0] = np.matmul(mask, samps[i][2].T)/points + samp_img[i][0] = (np.log10(samp_img[i][0] + 1e-10) / 10) + 1 + samp_img[i][1] = np.matmul(mask, samps[i][3].T)/points + img_resized[i] = cv2.resize(img[i], (N,N)) + img_resized[i] = img_resized[i]/np.sum(img_resized[i]) + + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + truth_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(img_resized, axes=(1,2)), axes=(1,2)), axes=(1,2)) + fft_scaled_truth = prepare_fft_images(truth_fft, True, False) + + out = data_path + "/samp_train0.h5" + save_fft_pair(out, samp_img[:2300], fft_scaled_truth[:2300]) + out = data_path + "/samp_valid0.h5" + save_fft_pair(out, samp_img[2300:], fft_scaled_truth[2300:]) + + + +def process_data_dirty_model(data_path, freq, n_positions, fov_asec, layout): + + print(f"\n Loading VLBI data set.\n") + bundles = dt.get_bundles(data_path) + freq = freq*10**6 # mhz hard code #eht 227297 + uvfits = dt.get_bundles(bundles[2]) + imgs = dt.get_bundles(bundles[4]) + configs = dt.get_bundles(bundles[1]) + uv_srt = natsorted(uvfits, alg=ns.PATH) + img_srt = natsorted(imgs, alg=ns.PATH) + size = 1000 + for p in tqdm(range(n_positions)): + N = 64 # hard code + with fits.open(uv_srt[p*1000]) as hdul: + n_sampled = hdul[0].data.shape[0] #number of sampled points + baselines = hdul[0].data['Baseline'] + baselines = np.append(baselines,baselines) + unique_telescopes = hdul[3].data.shape[0] + unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 + + # response matrices + A = response(configs[p], N, unique_telescopes, unique_baselines, layout) + + img = np.zeros((size,128,128)) + samps = np.zeros((size,4,n_sampled*2)) + print(f"\n Load subset.\n") + for i in np.arange(p*1000, p*1000+1000): + # print(i) + sampled = uv_srt[i] + target = img_srt[i] # +1000 because I had to only grid images from 1000-1999 + + img[i-p*1000] = np.asarray(Image.open(str(target))) + + with fits.open(sampled) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + samps[i-p*1000] = [np.append(data['UU--']*freq,-data['UU--']*freq),np.append(data['VV--']*freq,-data['VV--']*freq),np.append(ap,ap),np.append(ph,-ph)] + + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0][0] + v_0 = samps[0][1] + mask = np.zeros((N,N,u_0.shape[0])) + + base_mask = np.zeros((N,N,int(unique_baselines))) + + fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 + # delta_u = 1/(fov*N/256) # hard code + delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max + # delta_u = (2*max(np.max(u_0),np.max(v_0))/N) # test gridding pixel size + # biggest_baselines = 8611*1e3 + # wave = const.c/(freq/un.second)/un.meter + # uv_max = biggest_baselines/wave + # delta_u = uv_max/N + # print(delta_u) + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + base = np.unique(baselines[mask[i,j].astype(bool)]) + base_mask[i,j,:base.shape[0]] = base + + mask = np.flip(mask, [0]) + points = np.sum(mask, 2) + points[points==0] = 1 + samp_img = np.zeros((size,2,N,N)) + img_resized = np.zeros((size,N,N)) + for i in range(samps.shape[0]): + samp_img[i][0] = np.matmul(mask, samps[i][2].T)/points + samp_img[i][0] = (np.log10(samp_img[i][0] + 1e-10) / 10) + 1 + samp_img[i][1] = np.matmul(mask, samps[i][3].T)/points + img_resized[i] = cv2.resize(img[i], (N,N)) + img_resized[i] = img_resized[i]/np.sum(img_resized[i]) + + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + truth_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(img_resized, axes=(1,2)), axes=(1,2)), axes=(1,2)) + fft_scaled_truth = prepare_fft_images(truth_fft, True, False) + + out = data_path + "/h5/samp_train"+ str(p) +".h5" + save_fft_pair_with_response(out, samp_img[:800], fft_scaled_truth[:800], np.expand_dims(base_mask,0), np.expand_dims(A,0)) + out = data_path + "/h5/samp_valid"+ str(p) + ".h5" + save_fft_pair_with_response(out, samp_img[800:], fft_scaled_truth[800:], np.expand_dims(base_mask,0), np.expand_dims(A,0)) + + +def response(config, N, unique_telescopes, unique_baselines, layout='vlba'): + rc = ut.read_config(config) + array_layout = layouts.get_array_layout(layout) + src_crd = rc['src_coord'] + + wave = const.c/((float(rc['channel'].split(':')[0]))*10**6/un.second)/un.meter + rd = scan.rd_grid(rc['fov_size']*np.pi/(3600*180),N, src_crd) + E = scan.getE(rd, array_layout, wave, src_crd) + A = np.zeros((N,N,int(unique_baselines))) + counter = 0 + for i in range(int(unique_telescopes)): + for j in range(int(unique_telescopes)): + if i == j or j < i: + continue + A[:,:,counter] = E[:,:,i]*E[:,:,j] + counter += 1 + + return A + + +def process_measurement(data_path, file, config, fov_asec): + + print(f"\n Loading VLBI data set.\n") + configs = config + size = 1 + N=64 + with fits.open(file) as hdul: + n_sampled = hdul[0].data.shape[0] #number of sampled points + baselines = hdul[0].data['Baseline'] + + unique_telescopes = hdul[3].data.shape[0] + unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 + freq = hdul[0].header[37] + offset = hdul[2].data['IF FREQ'] + for o in offset[0][1:]: + break + baselines = np.append(baselines,hdul[0].data['Baseline']) + baselines = np.append(baselines,baselines) + # response matrices + A = response(configs, N, unique_telescopes, unique_baselines, 'vlba') + + samps = np.zeros((size,4,n_sampled*2)) + print(f"\n Load subset.\n") + + + with fits.open(file) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x)[:,0] + y = np.squeeze(y)[:,0] + w = np.squeeze(w)[:,0] + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + u = np.array([]) + v = np.array([]) + for f in offset[0]: + u = np.append(u,data['UU--']*(freq+f)) + v = np.append(v,data['VV--']*(freq+f)) + break + samps = [np.append(u,-u),np.append(v,-v),np.append(ap,ap),np.append(ph,-ph)] + import matplotlib.pyplot as plt + plt.plot(samps[0], samps[1], 'x') + plt.show() + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0] + v_0 = samps[1] + mask = np.zeros((N,N,u_0.shape[0])) + + base_mask = np.zeros((N,N,int(unique_baselines))) + + fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 + delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max + + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + base = np.unique(baselines[mask[i,j].astype(bool)]) + base_mask[i,j,:base.shape[0]] = base + + mask = np.flip(mask, [0]) + points = np.sum(mask, 2) + points[points==0] = 1 + img_resized = np.zeros((size,N,N)) + + samp_img = np.zeros((size,2,N,N)) + print(mask.shape) + print(samps[2].shape) + samp_img[0,0] = np.matmul(mask, samps[2].T)/points + samp_img[0,0] = (np.log10(samp_img[0,0] + 1e-10) / 10) + 1 + samp_img[0,1] = np.matmul(mask, samps[3].T)/points + + + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + + out = data_path + "/h5/samp_meas.h5" + save_fft_pair_with_response(out, samp_img, samp_img, np.expand_dims(base_mask,0), np.expand_dims(A,0)) + + +def process_eht(data_path, file, config, fov_asec): + + print(f"\n Loading VLBI data set.\n") + configs = config + size = 1 + N=64 + with fits.open(file) as hdul: + baselines = hdul[0].data['Baseline'] + + unique_telescopes = 8 + unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 + freq = 229071e6#hdul[0].header[37] + baselines = np.append(baselines,baselines) + # response matrices + A = response(configs, N, unique_telescopes, unique_baselines, 'eht') + + # samps = np.zeros((size,4,n_sampled*2)) + print(f"\n Load subset.\n") + + + with fits.open(file) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt((x*w)**2+(y*w)**2) + ph = np.angle(x*w+1j*y*w) + u = np.array([]) + v = np.array([]) + u = np.append(u,data['UU---SIN']*(freq)) + v = np.append(v,data['VV---SIN']*(freq)) + samps = [np.append(u,-u),np.append(v,-v),np.append(ap,ap),np.append(ph,-ph)] + + # plt.plot(samps[0], samps[1], 'x') + # plt.show() + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0] + v_0 = samps[1] + mask = np.zeros((N,N,u_0.shape[0])) + + base_mask = np.zeros((N,N,int(unique_baselines))) + + fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 + delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max + + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + base = np.unique(baselines[mask[i,j].astype(bool)]) + base_mask[i,j,:base.shape[0]] = base + + mask = np.flip(mask, [0]) + import matplotlib.pyplot as plt + # plt.imshow(np.sum(mask, 2)) + # plt.show() + points = np.sum(mask, 2) + points[points==0] = 1 + + samp_img = np.zeros((size,2,N,N)) + print(mask.shape) + print(samps[2].shape) + samp_img[0,0] = np.matmul(mask, samps[2].T)/points + samp_img[0,0] = (np.log10(samp_img[0,0] + 1e-10) / 10) + 1 + samp_img[0,1] = np.matmul(mask, samps[3].T)/points + plt.imshow(samp_img[0,0]) + plt.colorbar() + plt.show() + plt.imshow(samp_img[0,1]) + plt.colorbar() + plt.show() + + + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + + out = data_path + "/eht_hi_test_DPG_startmod.h5" + save_fft_pair_with_response(out, samp_img, samp_img, np.expand_dims(base_mask,0), np.expand_dims(A,0)) + +def process_eht_hist(data_path, file, config, fov_asec): + + print(f"\n Loading VLBI data set.\n") + configs = config + size = 1 + N=64 + with fits.open(file) as hdul: + baselines = hdul[0].data['Baseline'] + + unique_telescopes = 8 + unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 + freq = 229071e6#hdul[0].header[37] + baselines = np.append(baselines,baselines) + # response matrices + A = response(configs, N, unique_telescopes, unique_baselines, 'eht') + + # samps = np.zeros((size,4,n_sampled*2)) + print(f"\n Load subset.\n") + + + with fits.open(file) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + u = np.array([]) + v = np.array([]) + u = np.append(u,data['UU---SIN']*(freq)) + v = np.append(v,data['VV---SIN']*(freq)) + samps = [np.append(u,-u),np.append(v,-v),np.append(ap,ap),np.append(ph,-ph)] + + # plt.plot(samps[0], samps[1], 'x') + # plt.show() + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0] + v_0 = samps[1] + mask = np.zeros((N,N,u_0.shape[0])) + + base_mask = np.zeros((N,N,int(unique_baselines))) + + fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 + delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max + binpos = np.arange(N//2+1)*delta_u + binsneg = -np.flip(np.arange(N//2+1)*delta_u) + bins = np.append(binsneg,binpos) + bins = np.unique(bins) + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + base = np.unique(baselines[mask[i,j].astype(bool)]) + base_mask[i,j,:base.shape[0]] = base + + import matplotlib.pyplot as plt + # plt.imshow(np.sum(mask, 2)) + # plt.show() + amp_cal,_,_,_ = plt.hist2d(samps[1],samps[0],bins=bins, weights=np.append(ap,ap)) + phase_cal,_,_,_ = plt.hist2d(samps[1],samps[0],bins=bins, weights=np.append(ph,-ph)) + points_cal,_,_,_ = plt.hist2d(samps[1],samps[0],bins=bins) + points_cal[points_cal==0]=1 + amp_cal = amp_cal/points_cal + phase_cal = phase_cal/points_cal + + samp_img = np.zeros((size,2,N,N)) + print(mask.shape) + print(samps[2].shape) + samp_img[0,0] = amp_cal + samp_img[0,0] = (np.log10(samp_img[0,0] + 1e-10) / 10) + 1 + samp_img[0,1] = phase_cal + plt.imshow(points_cal) + plt.colorbar() + plt.show() + plt.imshow(samp_img[0,0]) + plt.colorbar() + plt.show() + plt.imshow(samp_img[0,1]) + plt.colorbar() + plt.show() + + + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + + out = data_path + "/eht_hi_test_DPG.h5" + save_fft_pair_with_response(out, samp_img, samp_img, np.expand_dims(base_mask,0), np.expand_dims(A,0)) + + + + +def process_data_dirty_model_noisy(data_path, freq, n_positions, fov_asec, layout): + + print(f"\n Loading VLBI data set.\n") + bundles = dt.get_bundles(data_path) + freq = freq*10**6 # mhz hard code #eht 227297 + uvfits = dt.get_bundles(bundles[3]) + imgs = dt.get_bundles(bundles[2]) + configs = dt.get_bundles(bundles[0]) + uv_srt = natsorted(uvfits, alg=ns.PATH)[50000:] + img_srt = natsorted(imgs, alg=ns.PATH)[50000:] + configs = natsorted(configs, alg=ns.PATH)[50:] + size = 1000 + for p in tqdm(range(n_positions)): + N = 64 # hard code + with fits.open(uv_srt[p*1000]) as hdul: + n_sampled = hdul[0].data.shape[0] #number of sampled points + baselines = hdul[0].data['Baseline'] + baselines = np.append(baselines,baselines) + unique_telescopes = hdul[3].data.shape[0] + unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 + + # response matrices + A = response(configs[p], N, unique_telescopes, unique_baselines, layout) + + img = np.zeros((size,256,256)) + samps = np.zeros((size,4,n_sampled*2)) + print(f"\n Load subset.\n") + for i in np.arange(p*1000, p*1000+1000): + # print(i) + sampled = uv_srt[i] + target = img_srt[i] # +1000 because I had to only grid images from 1000-1999 + + img[i-p*1000] = np.asarray(Image.open(str(target))) + + with fits.open(sampled) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + samps[i-p*1000] = [np.append(data['UU--']*freq,-data['UU--']*freq),np.append(data['VV--']*freq,-data['VV--']*freq),np.append(ap,ap),np.append(ph,-ph)] + + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0][0] + v_0 = samps[0][1] + mask = np.zeros((N,N,u_0.shape[0])) + + base_mask = np.zeros((N,N,int(unique_baselines))) + + fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 + # delta_u = 1/(fov*N/256) # hard code + delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max + # delta_u = (2*max(np.max(u_0),np.max(v_0))/N) # test gridding pixel size + # biggest_baselines = 8611*1e3 + # wave = const.c/(freq/un.second)/un.meter + # uv_max = biggest_baselines/wave + # delta_u = uv_max/N + # print(delta_u) + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + base = np.unique(baselines[mask[i,j].astype(bool)]) + base_mask[i,j,:base.shape[0]] = base + + mask = np.flip(mask, [0]) + points = np.sum(mask, 2) + points[points==0] = 1 + samp_img = np.zeros((size,2,N,N)) + img_resized = np.zeros((size,N,N)) + for i in range(samps.shape[0]): + samp_img[i][0] = np.matmul(mask, samps[i][2].T)/points + # samp_img[i][0] = (np.log10(samp_img[i][0] + 1e-10) / 10) + 1 + samp_img[i][1] = np.matmul(mask, samps[i][3].T)/points + img_resized[i] = cv2.resize(img[i], (N,N)) + img_resized[i] = img_resized[i]/np.sum(img_resized[i]) + + ### nooiiiiiseeeee + np.random.seed(42) + noise = np.random.normal(size=(size, N, N)) + m = np.zeros((1000,64,64)) + m[:] = np.sum(mask, 2) + m[m != 0] = 1 + ft_noise = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(noise))) + ft_noise[m == 0] = 0 + noise_dirty = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(ft_noise))) + + compl = samp_img[:,0]*np.exp(1j*samp_img[:,1]) + dirty_img = abs(np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(compl)))) + for idx, d in enumerate(dirty_img): + max = np.max(d) + std = np.std(noise_dirty[idx]) + snr = np.random.uniform(2,10) + alpha = max/(std*snr) + dirty_img[idx] = dirty_img[idx] + abs(noise_dirty[idx]*alpha) + + measured = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(dirty_img))) + mask = np.sum(mask, 2) + for idx, m in enumerate(measured): + samp_img[idx][0] = (np.log10(np.abs(m) + 1e-10) / 10) + 1 + samp_img[idx][0][mask == 0] = 0 + samp_img[idx][1] = np.angle(m) + samp_img[idx][1][mask == 0] = 0 + import matplotlib.pyplot as plt + plt.imshow(samp_img[0][0]) + plt.colorbar() + plt.show() + + + # truth_fft = np.array([np.fft.fft2(np.fft.fftshift(img)) for im in img_resized]) + truth_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(img_resized, axes=(1,2)), axes=(1,2)), axes=(1,2)) + fft_scaled_truth = prepare_fft_images(truth_fft, True, False) + + out = data_path + "/h5/bh/samp_train"+ str(p) +".h5" + save_fft_pair_with_response(out, samp_img[:800], fft_scaled_truth[:800], np.expand_dims(base_mask,0), np.expand_dims(A,0)) + out = data_path + "/h5/bh/samp_valid"+ str(p) + ".h5" + save_fft_pair_with_response(out, samp_img[800:], fft_scaled_truth[800:], np.expand_dims(base_mask,0), np.expand_dims(A,0)) + +def process_data_dirty_model_noisy_pointSource(data_path, freq, n_positions, fov_asec, layout): + + print(f"\n Loading VLBI data set.\n") + bundles = dt.get_bundles(data_path) + print(bundles) + freq = freq*10**6 # mhz hard code #eht 227297 + uvfits = dt.get_bundles(bundles[2]) + imgs = dt.get_bundles(bundles[4]) + configs = dt.get_bundles(bundles[1]) + uv_srt = natsorted(uvfits, alg=ns.PATH) + img_srt = natsorted(imgs, alg=ns.PATH) + size = 1000 + for p in tqdm(range(n_positions)): + N = 64 # hard code + with fits.open(uv_srt[p*1000]) as hdul: + n_sampled = hdul[0].data.shape[0] #number of sampled points + baselines = hdul[0].data['Baseline'] + baselines = np.append(baselines,baselines) + unique_telescopes = hdul[3].data.shape[0] + unique_baselines = (unique_telescopes**2 - unique_telescopes)/2 + + # response matrices + A = response(configs[p], N, unique_telescopes, unique_baselines, layout) + + img = np.zeros((size,128,128)) + samps = np.zeros((size,4,n_sampled*2)) + print(f"\n Load subset.\n") + for i in np.arange(p*1000, p*1000+1000): + # print(i) + sampled = uv_srt[i] + target = img_srt[i] # +1000 because I had to only grid images from 1000-1999 + + img[i-p*1000] = np.asarray(Image.open(str(target))) + + with fits.open(sampled) as hdul: + data = hdul[0].data + cmplx = data['DATA'] + x = cmplx[...,0,0] + y = cmplx[...,0,1] + w = cmplx[...,0,2] + x = np.squeeze(x) + y = np.squeeze(y) + w = np.squeeze(w) + ap = np.sqrt(x**2+y**2) + ph = np.angle(x+1j*y) + samps[i-p*1000] = [np.append(data['UU--']*freq,-data['UU--']*freq),np.append(data['VV--']*freq,-data['VV--']*freq),np.append(ap,ap),np.append(ph,-ph)] + + print(f"\n Gridding VLBI data set.\n") + + # Generate Mask + u_0 = samps[0][0] + v_0 = samps[0][1] + mask = np.zeros((N,N,u_0.shape[0])) + + base_mask = np.zeros((N,N,int(unique_baselines))) + + fov = fov_asec*np.pi/(3600*180) # hard code #default 0.00018382 + # delta_u = 1/(fov*N/256) # hard code + delta_u = 1/(fov) # with a set N this is the same as zooming in since N*delta_u can be smaller than u_max + # delta_u = (2*max(np.max(u_0),np.max(v_0))/N) # test gridding pixel size + # biggest_baselines = 8611*1e3 + # wave = const.c/(freq/un.second)/un.meter + # uv_max = biggest_baselines/wave + # delta_u = uv_max/N + # print(delta_u) + for i in range(N): + for j in range(N): + u_cell = (j-N/2)*delta_u + v_cell = (i-N/2)*delta_u + mask[i,j] = ((u_cell <= u_0) & (u_0 <= u_cell+delta_u)) & ((v_cell <= v_0) & (v_0 <= v_cell+delta_u)) + + base = np.unique(baselines[mask[i,j].astype(bool)]) + base_mask[i,j,:base.shape[0]] = base + + mask = np.flip(mask, [0]) + points = np.sum(mask, 2) + points[points==0] = 1 + samp_img = np.zeros((size,2,N,N)) + img_resized = np.zeros((size,N,N)) + for i in range(samps.shape[0]): + samp_img[i][0] = np.matmul(mask, samps[i][2].T)/points + # samp_img[i][0] = (np.log10(samp_img[i][0] + 1e-10) / 10) + 1 + samp_img[i][1] = np.matmul(mask, samps[i][3].T)/points + img_resized[i] = cv2.resize(img[i], (N,N)) + img_resized[i] = img_resized[i]/np.sum(img_resized[i]) + + ### nooiiiiiseeeee + np.random.seed(42) + noise = np.random.normal(size=(size, N, N)) + m = np.zeros((1000,64,64)) + m[:] = np.sum(mask, 2) + m[m != 0] = 1 + ft_noise = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(noise))) + ft_noise[m == 0] = 0 + noise_dirty = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(ft_noise))) + + compl = samp_img[:,0]*np.exp(1j*samp_img[:,1]) + dirty_img = abs(np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(compl)))) + for idx, d in enumerate(dirty_img): + max = np.max(d) + std = np.std(noise_dirty[idx]) + snr = np.random.uniform(2,10) + alpha = max/(std*snr) + dirty_img[idx] = dirty_img[idx] + abs(noise_dirty[idx]*alpha) + + measured = np.fft.ifftshift(np.fft.fft2(np.fft.fftshift(dirty_img))) + mask = np.sum(mask, 2) + for idx, m in enumerate(measured): + samp_img[idx][0] = (np.log10(np.abs(m) + 1e-10) / 10) + 1 + samp_img[idx][0][mask == 0] = 0 + samp_img[idx][1] = np.angle(m) + samp_img[idx][1][mask == 0] = 0 + import matplotlib.pyplot as plt + plt.imshow(samp_img[0][0]) + plt.colorbar() + plt.show() + + + #point source label + position = np.zeros((size,N,N)) + result = np.array([np.unravel_index(np.argmax(r), r.shape) for r in img_resized]) + for i in range(2): + position[i,result[i][0],result[i][1]] = 1 + plt.imshow(position[0]) + plt.colorbar() + plt.show() + + + truth_fft = np.fft.fftshift(np.fft.fft2(np.fft.fftshift(img_resized, axes=(1,2)), axes=(1,2)), axes=(1,2)) + fft_scaled_truth = prepare_fft_images(truth_fft, True, False) + + out = data_path + "/h5/samp_train"+ str(p) +".h5" + save_fft_pair_with_response(out, samp_img[:800], fft_scaled_truth[:800], np.expand_dims(base_mask,0), np.expand_dims(A,0)) + out = data_path + "/h5/samp_valid"+ str(p) + ".h5" + save_fft_pair_with_response(out, samp_img[800:], fft_scaled_truth[800:], np.expand_dims(base_mask,0), np.expand_dims(A,0)) diff --git a/radionets/simulations/scripts/simulate_images.py b/radionets/simulations/scripts/simulate_images.py index ec622154..8cf9ea77 100644 --- a/radionets/simulations/scripts/simulate_images.py +++ b/radionets/simulations/scripts/simulate_images.py @@ -2,6 +2,8 @@ import toml from radionets.simulations.simulate import create_fft_images, sample_fft_images from radionets.simulations.utils import check_outpath, read_config, calc_norm +import os +os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" @click.command() diff --git a/radionets/simulations/utils.py b/radionets/simulations/utils.py index 252f689b..4351f11c 100644 --- a/radionets/simulations/utils.py +++ b/radionets/simulations/utils.py @@ -10,6 +10,8 @@ from tqdm import tqdm from scipy import interpolate from skimage.transform import resize +from scipy import interpolate +from natsort import natsorted from radionets.dl_framework.data import ( save_fft_pair, open_fft_pair, @@ -222,6 +224,22 @@ def get_fft_bundle_paths(data_path, ftype, mode): return bundle_paths +def get_real_bundle_paths(data_path): + bundles = get_bundles(data_path) + bundles_target = get_bundles(bundles[3]) + bundles_input = get_bundles(bundles[1]) + bundle_paths_target = [ + path for path in bundles_target if re.findall(f"[0-9].fits", path.name) + ] + bundle_paths_input = [ + path for path in bundles_input if re.findall(f"[0-9].oifits", path.name) + ] + bundle_paths_input = natsorted(bundle_paths_input) + bundle_paths_target = natsorted(bundle_paths_target) + + return [bundle_paths_input, bundle_paths_target] + + def prepare_fft_images(fft_images, amp_phase, real_imag): if amp_phase: amp, phase = split_amp_phase(fft_images) diff --git a/radionets/simulations/uv_plots.py b/radionets/simulations/uv_plots.py index 20990cfc..909a690d 100644 --- a/radionets/simulations/uv_plots.py +++ b/radionets/simulations/uv_plots.py @@ -100,7 +100,7 @@ class object with antenna positions and baselines between telescopes x_enu_ant, y_enu_ant, marker="o", - markersize=6, + markersize=15, color="#1f77b4", linestyle="none", label="Antenna positions", @@ -111,7 +111,7 @@ class object with antenna positions and baselines between telescopes marker="*", linestyle="none", color="#ff7f0e", - markersize=15, + markersize=20, transform=ccrs.Geodetic(), zorder=10, label="Projected source", @@ -120,8 +120,9 @@ class object with antenna positions and baselines between telescopes if baselines is True: plot_baselines(antenna) - plt.legend(fontsize=16, markerscale=1.5) + plt.legend(fontsize=16, markerscale=1.5,loc='lower center',bbox_to_anchor=(-0.2, 0)) plt.tight_layout() + return plt.gcf() def animate_baselines(source, antenna, filename, fps=5): diff --git a/tests/evaluate.toml b/tests/evaluate.toml index 16b95091..853c6e41 100644 --- a/tests/evaluate.toml +++ b/tests/evaluate.toml @@ -20,6 +20,7 @@ arch_name = "SRResNet_bigger_no_symmetry" arch_name_2 = "none" output_format = "png" diff = true +rim = false [inspection] visualize_prediction = true @@ -38,4 +39,4 @@ evaluate_dynamic_range = false evaluate_ms_ssim = false evaluate_mean_diff = false evaluate_area = false -evaluate_point = false \ No newline at end of file +evaluate_point = false