diff --git a/README.md b/README.md index 44fc064..672f96f 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ mlops run -h The first thing to do after cloning this template is to rename the appropriate files and folders to make the directory project specific. The `project` directory should be renamed to make it clear that it contains your project files. -### There are 5 main components that need to be completed after cloning the template: +### There are 6 main components that need to be completed after cloning the template: ### 1. `config/config.cfg` and `config/local_config.cfg` The config file contains all the information that is used for configuring the project, experiment, and tracking server. This includes training parameters and XNAT configurations. @@ -53,7 +53,13 @@ As there will be differences between local development and running on DGX (for e Note: The values present in the template config files are examples, you can remove any except those in `[server]` and `[project]` which are necessary for MLOps. Outside of these you are encouraged to add and modify the config files as relevant to your project. -### 2. `project/Network.py` +### 2. `project/XNATDataImport.py` +This file is used to define and pull the required data from XNAT. It utilises DataBuilderXNAT to do so as shown in the example, if you require additional or different data from XNAT additional actions can be added. + +If your data is not stored in XNAT this can be replaced by any method that accesses your data. + + +### 3. `project/Network.py` This file is used to define the PyTorch `LightningModule` class. This is where you set the Network architecture and flow that you will use for training, validation, and testing. @@ -62,22 +68,22 @@ Here you can set up which metrics are calculated and at which stage in the flow The example has numerous metrics and steps that are not always necessary, feel free to delete or add as relevant to your project. -### 3. `project/DataModule.py` +### 4. `project/DataModule.py` This file is used to define the PyTorch `LightningDataModule` class. This is where you define the data that is used for training, validation, and testing. -The example involves retrieving data from XNAT (more on this below) which may not be necessary for your project. There are additional data validation steps that might not be relevant, feel free to delete or add as relevant to your project. +The example involves additional data validation steps that might not be relevant, feel free to delete or add as relevant to your project. -### 4. `scripts/train.py` +### 5. `scripts/train.py` This file is used to define the training run. This is where the `Datamodule` and `Network` are pulled together. The example includes callbacks to retrieve the best model parameters, feel free to delete or add as relevant to your project. -### 5. `Dockerfile` +### 6. `Dockerfile` This dockerfile sets up the Docker image that the MLOps run will utilise. In the example this is just a simple environment running python version 3.10. diff --git a/project/DataModule.py b/project/DataModule.py index 760e6d0..454515f 100644 --- a/project/DataModule.py +++ b/project/DataModule.py @@ -12,8 +12,6 @@ from torchvision.utils import make_grid from sklearn.model_selection import train_test_split -from utils.tools import DataBuilderXNAT -from xnat.mixin import ImageScanData, SubjectData import matplotlib.pyplot as plt from project.transforms import ( @@ -34,16 +32,20 @@ class DataModule(pytorch_lightning.LightningDataModule): def __init__( self, + raw_data, xnat_configuration: dict = None, batch_size: int = 1, num_workers: int = 4, visualise_training_data=True, + image_series_option: str = 'error', ): super().__init__() - self.num_workers = num_workers - self.batch_size = batch_size + self.raw_data = raw_data self.xnat_configuration = xnat_configuration + self.batch_size = batch_size + self.num_workers = num_workers self.visualise_training_data = visualise_training_data + self.image_series_option = image_series_option self.train_transforms = Compose( load_xnat(self.xnat_configuration, self.image_series_option) @@ -58,22 +60,6 @@ def __init__( + output_transforms() ) - def get_data(self) -> None: - """ - Fetches raw XNAT data and stores in raw_data attribute - """ - actions = [ - (self.fetch_xr, "image"), - (self.fetch_label, "label"), - ] - - data_builder = DataBuilderXNAT( - self.xnat_configuration, actions=actions, num_workers=self.num_workers - ) - - data_builder.fetch_data() - self.raw_data = data_builder.dataset - def validate_data(self) -> None: """ Remove samples that do not have both an image and a label @@ -125,9 +111,8 @@ def get_labels(self, data) -> list: if sample["data_label"] == "label" ] - def prepare_data(self, *args, **kwargs): + def setup(self, *args, **kwargs): - self.get_data() logging.info("Validating data") self.validate_data() @@ -254,41 +239,6 @@ def test_dataloader(self): pin_memory=is_available(), ) - @staticmethod - def fetch_xr(subject_data: SubjectData = None) -> List[ImageScanData]: - """ - Function that identifies and returns the required xnat ImageData object from a xnat SubjectData object - along with the 'key' that it will be used to access it. - """ - - scan_objects = [] - - for exp in subject_data.experiments: - if ( - "CR" in subject_data.experiments[exp].modality - or "DX" in subject_data.experiments[exp].modality - ): - for scan in subject_data.experiments[exp].scans: - scan_objects.append(subject_data.experiments[exp].scans[scan]) - return scan_objects - - @staticmethod - def fetch_label(subject_data: SubjectData = None): - """ - Function that identifies and returns the required label from a XNAT SubjectData object. - """ - label = None - for exp in subject_data.experiments: - if ( - "CR" in subject_data.experiments[exp].modality - or "DX" in subject_data.experiments[exp].modality - ): - temp_label = subject_data.experiments[exp].label - x = temp_label.split("_") - label = int(x[1]) - - return label - def dataset_stats(self, dataset: List[Dict], fields=["label"]) -> dict: """Calculate dataset statistics diff --git a/project/XNATDataImport.py b/project/XNATDataImport.py new file mode 100644 index 0000000..c7b2ea7 --- /dev/null +++ b/project/XNATDataImport.py @@ -0,0 +1,64 @@ +import os +import logging +import pandas as pd +import numpy as np + +from utils.tools import DataBuilderXNAT +from xnat.mixin import ImageScanData, SubjectData +from typing import List, Dict + +logger = logging.getLogger(__name__) + +class XNATDataImport(): + + def __init__(self, xnat_configuration: dict = None, num_workers: int = 4): + self.xnat_configuration = xnat_configuration + self.num_workers = num_workers + + def import_xnat_data(self): + actions = [ + (self.fetch_xr, "image"), + (self.fetch_label, "label"), + ] + + data_builder = DataBuilderXNAT( + self.xnat_configuration, actions=actions, num_workers=self.num_workers + ) + + data_builder.fetch_data() + return(data_builder.dataset) + + @staticmethod + def fetch_xr(subject_data: SubjectData = None) -> List[ImageScanData]: + """ + Function that identifies and returns the required xnat ImageData object from a xnat SubjectData object + along with the 'key' that it will be used to access it. + """ + + scan_objects = [] + + for exp in subject_data.experiments: + if ( + "CR" in subject_data.experiments[exp].modality + or "DX" in subject_data.experiments[exp].modality + ): + for scan in subject_data.experiments[exp].scans: + scan_objects.append(subject_data.experiments[exp].scans[scan]) + return scan_objects + + @staticmethod + def fetch_label(subject_data: SubjectData = None): + """ + Function that identifies and returns the required label from a XNAT SubjectData object. + """ + label = None + for exp in subject_data.experiments: + if ( + "CR" in subject_data.experiments[exp].modality + or "DX" in subject_data.experiments[exp].modality + ): + temp_label = subject_data.experiments[exp].label + x = temp_label.split("_") + label = int(x[1]) + + return label diff --git a/scripts/train.py b/scripts/train.py index 11e6409..5348242 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -15,6 +15,7 @@ from project.DataModule import DataModule from project.Network import Network from project.DataModule import label_dict +from project.XNATDataImport import XNATDataImport logger = logging.getLogger(__name__) @@ -50,16 +51,23 @@ def train(config): mlflow.pytorch.autolog(log_models=False) + # Import raw data + raw_data = XNATDataImport( + xnat_configuration = xnat_configuration, + num_workers = num_workers, + ).import_xnat_data() + # Set up datamodule dm = DataModule( + raw_data = raw_data, xnat_configuration = xnat_configuration, num_workers = num_workers, batch_size = int(config['params']['batch_size']), visualise_training_data = config['params']['visualise_training_data'], ) - dm.prepare_data() - + dm.setup() + n_classes = len(set([x for x in label_dict.values() if x is not None])) mlflow.log_param('n_classes', n_classes)