diff --git a/python/art/NST.py b/python/art/NST.py new file mode 100644 index 0000000..8b97da8 --- /dev/null +++ b/python/art/NST.py @@ -0,0 +1,199 @@ +import cv2 as cv +import numpy as np +import torch +from torchvision import transforms +from torch.autograd import Variable +from torch.optim import LBFGS +import os +from models.definitions.vgg19 import Vgg19 + +IMAGENET_MEAN_255 = [123.675, 116.28, 103.53] +IMAGENET_STD_NEUTRAL = [1, 1, 1] + +def load_image(img_path,target_shape="None"): + ''' + Load and resize the image. + ''' + if not os.path.exists(img_path): + raise Exception(f'Path not found: {img_path}') + img = cv.imread(img_path)[:, :, ::-1] # convert BGR to RGB when reading + if target_shape is not None: + if isinstance(target_shape, int) and target_shape != -1: + current_height, current_width = img.shape[:2] + new_height = target_shape + new_width = int(current_width * (new_height / current_height)) + img = cv.resize(img, (new_width, new_height), interpolation=cv.INTER_CUBIC) + else: + img = cv.resize(img, (target_shape[1], target_shape[0]), interpolation=cv.INTER_CUBIC) + img = img.astype(np.float32) + img /= 255.0 + return img + +def prepare_img(img_path, target_shape, device): + ''' + Normalize the image. + ''' + img = load_image(img_path, target_shape=target_shape) + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: x.mul(255)), + transforms.Normalize(mean=IMAGENET_MEAN_255, std=IMAGENET_STD_NEUTRAL)]) + img = transform(img).to(device).unsqueeze(0) + return img + +def save_image(img, img_path): + if len(img.shape) == 2: + img = np.stack((img,) * 3, axis=-1) + cv.imwrite(img_path, img[:, :, ::-1]) # convert RGB to BGR while writing + +def generate_out_img_name(config): + ''' + Generate a name for the output image. + Example: 'c1-s1.jpg' + where c1: content_img_name, and + s1: style_img_name. + ''' + prefix = os.path.basename(config['content_img_name']).split('.')[0] + '_' + os.path.basename(config['style_img_name']).split('.')[0] + suffix = f'{config["img_format"][1]}' + return prefix + suffix + +def save_and_maybe_display(optimizing_img, dump_path, config, img_id, num_of_iterations): + ''' + Save the generated image. + If saving_freq == -1, only the final output image will be saved. + Else, intermediate images can be saved too. + ''' + saving_freq = -1 + out_img = optimizing_img.squeeze(axis=0).to('cpu').detach().numpy() + out_img = np.moveaxis(out_img, 0, 2) + + if img_id == num_of_iterations-1 : + img_format = config['img_format'] + out_img_name = str(img_id).zfill(img_format[0]) + img_format[1] if saving_freq != -1 else generate_out_img_name(config) + dump_img = np.copy(out_img) + dump_img += np.array(IMAGENET_MEAN_255).reshape((1, 1, 3)) + dump_img = np.clip(dump_img, 0, 255).astype('uint8') + cv.imwrite(os.path.join(dump_path, out_img_name), dump_img[:, :, ::-1]) + + +def prepare_model(device): + ''' + Load VGG19 model into local cache. + ''' + model = Vgg19(requires_grad=False, show_progress=True) + content_feature_maps_index = model.content_feature_maps_index + style_feature_maps_indices = model.style_feature_maps_indices + layer_names = model.layer_names + content_fms_index_name = (content_feature_maps_index, layer_names[content_feature_maps_index]) + style_fms_indices_names = (style_feature_maps_indices, layer_names) + return model.to(device).eval(), content_fms_index_name, style_fms_indices_names + +def gram_matrix(x, should_normalize=True): + ''' + Generate gram matrices of the representations of content and style images. + ''' + (b, ch, h, w) = x.size() + features = x.view(b, ch, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) + if should_normalize: + gram /= ch * h * w + return gram + +def total_variation(y): + ''' + Calculate total variation. + ''' + return torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])) + +def build_loss(neural_net, optimizing_img, target_representations, content_feature_maps_index, style_feature_maps_indices, config): + ''' + Calculate content_loss, style_loss, and total_variation_loss. + ''' + target_content_representation = target_representations[0] + target_style_representation = target_representations[1] + current_set_of_feature_maps = neural_net(optimizing_img) + current_content_representation = current_set_of_feature_maps[content_feature_maps_index].squeeze(axis=0) + content_loss = torch.nn.MSELoss(reduction='mean')(target_content_representation, current_content_representation) + style_loss = 0.0 + current_style_representation = [gram_matrix(x) for cnt, x in enumerate(current_set_of_feature_maps) if cnt in style_feature_maps_indices] + for gram_gt, gram_hat in zip(target_style_representation, current_style_representation): + style_loss += torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0]) + style_loss /= len(target_style_representation) + tv_loss = total_variation(optimizing_img) + total_loss = config['content_weight'] * content_loss + config['style_weight'] * style_loss + config['tv_weight'] * tv_loss + return total_loss, content_loss, style_loss, tv_loss + +def make_tuning_step(neural_net, optimizer, target_representations, content_feature_maps_index, style_feature_maps_indices, config): + ''' + Performs a step in the tuning loop. + (We are tuning only the pixels, not the weights.) + ''' + def tuning_step(optimizing_img): + total_loss, content_loss, style_loss, tv_loss = build_loss(neural_net, optimizing_img, target_representations, content_feature_maps_index, style_feature_maps_indices, config) + total_loss.backward() + optimizer.step() + optimizer.zero_grad() + return total_loss, content_loss, style_loss, tv_loss + return tuning_step + +def neural_style_transfer(config): + ''' + The main Neural Style Transfer method. + ''' + content_img_path = os.path.join(config['content_images_dir'], config['content_img_name']) + style_img_path = os.path.join(config['style_images_dir'], config['style_img_name']) + out_dir_name = 'combined_' + os.path.split(content_img_path)[1].split('.')[0] + '_' + os.path.split(style_img_path)[1].split('.')[0] + dump_path = os.path.join(config['output_img_dir'], out_dir_name) + os.makedirs(dump_path, exist_ok=True) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + content_img = prepare_img(content_img_path, config['height'], device) + style_img = prepare_img(style_img_path, config['height'], device) + + init_img = content_img + + optimizing_img = Variable(init_img, requires_grad=True) + neural_net, content_feature_maps_index_name, style_feature_maps_indices_names = prepare_model(device) + print(f'Using VGG19 in the optimization procedure.') + content_img_set_of_feature_maps = neural_net(content_img) + style_img_set_of_feature_maps = neural_net(style_img) + target_content_representation = content_img_set_of_feature_maps[content_feature_maps_index_name[0]].squeeze(axis=0) + target_style_representation = [gram_matrix(x) for cnt, x in enumerate(style_img_set_of_feature_maps) if cnt in style_feature_maps_indices_names[0]] + target_representations = [target_content_representation, target_style_representation] + num_of_iterations = 1000 + + optimizer = LBFGS((optimizing_img,), max_iter=num_of_iterations, line_search_fn='strong_wolfe') + cnt = 0 + + def closure(): + nonlocal cnt + if torch.is_grad_enabled(): + optimizer.zero_grad() + total_loss, content_loss, style_loss, tv_loss = build_loss(neural_net, optimizing_img, target_representations, content_feature_maps_index_name[0], style_feature_maps_indices_names[0], config) + if total_loss.requires_grad: + total_loss.backward() + with torch.no_grad(): + print(f'L-BFGS | iteration: {cnt:03}, total loss={total_loss.item():12.4f}, content_loss={config["content_weight"] * content_loss.item():12.4f}, style loss={config["style_weight"] * style_loss.item():12.4f}, tv loss={config["tv_weight"] * tv_loss.item():12.4f}') + save_and_maybe_display(optimizing_img, dump_path, config, cnt, num_of_iterations) + cnt += 1 + return total_loss + optimizer.step(closure) + return dump_path + +PATH = './' +CONTENT_IMAGE = 'c1.jpg' +STYLE_IMAGE = 's1.jpg' + +default_resource_dir = os.path.join(PATH, 'data') +content_images_dir = os.path.join(default_resource_dir, 'content-images') +style_images_dir = os.path.join(default_resource_dir, 'style-images') +output_img_dir = os.path.join(default_resource_dir, 'output-images') +img_format = (4, '.jpg') + +optimization_config = {'content_img_name': CONTENT_IMAGE, 'style_img_name': STYLE_IMAGE, 'height': 400, 'content_weight': 100000.0, 'style_weight': 30000.0, 'tv_weight': 1.0} +optimization_config['content_images_dir'] = content_images_dir +optimization_config['style_images_dir'] = style_images_dir +optimization_config['output_img_dir'] = output_img_dir +optimization_config['img_format'] = img_format + +results_path = neural_style_transfer(optimization_config) diff --git a/python/art/data/content-images/c1.jpg b/python/art/data/content-images/c1.jpg new file mode 100644 index 0000000..455eb37 Binary files /dev/null and b/python/art/data/content-images/c1.jpg differ diff --git a/python/art/data/output-images/combined_c1_s1/c1_s1.jpg b/python/art/data/output-images/combined_c1_s1/c1_s1.jpg new file mode 100644 index 0000000..fbe4824 Binary files /dev/null and b/python/art/data/output-images/combined_c1_s1/c1_s1.jpg differ diff --git a/python/art/data/style-images/s1.jpg b/python/art/data/style-images/s1.jpg new file mode 100644 index 0000000..db596a0 Binary files /dev/null and b/python/art/data/style-images/s1.jpg differ diff --git a/python/art/models/definitions/__pycache__/vgg19.cpython-310.pyc b/python/art/models/definitions/__pycache__/vgg19.cpython-310.pyc new file mode 100644 index 0000000..bb64baf Binary files /dev/null and b/python/art/models/definitions/__pycache__/vgg19.cpython-310.pyc differ diff --git a/python/art/models/definitions/vgg19.py b/python/art/models/definitions/vgg19.py new file mode 100644 index 0000000..001b757 --- /dev/null +++ b/python/art/models/definitions/vgg19.py @@ -0,0 +1,60 @@ +from collections import namedtuple +import torch +from torchvision import models + +class Vgg19(torch.nn.Module): + """ + VGG19 has a total of 19 layers. Out of them, 'conv4_2' is used for content representation, + and 'conv1_1', 'conv2_1', 'conv3_1', 'conv4_1', 'conv5_1' are used for style representation. + """ + def __init__(self, requires_grad=False, show_progress=False, use_relu=True): + super().__init__() + vgg_pretrained_features = models.vgg19(pretrained=True, progress=show_progress).features + + self.layer_names = ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'conv4_2', 'relu5_1'] + self.offset = 1 + self.content_feature_maps_index = 4 + self.style_feature_maps_indices = list(range(len(self.layer_names))) + self.style_feature_maps_indices.remove(4) + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + + for x in range(1+self.offset): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(1+self.offset, 6+self.offset): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(6+self.offset, 11+self.offset): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(11+self.offset, 20+self.offset): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(20+self.offset, 22): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + for x in range(22, 29++self.offset): + self.slice6.add_module(str(x), vgg_pretrained_features[x]) + + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + x = self.slice1(x) + layer1_1 = x + x = self.slice2(x) + layer2_1 = x + x = self.slice3(x) + layer3_1 = x + x = self.slice4(x) + layer4_1 = x + x = self.slice5(x) + conv4_2 = x + x = self.slice6(x) + layer5_1 = x + + vgg_outputs = namedtuple("VggOutputs", self.layer_names) + out = vgg_outputs(layer1_1, layer2_1, layer3_1, layer4_1, conv4_2, layer5_1) + + return out