Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ mlops run -h
The first thing to do after cloning this template is to rename the appropriate files and folders to make the directory project specific.
The `project` directory should be renamed to make it clear that it contains your project files.

### There are 5 main components that need to be completed after cloning the template:
### There are 6 main components that need to be completed after cloning the template:

### 1. `config/config.cfg` and `config/local_config.cfg`
The config file contains all the information that is used for configuring the project, experiment, and tracking server. This includes training parameters and XNAT configurations.
Expand All @@ -53,7 +53,13 @@ As there will be differences between local development and running on DGX (for e

Note: The values present in the template config files are examples, you can remove any except those in `[server]` and `[project]` which are necessary for MLOps. Outside of these you are encouraged to add and modify the config files as relevant to your project.

### 2. `project/Network.py`
### 2. `project/XNATDataImport.py`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should stay to be part of the DataModule component.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it is its own component, I think it should have its own explanation.

Will create an issue to make a code structure/architecture image.

This file is used to define and pull the required data from XNAT. It utilises DataBuilderXNAT to do so as shown in the example, if you require additional or different data from XNAT additional actions can be added.

If your data is not stored in XNAT this can be replaced by any method that accesses your data.


### 3. `project/Network.py`
This file is used to define the PyTorch `LightningModule` class.

This is where you set the Network architecture and flow that you will use for training, validation, and testing.
Expand All @@ -62,22 +68,22 @@ Here you can set up which metrics are calculated and at which stage in the flow

The example has numerous metrics and steps that are not always necessary, feel free to delete or add as relevant to your project.

### 3. `project/DataModule.py`
### 4. `project/DataModule.py`
This file is used to define the PyTorch `LightningDataModule` class.

This is where you define the data that is used for training, validation, and testing.

The example involves retrieving data from XNAT (more on this below) which may not be necessary for your project. There are additional data validation steps that might not be relevant, feel free to delete or add as relevant to your project.
The example involves additional data validation steps that might not be relevant, feel free to delete or add as relevant to your project.


### 4. `scripts/train.py`
### 5. `scripts/train.py`
This file is used to define the training run.

This is where the `Datamodule` and `Network` are pulled together.

The example includes callbacks to retrieve the best model parameters, feel free to delete or add as relevant to your project.

### 5. `Dockerfile`
### 6. `Dockerfile`
This dockerfile sets up the Docker image that the MLOps run will utilise.

In the example this is just a simple environment running python version 3.10.
Expand Down
64 changes: 7 additions & 57 deletions project/DataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from torchvision.utils import make_grid
from sklearn.model_selection import train_test_split

from utils.tools import DataBuilderXNAT
from xnat.mixin import ImageScanData, SubjectData
import matplotlib.pyplot as plt

from project.transforms import (
Expand All @@ -34,16 +32,20 @@ class DataModule(pytorch_lightning.LightningDataModule):

def __init__(
self,
raw_data,
xnat_configuration: dict = None,
batch_size: int = 1,
num_workers: int = 4,
visualise_training_data=True,
image_series_option: str = 'error',
):
super().__init__()
self.num_workers = num_workers
self.batch_size = batch_size
self.raw_data = raw_data
self.xnat_configuration = xnat_configuration
self.batch_size = batch_size
self.num_workers = num_workers
self.visualise_training_data = visualise_training_data
self.image_series_option = image_series_option
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this only relevant for image data from XNAT? if so, could that be reflected in a comment?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave for now, can always delete or refactor at a later stage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments in the code to prompt future users to delete as appropriate


self.train_transforms = Compose(
load_xnat(self.xnat_configuration, self.image_series_option)
Expand All @@ -58,22 +60,6 @@ def __init__(
+ output_transforms()
)

def get_data(self) -> None:
"""
Fetches raw XNAT data and stores in raw_data attribute
"""
actions = [
(self.fetch_xr, "image"),
(self.fetch_label, "label"),
]

data_builder = DataBuilderXNAT(
self.xnat_configuration, actions=actions, num_workers=self.num_workers
)

data_builder.fetch_data()
self.raw_data = data_builder.dataset

def validate_data(self) -> None:
"""
Remove samples that do not have both an image and a label
Expand Down Expand Up @@ -125,9 +111,8 @@ def get_labels(self, data) -> list:
if sample["data_label"] == "label"
]

def prepare_data(self, *args, **kwargs):
def setup(self, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to have an in-depth discussion around setup() versus prepare_data(), potentially at the next TT.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep as is for now, but explore in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment in the code with a brief summary of the decision to change this and point to resources where relevant considerations are listed for future use


self.get_data()
logging.info("Validating data")
self.validate_data()

Expand Down Expand Up @@ -254,41 +239,6 @@ def test_dataloader(self):
pin_memory=is_available(),
)

@staticmethod
def fetch_xr(subject_data: SubjectData = None) -> List[ImageScanData]:
"""
Function that identifies and returns the required xnat ImageData object from a xnat SubjectData object
along with the 'key' that it will be used to access it.
"""

scan_objects = []

for exp in subject_data.experiments:
if (
"CR" in subject_data.experiments[exp].modality
or "DX" in subject_data.experiments[exp].modality
):
for scan in subject_data.experiments[exp].scans:
scan_objects.append(subject_data.experiments[exp].scans[scan])
return scan_objects

@staticmethod
def fetch_label(subject_data: SubjectData = None):
"""
Function that identifies and returns the required label from a XNAT SubjectData object.
"""
label = None
for exp in subject_data.experiments:
if (
"CR" in subject_data.experiments[exp].modality
or "DX" in subject_data.experiments[exp].modality
):
temp_label = subject_data.experiments[exp].label
x = temp_label.split("_")
label = int(x[1])

return label

def dataset_stats(self, dataset: List[Dict], fields=["label"]) -> dict:
"""Calculate dataset statistics

Expand Down
64 changes: 64 additions & 0 deletions project/XNATDataImport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os
import logging
import pandas as pd
import numpy as np

from utils.tools import DataBuilderXNAT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could XNATDataImport be merged/refactored with DataBuilderXNAT? What's the benefit of having them separate?

A merged version is easier to get rid of/ignore for not image-based projects.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different functionalities so keep separate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make plans to move both to CSC-XNAT package

from xnat.mixin import ImageScanData, SubjectData
from typing import List, Dict

logger = logging.getLogger(__name__)

class XNATDataImport():

def __init__(self, xnat_configuration: dict = None, num_workers: int = 4):
self.xnat_configuration = xnat_configuration
self.num_workers = num_workers

def import_xnat_data(self):
actions = [
(self.fetch_xr, "image"),
(self.fetch_label, "label"),
]

data_builder = DataBuilderXNAT(
self.xnat_configuration, actions=actions, num_workers=self.num_workers
)

data_builder.fetch_data()
return(data_builder.dataset)

@staticmethod
def fetch_xr(subject_data: SubjectData = None) -> List[ImageScanData]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"CR" and "DX" are I assume specific for X-rays. I recommend we add various fetch functions with good documentation and prompt the developer of the specific project to choose and adapt the available functions as applicable for their project.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create as separate issue!

"""
Function that identifies and returns the required xnat ImageData object from a xnat SubjectData object
along with the 'key' that it will be used to access it.
"""

scan_objects = []

for exp in subject_data.experiments:
if (
"CR" in subject_data.experiments[exp].modality
or "DX" in subject_data.experiments[exp].modality
):
for scan in subject_data.experiments[exp].scans:
scan_objects.append(subject_data.experiments[exp].scans[scan])
return scan_objects

@staticmethod
def fetch_label(subject_data: SubjectData = None):
"""
Function that identifies and returns the required label from a XNAT SubjectData object.
"""
label = None
for exp in subject_data.experiments:
if (
"CR" in subject_data.experiments[exp].modality
or "DX" in subject_data.experiments[exp].modality
):
temp_label = subject_data.experiments[exp].label
x = temp_label.split("_")
label = int(x[1])

return label
12 changes: 10 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from project.DataModule import DataModule
from project.Network import Network
from project.DataModule import label_dict
from project.XNATDataImport import XNATDataImport

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,16 +51,23 @@ def train(config):

mlflow.pytorch.autolog(log_models=False)

# Import raw data
raw_data = XNATDataImport(
xnat_configuration = xnat_configuration,
num_workers = num_workers,
).import_xnat_data()

# Set up datamodule
dm = DataModule(
raw_data = raw_data,
xnat_configuration = xnat_configuration,
num_workers = num_workers,
batch_size = int(config['params']['batch_size']),
visualise_training_data = config['params']['visualise_training_data'],
)

dm.prepare_data()

dm.setup()
n_classes = len(set([x for x in label_dict.values() if x is not None]))
mlflow.log_param('n_classes', n_classes)

Expand Down