diff --git a/README.md b/README.md index 25e1e1e..7e959d8 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ pip install -r requirements.txt Once you have NHL recap screenshosts saved on your computer, you must first resize them before creating the labels. To do so, use the following script by specifying the path of the folder where your images are saved: ``` -python -m src.utils.resize_images --path your_path +python -m src.data_creation.resize_images --path your_path ``` It should save all the `*.png` files inside that folder in the correct dimensions. @@ -63,7 +63,7 @@ Once you launched the labeling tool, you are up and running to import you resize The last part before splitting the dataset and running the model is to extract the XML from the labeling tool and split it into separate XMLs (one for each file). To split the XML downloaded from the labeling tool: ``` -python -m src.parser.xml_splitter --file path_xml_file --dir dir_save_xmls +python -m src.data_creation.parser.xml_splitter --file path_xml_file --dir dir_save_xmls ``` The very last step is to add the resized images and the accompagning XML to the `data/raw/` directory and push it to the repo. diff --git a/data/raw/image_train.txt b/data/raw/image_train.txt deleted file mode 100644 index 03153b2..0000000 --- a/data/raw/image_train.txt +++ /dev/null @@ -1,8 +0,0 @@ -data/raw/image-1.png -data/raw/image-2.png -data/raw/image-3.png -data/raw/image-100.png -data/raw/image-102.png -data/raw/image-103.png -data/raw/image-105.png -data/raw/image-106.png \ No newline at end of file diff --git a/data/raw/image_val.txt b/data/raw/image_val.txt deleted file mode 100644 index f03b01f..0000000 --- a/data/raw/image_val.txt +++ /dev/null @@ -1,4 +0,0 @@ -data/raw/image-5.png -data/raw/image-6.png -data/raw/image-101.png -data/raw/image-104.png \ No newline at end of file diff --git a/data/raw/xml_train.txt b/data/raw/xml_train.txt deleted file mode 100644 index 510f8d4..0000000 --- a/data/raw/xml_train.txt +++ /dev/null @@ -1,8 +0,0 @@ -data/raw/image-1.xml -data/raw/image-2.xml -data/raw/image-3.xml -data/raw/image-100.xml -data/raw/image-102.xml -data/raw/image-103.xml -data/raw/image-105.xml -data/raw/image-106.xml \ No newline at end of file diff --git a/data/raw/xml_val.txt b/data/raw/xml_val.txt deleted file mode 100644 index 71044cd..0000000 --- a/data/raw/xml_val.txt +++ /dev/null @@ -1,4 +0,0 @@ -data/raw/image-5.xml -data/raw/image-6.xml -data/raw/image-101.xml -data/raw/image-104.xml \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3a59a38..f82f33a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy -torch +torch==1.4.0 elementpath optparse-pretty python-resize-image diff --git a/src/semantic/vgg/__init__.py b/src/data_creation/__init__.py similarity index 100% rename from src/semantic/vgg/__init__.py rename to src/data_creation/__init__.py diff --git a/src/semantic/utils/utils.py b/src/data_creation/file_manager.py similarity index 99% rename from src/semantic/utils/utils.py rename to src/data_creation/file_manager.py index 15faec1..11f9000 100644 --- a/src/semantic/utils/utils.py +++ b/src/data_creation/file_manager.py @@ -9,4 +9,4 @@ def readfile(name): def savefile(classe, name): with open('{}.pkl'.format(name), 'wb') as fich: - fich.write(pickle.dumps(classe, pickle.HIGHEST_PROTOCOL)) + fich.write(pickle.dumps(classe, pickle.HIGHEST_PROTOCOL)) \ No newline at end of file diff --git a/src/semantic/utils/create_image_label.py b/src/data_creation/label_creation/create_image_label.py similarity index 77% rename from src/semantic/utils/create_image_label.py rename to src/data_creation/label_creation/create_image_label.py index 154195a..9c32b82 100644 --- a/src/semantic/utils/create_image_label.py +++ b/src/data_creation/label_creation/create_image_label.py @@ -1,10 +1,11 @@ import numpy as np import matplotlib.pyplot as plt import mahotas -from src.semantic.parser.xml_parser import parse_xml_data -from src.semantic.net_parameters import p_label_to_int +from src.data_creation.parser.xml_parser import parse_xml_data from PIL import Image +LABEL_TO_INT = {'ice': 1, 'board': 2, 'circlezone': 3, 'circlemid': 4, 'goal': 5, 'blue': 6, 'red': 7, 'fo': 8} + class CreateLabel: def __init__(self, path_xml, path_image): @@ -42,8 +43,8 @@ def get_label(self): poly = points[i] x, y = zip(*CreateLabel.render(poly)) for k in range(len(y)): - if p_label_to_int[labels[i]] > frame_image[x[k]][y[k]]: - frame_image[x[k]][y[k]] = p_label_to_int[labels[i]] + if LABEL_TO_INT[labels[i]] > frame_image[x[k]][y[k]]: + frame_image[x[k]][y[k]] = LABEL_TO_INT[labels[i]] self.frame_image = frame_image.transpose() return frame_image.transpose() @@ -54,7 +55,3 @@ def show_plot(self): plt.imshow(self.frame_image) plt.show() - -#Label2 = CreateLabel(path_xml='./data/xml/test2_polygon.xml', path_image='./data/image/test2_polygon.png') -#label2_array = Label2.get_label() -#Label2.show_plot() diff --git a/src/data_creation/parser/__init__.py b/src/data_creation/parser/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/semantic/parser/xml_parser.py b/src/data_creation/parser/xml_parser.py similarity index 100% rename from src/semantic/parser/xml_parser.py rename to src/data_creation/parser/xml_parser.py diff --git a/src/semantic/parser/xml_splitter.py b/src/data_creation/parser/xml_splitter.py similarity index 79% rename from src/semantic/parser/xml_splitter.py rename to src/data_creation/parser/xml_splitter.py index d257f2d..3c8b35d 100644 --- a/src/semantic/parser/xml_splitter.py +++ b/src/data_creation/parser/xml_splitter.py @@ -28,14 +28,14 @@ def split_xml(path_file, path_to): def get_args(): - parser = OptionParser() - parser.add_option('-f', '--file', type=str, dest='file', + parser = OptionParser() + parser.add_option('-f', '--file', type=str, dest='file', help='File Path (including filename) of the XML.') - parser.add_option('-d', '--dir', type=str, dest='dir', + parser.add_option('-d', '--dir', type=str, dest='dir', help='Directory to save the XMLs') - (options, args) = parser.parse_args() - return options + (options, args) = parser.parse_args() + return options if __name__ == '__main__': diff --git a/src/semantic/utils/resize_images.py b/src/data_creation/resize_images.py similarity index 95% rename from src/semantic/utils/resize_images.py rename to src/data_creation/resize_images.py index 6ba6772..2e7da5d 100644 --- a/src/semantic/utils/resize_images.py +++ b/src/data_creation/resize_images.py @@ -4,6 +4,8 @@ from PIL import Image from resizeimage import resizeimage +RESIZE_FORMAT = [512, 256] + def resize_images(path): """Function that resize png files inside a dir""" @@ -12,7 +14,7 @@ def resize_images(path): if file.endswith(".png"): with open(os.path.join(path, file), 'r+b') as f: with Image.open(f) as image: - cover = resizeimage.resize_thumbnail(image, [512, 256]) + cover = resizeimage.resize_thumbnail(image, RESIZE_FORMAT) new_name = os.path.join('resized_'+file) cover.save(os.path.join(path, new_name), image.format) print(file+' has been resized and saved.') diff --git a/src/semantic/__init__.py b/src/semantic/__init__.py index d429724..e69de29 100644 --- a/src/semantic/__init__.py +++ b/src/semantic/__init__.py @@ -1,2 +0,0 @@ -from src.semantic.training_function import train -import src.semantic.net_parameters diff --git a/src/semantic/create_data_training_setup.py b/src/semantic/create_data_training_setup.py new file mode 100644 index 0000000..0269c4c --- /dev/null +++ b/src/semantic/create_data_training_setup.py @@ -0,0 +1,38 @@ +import json +from pathlib import Path +from optparse import OptionParser + +from src.semantic.modeling_data_creation.split_modeling_data import create_labels_from_dir +from src.semantic.dataloader.flip_images import flip_images + + +def get_args(): + parser = OptionParser() + parser.add_option('-c', '--config', type=str, dest='config', default='src/semantic/training_config.json', + help='Config file to setup training') + + (options, args) = parser.parse_args() + return options + + +def data_creation(config_file): + with open(config_file, "r") as f: + config = json.load(f) + data_parameters = config["data_parameters"] + # Split train and test in 3 different folders (and save arrays instead of XMLs) + create_labels_from_dir( + path_data=data_parameters["raw_data_path"], + path_to=data_parameters["data_creation_folder_path"], + train_test_perc=data_parameters["train_test_perc"], + train_valid_perc=data_parameters["train_valid_perc"], + max=data_parameters["max_image"] + ) + + if data_parameters["data_augmentation"]: + train_data_path = Path(data_parameters["data_creation_folder_path"], "train") + flip_images(train_data_path) + + +if __name__ == "__main__": + args = get_args() + data_creation(args.config) diff --git a/src/semantic/dataloader/dataset.py b/src/semantic/dataloader/dataset.py index b1e4aa2..a2c8428 100644 --- a/src/semantic/dataloader/dataset.py +++ b/src/semantic/dataloader/dataset.py @@ -1,11 +1,9 @@ import numpy as np -import os from PIL import Image from torch.utils.data import Dataset -from src.semantic.utils.create_image_label import CreateLabel -from src.semantic.utils.utils import readfile +from src.data_creation.file_manager import readfile def load_image(file): @@ -16,8 +14,6 @@ class DataGenerator(Dataset): def __init__(self, imagepath, labelpath, transform): # make sure label match with image self.transform = transform - #assert os.path.exists(imagepath), "{} not exists !".format(imagepath) - #assert os.path.exists(labelpath), "{} not exists !".format(labelpath) self.image = imagepath self.label = labelpath diff --git a/src/semantic/dataloader/flip_images.py b/src/semantic/dataloader/flip_images.py index 60d391d..adb501a 100644 --- a/src/semantic/dataloader/flip_images.py +++ b/src/semantic/dataloader/flip_images.py @@ -1,15 +1,16 @@ import glob +from pathlib import Path import numpy as np from PIL import Image from torchvision import transforms -from src.semantic.utils.utils import readfile, savefile +from src.data_creation.file_manager import readfile, savefile def flip_images(path_data): - images = glob.glob(path_data + '*.png') - labels = glob.glob(path_data + '*.pkl') + images = glob.glob(str(Path(path_data, '*.png'))) + labels = glob.glob(str(Path(path_data, '*.pkl'))) preprocess_flip = transforms.Compose([ transforms.RandomHorizontalFlip(1) @@ -22,7 +23,6 @@ def flip_images(path_data): image_flip = preprocess_flip(image_flip) image_flip.save(image.replace('image', 'rimage')) - for label in labels: label_flip = readfile(label.replace('.pkl', '')) label_flip = Image.fromarray(label_flip) diff --git a/src/semantic/history.py b/src/semantic/history.py deleted file mode 100644 index 328e634..0000000 --- a/src/semantic/history.py +++ /dev/null @@ -1,60 +0,0 @@ -import matplotlib.pyplot as plt -from IPython.display import clear_output - - -class History: - - def __init__(self): - self.history = { - 'train_loss': [], - 'val_loss': [], - 'lr': [] - } - - def save(self, train_loss, val_loss, lr): - self.history['train_loss'].append(train_loss) - self.history['val_loss'].append(val_loss) - self.history['lr'].append(lr) - - def add_history(self, history_add): - self.history['train_loss'] += history_add.history['train_loss'] - self.history['val_loss'] += history_add.history['val_loss'] - self.history['lr'] += history_add.history['lr'] - - def display_loss(self): - epoch = len(self.history['train_loss']) - epochs = [x for x in range(1, epoch + 1)] - plt.title('Training loss') - plt.xlabel('Epochs') - plt.ylabel('Loss') - plt.plot(epochs, self.history['train_loss'], label='Train') - plt.plot(epochs, self.history['val_loss'], label='Validation') - plt.legend() - plt.show() - - def display_lr(self): - epoch = len(self.history['train_loss']) - epochs = [x for x in range(1, epoch + 1)] - plt.title('Learning rate') - plt.xlabel('Epochs') - plt.ylabel('Lr') - plt.plot(epochs, self.history['lr'], label='Lr') - plt.show() - - def display(self): - epoch = len(self.history['train_loss']) - epochs = [x for x in range(1, epoch + 1)] - - fig, axes = plt.subplots(2, 1) - plt.tight_layout() - - axes[1].set_xlabel('Epochs') - axes[1].set_ylabel('Loss') - axes[1].plot(epochs, self.history['train_loss'], label='Train') - axes[1].plot(epochs, self.history['val_loss'], label='Validation') - - axes[2].set_xlabel('Epochs') - axes[2].set_ylabel('Lr') - axes[2].plot(epochs, self.history['lr'], label='Lr') - - plt.show() diff --git a/src/semantic/model/__init__.py b/src/semantic/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/semantic/model/unet/__init__.py b/src/semantic/model/unet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/semantic/unet/unet_model.py b/src/semantic/model/unet/unet_model.py similarity index 94% rename from src/semantic/unet/unet_model.py rename to src/semantic/model/unet/unet_model.py index 45591c3..bc981dc 100644 --- a/src/semantic/unet/unet_model.py +++ b/src/semantic/model/unet/unet_model.py @@ -1,5 +1,6 @@ import torch.nn.functional as F -from src.semantic.unet.unet_utils import * +from src.semantic.model.unet.unet_utils import * + class UNet(nn.Module): def __init__(self, n_channels, n_classes): diff --git a/src/semantic/unet/unet_utils.py b/src/semantic/model/unet/unet_utils.py similarity index 95% rename from src/semantic/unet/unet_utils.py rename to src/semantic/model/unet/unet_utils.py index 11dbe3c..a3382e0 100644 --- a/src/semantic/unet/unet_utils.py +++ b/src/semantic/model/unet/unet_utils.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from src.semantic.net_parameters import p_bilinear class double_conv(nn.Module): @@ -47,7 +46,7 @@ def forward(self, x): class up(nn.Module): - def __init__(self, in_ch, out_ch, bilinear=p_bilinear): + def __init__(self, in_ch, out_ch, bilinear=True): super(up, self).__init__() # would be a nice idea if the upsampling could be learned too, diff --git a/src/semantic/model/vgg/__init__.py b/src/semantic/model/vgg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/semantic/vgg/vggnet.py b/src/semantic/model/vgg/vggnet.py similarity index 100% rename from src/semantic/vgg/vggnet.py rename to src/semantic/model/vgg/vggnet.py diff --git a/src/semantic/vgg/weight_adapt.py b/src/semantic/model/vgg/weight_adapt.py similarity index 97% rename from src/semantic/vgg/weight_adapt.py rename to src/semantic/model/vgg/weight_adapt.py index 09538ca..28818d1 100644 --- a/src/semantic/vgg/weight_adapt.py +++ b/src/semantic/model/vgg/weight_adapt.py @@ -3,7 +3,8 @@ import torch import torch.nn as nn -from src.semantic.net_parameters import p_number_of_classes + +NUMBER_OF_CLASSES = 9 model_urls = { @@ -56,5 +57,5 @@ adapt_state_dict["features.19.bias"] = ori_state_dict["features.21.bias"] adapt_state_dict["features.19.running_mean"] = ori_state_dict["features.21.running_mean"] adapt_state_dict["features.19.running_var"] = ori_state_dict["features.21.running_var"] -adapt_state_dict["conv_out.weight"] = nn.init.normal(torch.zeros((p_number_of_classes, 256, 1, 1)), 0) +adapt_state_dict["conv_out.weight"] = nn.init.normal(torch.zeros((NUMBER_OF_CLASSES, 256, 1, 1)), 0) adapt_state_dict["conv_out.bias"] = torch.zeros(p_number_of_classes) diff --git a/src/semantic/modeling_data_creation/__init__.py b/src/semantic/modeling_data_creation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/semantic/unet/generate_masks.py b/src/semantic/modeling_data_creation/split_modeling_data.py similarity index 84% rename from src/semantic/unet/generate_masks.py rename to src/semantic/modeling_data_creation/split_modeling_data.py index 33487c3..b03dfdc 100644 --- a/src/semantic/unet/generate_masks.py +++ b/src/semantic/modeling_data_creation/split_modeling_data.py @@ -1,11 +1,12 @@ import glob import math import os +from pathlib import Path import numpy as np from shutil import copyfile -from src.semantic.utils.create_image_label import CreateLabel -from src.semantic.utils.utils import savefile +from src.data_creation.label_creation.create_image_label import CreateLabel +from src.data_creation.file_manager import savefile def create_labels_from_dir(path_data, path_to, train_test_perc=0.8, train_valid_perc=0.8, shuffle=True, max=None): @@ -21,8 +22,8 @@ def create_labels_from_dir(path_data, path_to, train_test_perc=0.8, train_valid_ The XML files are created using cvat tool (see labeling-tool/) """ - images = glob.glob(path_data + '*.png') - xml = glob.glob(path_data + '*.xml') + images = glob.glob(str(Path(path_data, '*.png'))) + xml = glob.glob(str(Path(path_data, '*.xml'))) images.sort() xml.sort() @@ -41,21 +42,19 @@ def create_labels_from_dir(path_data, path_to, train_test_perc=0.8, train_valid_ train_idx, test_idx = indices[:split], indices[split:] nb_images_train = len(train_idx) - indices_train = np.arange(nb_images_train) - - if shuffle: - np.random.shuffle(indices_train) split_train = math.floor(train_valid_perc * nb_images_train) - train_idx, valid_idx = indices_train[:split_train], indices_train[split_train:] + train_idx, valid_idx = train_idx[:split_train], train_idx[split_train:] if max is not None: train_idx = train_idx[:max] + valid_idx = valid_idx[:max] + test_idx = test_idx[:max] # Create new folders for train and test datasets - os.mkdir(path_to + 'train/') - os.mkdir(path_to + 'valid/') - os.mkdir(path_to + 'test/') + os.mkdir(str(Path(path_to, 'train'))) + os.mkdir(str(Path(path_to, 'valid'))) + os.mkdir(str(Path(path_to, 'test'))) for id in train_idx: filename_png = images[id].split('/')[-1] diff --git a/src/semantic/net_parameters.py b/src/semantic/net_parameters.py deleted file mode 100644 index f4388f8..0000000 --- a/src/semantic/net_parameters.py +++ /dev/null @@ -1,21 +0,0 @@ -p_weight_augmentation = None -p_bilinear = True -p_model_name_save = 'unet' -p_normalize = True -p_max_images= 2000 -p_number_of_classes = 9 -p_label_to_int = {'ice': 1, 'board': 2, 'circlezone': 3, 'circlemid': 4, 'goal': 5, 'blue': 6, 'red': 7, 'fo': 8} -p_classes_color = ['black', 'white', 'yellow', 'pink', 'coral', 'crimson', 'blue', 'red', 'magenta'] -p_history_save_name = None -p_save_name = 'history' - -net_dict = { - 'p_weight_augmentation': p_weight_augmentation, - 'p_bilinear': p_bilinear, - 'p_model_name_save': p_model_name_save, - 'p_normalize': p_normalize, - 'p_max_images': p_max_images, - 'p_number_of_classes': p_number_of_classes, - 'p_label_to_int': p_label_to_int, - 'p_classes_color': p_classes_color -} diff --git a/src/semantic/predict.py b/src/semantic/predict.py index 26461a1..6e2effa 100644 --- a/src/semantic/predict.py +++ b/src/semantic/predict.py @@ -4,11 +4,13 @@ import matplotlib.pyplot as plt from src.semantic.training_function import predict -from src.semantic.utils.utils import readfile, savefile +from src.data_creation.file_manager import readfile, savefile -from src.semantic.net_parameters import p_classes_color -cmap = mpl.colors.ListedColormap(p_classes_color) +CLASSES_COLOR = ['black', 'white', 'yellow', 'pink', 'coral', 'crimson', 'blue', 'red', 'magenta'] + + +cmap = mpl.colors.ListedColormap(CLASSES_COLOR) def get_args(): parser = OptionParser() diff --git a/src/semantic/training_config.json b/src/semantic/training_config.json new file mode 100644 index 0000000..a79b4ae --- /dev/null +++ b/src/semantic/training_config.json @@ -0,0 +1,45 @@ +{ + "model_load_name": null, + "model_type": "unet", + "model_parameters": { + "n_channels": 3, + "n_classes": 9 + }, + "model_save_name": "test_phil_unet", + + "data_parameters":{ + "raw_data_path": "data/raw/", + "data_creation_folder_path": "data/", + "data_augmentation": true, + "train_test_perc": 0.8, + "train_valid_perc": 0.8, + "max_image": 2 + }, + + "loss_type": "CrossEntropy", + "use_gpu": false, + + "transform_params": { + "normalize": true, + "crop": [450, 256] + }, + + "optimizer_type": "SGD", + "optimizer_params": { + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 0.0005 + }, + + "schedular_type": "onecyclelr", + "schedular_params": { + "max_lr": 0.1 + }, + + "training_parameters": { + "n_epoch": 3, + "batch_size": 2, + "shuffle": true, + "weight_adaptation": null + } +} \ No newline at end of file diff --git a/src/semantic/training_function.py b/src/semantic/training_function.py index 6c71456..1955beb 100644 --- a/src/semantic/training_function.py +++ b/src/semantic/training_function.py @@ -4,25 +4,24 @@ import time import glob import os -import math -from torch.autograd import Variable from torch.utils.data import DataLoader from torchvision import transforms +from pathlib import Path from src.semantic.dataloader import DataGenerator -from src.semantic.net_parameters import p_number_of_classes from src.semantic.dataloader.dataset import load_image -from src.semantic.dataloader import NormalizeCropTransform -from src.semantic.history import History + + +NUMBER_OF_CLASSES = 9 def train_valid_loaders(train_path, valid_path, batch_size, transform, shuffle=True): # List the files in train and valid - train_images = glob.glob(train_path + '*.png') - train_labels = glob.glob(train_path + '*.pkl') - valid_images = glob.glob(valid_path + '*.png') - valid_labels = glob.glob(valid_path + '*.pkl') + train_images = glob.glob(str(Path(train_path, '*.png'))) + train_labels = glob.glob(str(Path(train_path, '*.pkl'))) + valid_images = glob.glob(str(Path(valid_path, '*.png'))) + valid_labels = glob.glob(str(Path(valid_path, '*.pkl'))) train_images.sort() train_labels.sort() @@ -38,9 +37,9 @@ def train_valid_loaders(train_path, valid_path, batch_size, transform, shuffle=T return loader_train, loader_val -def validate(model, val_loader, criterion, use_gpu=False): +@torch.no_grad() +def validate(model, val_loader, criterion, device): - model.train(False) val_loss = [] model.eval() @@ -48,76 +47,61 @@ def validate(model, val_loader, criterion, use_gpu=False): for j, batch in enumerate(val_loader): inputs, targets = batch - if use_gpu: - inputs = inputs.cuda() - targets = targets.cuda() + inputs = inputs.to(device) + targets = targets.to(device) - inputs = Variable(inputs, volatile=True) - targets = Variable(targets, volatile=True) output = model(inputs) - #predictions = output.max(dim=1)[1] - val_loss.append(criterion(output, targets[:, 0]).item()) - #true.extend(targets.data.cpu().numpy().tolist()) - #pred.extend(predictions.data.cpu().numpy().tolist()) - - model.train(True) - #return accuracy_score(true, pred) * 100, sum(val_loss) / len(val_loss) return sum(val_loss) / len(val_loss) -def train(model, optimizer, train_path, valid_path, n_epoch, batch_size, transform, criterion, use_gpu=False, +def train(model, optimizer, train_path, valid_path, n_epoch, batch_size, transform, criterion, device, scheduler=None, shuffle=True, weight_adaptation=None): - history = History() train_loader, val_loader = train_valid_loaders(train_path, valid_path, transform=transform, batch_size=batch_size, shuffle=shuffle) for i in range(n_epoch): start = time.time() - do_epoch(criterion, model, optimizer, scheduler, train_loader, use_gpu, weight_adaptation) + do_epoch(criterion, model, optimizer, scheduler, train_loader, device, weight_adaptation) - train_loss = validate(model, train_loader, criterion, use_gpu) + train_loss = validate(model, train_loader, criterion, device) - val_loss = validate(model, val_loader, criterion, use_gpu) + val_loss = validate(model, val_loader, criterion, device) end = time.time() - history.save(train_loss, val_loss, optimizer.param_groups[0]['lr']) print('Epoch {} - Train loss: {:.4f} - Val loss: {:.4f} Training time: {:.2f}s'.format(i, train_loss, val_loss, end - start)) - return history -def do_epoch(criterion, model, optimizer, scheduler, train_loader, use_gpu, weight_adaptation): +def do_epoch(criterion, model, optimizer, scheduler, train_loader, device, weight_adaptation): model.train() + if scheduler: scheduler.step() for batch in train_loader: inputs, targets = batch - if use_gpu: - inputs = inputs.cuda() - targets = targets.cuda() + inputs = inputs.to(device) + targets = targets.to(device) - inputs = Variable(inputs) - targets = Variable(targets) optimizer.zero_grad() output = model(inputs) if isinstance(criterion, torch.nn.modules.loss.CrossEntropyLoss): - weight_learn = torch.FloatTensor( - np.array([1/(np.log(1.1 + (np.array(targets.cpu() == i)).mean())) for i in range(p_number_of_classes)])) + weight_learn = np.array([1/(np.log(1.1 + (np.array(targets.to(torch.device("cpu")) == i)).mean())) + for i in range(NUMBER_OF_CLASSES)]) + weight_learn = torch.tensor(weight_learn, dtype=torch.float) if weight_adaptation is not None: pred_unique = output.max(dim=1)[1].unique() targets_unique = targets.unique() for target in targets_unique: if target not in pred_unique: weight_learn[target] = weight_learn[target] + weight_adaptation - if use_gpu: - weight_learn = weight_learn.cuda() + weight_learn = weight_learn.to(device) criterion = nn.CrossEntropyLoss(weight=weight_learn) loss = criterion(output, targets[:, 0]) @@ -125,7 +109,9 @@ def do_epoch(criterion, model, optimizer, scheduler, train_loader, use_gpu, weig optimizer.step() +@torch.no_grad() def predict(model, image_path, folder): + model.eval() # A sortir de la fonction eventuellement pour le normalize ... def crop_center(img, cropx, cropy): diff --git a/src/semantic/training_script.py b/src/semantic/training_script.py index 4557b8b..816a325 100644 --- a/src/semantic/training_script.py +++ b/src/semantic/training_script.py @@ -1,93 +1,123 @@ import torch.nn as nn import os import sys +import json +from pathlib import Path from optparse import OptionParser +import torch from torch import optim -from src.semantic.unet.unet_model import UNet - -from src.semantic import train +from src.semantic.model.unet.unet_model import UNet +from src.semantic.training_function import train from src.semantic.dataloader import NormalizeCropTransform from src.semantic.loss import DiceCoeff -from src.semantic.unet.generate_masks import create_labels_from_dir -from src.semantic.dataloader.flip_images import flip_images + from src.semantic.utils.show_images_sample import see_image_output -from src.semantic.utils.utils import readfile, savefile -from src.semantic.net_parameters import (p_weight_augmentation, p_normalize, p_model_name_save, p_max_images, - p_number_of_classes, net_dict, p_history_save_name, p_save_name) -from src.semantic.vgg.vggnet import vgg16_bn +from src.data_creation.file_manager import readfile, savefile -def train_unet(net, path_train, path_valid, n_epoch, batch_size, lr, criterion, use_gpu): +def create_model(model_type, model_params): + if model_type.lower() == 'vgg16': + raise NotImplementedError('Need to import the net from model and adapt the script. Old Stuff there') + if model_type.lower() == 'unet': + return UNet(**model_params) + else: + raise NotImplementedError('Need to specify a valid Neural Network model') + - optimizer = optim.SGD(net.parameters(), - lr=lr, - momentum=0.9, - weight_decay=0.0005) +def create_optimizer(optimizer_type, model, optimizer_params): + trainable_parameters = [p for p in model.parameters() if p.requires_grad] + if optimizer_type.lower() == "sgd": + return optim.SGD(trainable_parameters, **optimizer_params) + else: + raise NotImplementedError - transform = NormalizeCropTransform(normalize=p_normalize, crop=(450, 256)) +def create_loss(loss_type): + if loss_type.lower() == 'crossentropy': + return nn.CrossEntropyLoss() + if loss_type.lower() == 'Dice': + return DiceCoeff() - history = train(model=net, optimizer=optimizer, train_path=path_train, valid_path=path_valid, n_epoch=n_epoch, - batch_size=batch_size, criterion=criterion, transform=transform, use_gpu=use_gpu, - weight_adaptation=p_weight_augmentation) - net.cpu() - savefile(net, p_model_name_save) - return history +def create_scheduler(): + return None -def save_parameter_and_history(history, net_dict, args_parser, save_name, previous_parameter_file=None): - net_list = [] - args_list = [] - if previous_parameter_file is not None: - old_pickle = readfile(previous_parameter_file) - old_pickle['history'].add_history(history) - end_history = old_pickle['history'] - args_list = old_pickle['args_list'] - net_list = old_pickle['net_list'] +def create_device(use_gpu): + if use_gpu: + return torch.device("cuda") else: - end_history = history - - args_parser_dict = { - 'model': args_parser.model, - 'epochs': args_parser.epochs, - 'batch_size': args_parser.batchsize, - 'criterion': args_parser.criterion, - 'augmentation': args_parser.augmentation - } - args_list.append(args_parser_dict) - net_list.append(net_dict) - all_dict = { - 'history': end_history, - 'args_list': args_list, - 'net_list': net_list + return torch.device("cpu") + + +def training(config_file): + with open(config_file, "r") as f: + config = json.load(f) + + net = create_model( + model_type=config["model_type"], + model_params=config["model_parameters"] + ) + optimizer = create_optimizer( + optimizer_type=config["optimizer_type"], + model=net, + optimizer_params=config["optimizer_params"] + ) + loss_criterion = create_loss(config["loss_type"]) + + transform = NormalizeCropTransform(**config["transform_params"]) + + scheduler = create_scheduler() + + device = create_device(config["use_gpu"]) + + net.to(device) + + data_creation_folder_path = config["data_parameters"]["data_creation_folder_path"] + training_path = Path(data_creation_folder_path, "train") + validation_path = Path(data_creation_folder_path, "valid") + testing_path = Path(data_creation_folder_path, "test") + + training_dict = { + "model": net, + "optimizer": optimizer, + "train_path": training_path, + "valid_path": validation_path, + "transform": transform, + "criterion": loss_criterion, + "device": device, + "scheduler": scheduler, + **config["training_parameters"] } - savefile(all_dict, save_name) + + try: + train(**training_dict) + + see_image_output( + net, + path_train=training_path, + path_test=testing_path, + path_save=data_creation_folder_path + ) + + except KeyboardInterrupt: + savefile(net, config["model_save_name"]) + + print('Saved interrupt') + try: + sys.exit(0) + except SystemExit: + os._exit(0) + + net.to(torch.device("cpu")) + savefile(net, config["model_save_name"]) def get_args(): parser = OptionParser() - parser.add_option('-p', '--path', type=str, dest='path', default='data/raw/', - help='Path raw data (.png and .xml)') - parser.add_option('-m', '--model', dest='model', default='unet', type='string', - help='Type of Neural Nets') - parser.add_option('-e', '--epochs', dest='epochs', default=3, type='int', - help='number of epochs') - parser.add_option('-b', '--batch-size', dest='batchsize', default=2, - type='int', help='batch size') - parser.add_option('-l', '--learning-rate', dest='lr', default=0.1, - type='float', help='learning rate') - parser.add_option('-c', '--criterion', type=str, dest='criterion', default='CrossEntropy', - help='Choices: CrossEntropy or Dice') - parser.add_option('-g', '--gpu', action='store_true', dest='gpu', - default=False, help='use gpu') - parser.add_option('-n', '--model_load_name', type=str, dest='model_name', default='', - help='Model to load (path to the pickle)') - parser.add_option('-s', '--setup', dest='setup', action='store_true', - default=False, help='Setup the datasets otpion.') - parser.add_option('-a', '--augmentation', dest='augmentation', action='store_true', - default=False, help='data augmentation option. Need to have set up to true.') + parser.add_option('-c', '--config', type=str, dest='config', default='src/semantic/training_config.json', + help='Config file to setup training') (options, args) = parser.parse_args() return options @@ -95,62 +125,4 @@ def get_args(): if __name__ == '__main__': args = get_args() - - if args.criterion == 'CrossEntropy': - criterion = nn.CrossEntropyLoss() - elif args.criterion == 'Dice': - criterion = DiceCoeff() - else: - sys.exit(0) - - if args.model_name != '': - net = readfile(args.model_name) - print('Model loaded from this pickle : {}'.format(args.model_name)) - else: - if args.model == 'vgg16': - net = vgg16_bn(pretrained=True, nb_classes=p_number_of_classes) - elif args.model == 'unet': - net = UNet(3, p_number_of_classes) - else: - raise ValueError('Need to specify a Neural Network model') - if args.gpu: - net.cuda() - - # We assume the path to save is the path parent to the raw/ data - path_to = os.path.normpath(args.path + os.sep + os.pardir) + '/' - if args.setup: - # Split train and test in 2 different folders (and save arrays instead of XMLs) - create_labels_from_dir(path_data=args.path, path_to=path_to, train_test_perc=0.8, train_valid_perc=0.8, - max=p_max_images) - if args.augmentation: - flip_images(path_to+'train/') - - try: - history = train_unet(net=net, - path_train=path_to+'train/', - path_valid=path_to+'valid/', - n_epoch=args.epochs, - batch_size=args.batchsize, - lr=args.lr, - use_gpu=args.gpu, - criterion=criterion) - - see_image_output(net, path_train=path_to+'train/', path_test=path_to+'test/', path_save=path_to) - save_parameter_and_history(history=history, - net_dict=net_dict, - args_parser=args, - save_name=p_save_name, - previous_parameter_file=p_history_save_name) - except KeyboardInterrupt: - savefile(net, p_model_name_save) - save_parameter_and_history(history=history, - net_dict=net_dict, - args_parser=args, - save_name=p_save_name, - previous_parameter_file=p_history_save_name) - print('Saved interrupt') - try: - sys.exit(0) - except SystemExit: - os._exit(0) - + training(args.config) diff --git a/src/semantic/unet/__init__.py b/src/semantic/unet/__init__.py deleted file mode 100644 index dca2f07..0000000 --- a/src/semantic/unet/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from src.semantic.unet.generate_masks import create_labels_from_dir diff --git a/src/semantic/utils/show_images_sample.py b/src/semantic/utils/show_images_sample.py index 0f63d2f..acf98fa 100644 --- a/src/semantic/utils/show_images_sample.py +++ b/src/semantic/utils/show_images_sample.py @@ -2,21 +2,23 @@ import matplotlib.pyplot as plt import glob import numpy as np -from pathlib import PurePath +from pathlib import PurePath, Path from src.semantic.dataloader.transform import NormalizeCropTransform from src.semantic.dataloader.dataset import DataGenerator -from src.semantic.net_parameters import p_classes_color + + +CLASSES_COLOR = ['black', 'white', 'yellow', 'pink', 'coral', 'crimson', 'blue', 'red', 'magenta'] def see_image_output(net, path_train, path_test, path_save): - cmap = mpl.colors.ListedColormap(p_classes_color) + cmap = mpl.colors.ListedColormap(CLASSES_COLOR) net.cpu() net.eval() transform = NormalizeCropTransform(normalize=True, crop=(450, 256)) # Sample 2 images from train - train_images = glob.glob(path_train + '*.png') + train_images = glob.glob(str(Path(path_train, '*.png'))) nb_images = len(train_images) indices = np.arange(nb_images) np.random.shuffle(indices) @@ -26,7 +28,7 @@ def see_image_output(net, path_train, path_test, path_save): labels_train.sort() # Sample 2 images from test - test_images = glob.glob(path_test + '*.png') + test_images = glob.glob(str(Path(path_test, '*.png'))) nb_images = len(test_images) indices = np.arange(nb_images) np.random.shuffle(indices) @@ -59,7 +61,6 @@ def see_image_output(net, path_train, path_test, path_save): fig.suptitle("Sample predicted from train dataset", fontsize=16, y=1.002, x=0.4) plt.savefig(path_save+'train-sample.png') - plt.show() i = 0 while i < len(data_test): @@ -79,4 +80,3 @@ def see_image_output(net, path_train, path_test, path_save): fig.suptitle("Sample predicted from test dataset", fontsize=16, y=1.002, x=0.4) plt.savefig(path_save+'test-sample.png') - plt.show()