-
Notifications
You must be signed in to change notification settings - Fork 2
69 datamodule and data import are combined #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
acc54ff
2e31754
047d4ef
c18fd28
4919e25
98414c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,8 +12,6 @@ | |
| from torchvision.utils import make_grid | ||
| from sklearn.model_selection import train_test_split | ||
|
|
||
| from utils.tools import DataBuilderXNAT | ||
| from xnat.mixin import ImageScanData, SubjectData | ||
| import matplotlib.pyplot as plt | ||
|
|
||
| from project.transforms import ( | ||
|
|
@@ -34,16 +32,20 @@ class DataModule(pytorch_lightning.LightningDataModule): | |
|
|
||
| def __init__( | ||
| self, | ||
| raw_data, | ||
| xnat_configuration: dict = None, | ||
| batch_size: int = 1, | ||
| num_workers: int = 4, | ||
| visualise_training_data=True, | ||
| image_series_option: str = 'error', | ||
| ): | ||
| super().__init__() | ||
| self.num_workers = num_workers | ||
| self.batch_size = batch_size | ||
| self.raw_data = raw_data | ||
| self.xnat_configuration = xnat_configuration | ||
| self.batch_size = batch_size | ||
| self.num_workers = num_workers | ||
| self.visualise_training_data = visualise_training_data | ||
| self.image_series_option = image_series_option | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this only relevant for image data from XNAT? if so, could that be reflected in a comment?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leave for now, can always delete or refactor at a later stage.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add comments in the code to prompt future users to delete as appropriate |
||
|
|
||
| self.train_transforms = Compose( | ||
| load_xnat(self.xnat_configuration, self.image_series_option) | ||
|
|
@@ -58,22 +60,6 @@ def __init__( | |
| + output_transforms() | ||
| ) | ||
|
|
||
| def get_data(self) -> None: | ||
| """ | ||
| Fetches raw XNAT data and stores in raw_data attribute | ||
| """ | ||
| actions = [ | ||
| (self.fetch_xr, "image"), | ||
| (self.fetch_label, "label"), | ||
| ] | ||
|
|
||
| data_builder = DataBuilderXNAT( | ||
| self.xnat_configuration, actions=actions, num_workers=self.num_workers | ||
| ) | ||
|
|
||
| data_builder.fetch_data() | ||
| self.raw_data = data_builder.dataset | ||
|
|
||
| def validate_data(self) -> None: | ||
| """ | ||
| Remove samples that do not have both an image and a label | ||
|
|
@@ -125,9 +111,8 @@ def get_labels(self, data) -> list: | |
| if sample["data_label"] == "label" | ||
| ] | ||
|
|
||
| def prepare_data(self, *args, **kwargs): | ||
| def setup(self, *args, **kwargs): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to have an in-depth discussion around setup() versus prepare_data(), potentially at the next TT.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keep as is for now, but explore in the future.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add comment in the code with a brief summary of the decision to change this and point to resources where relevant considerations are listed for future use |
||
|
|
||
| self.get_data() | ||
| logging.info("Validating data") | ||
| self.validate_data() | ||
|
|
||
|
|
@@ -254,41 +239,6 @@ def test_dataloader(self): | |
| pin_memory=is_available(), | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def fetch_xr(subject_data: SubjectData = None) -> List[ImageScanData]: | ||
| """ | ||
| Function that identifies and returns the required xnat ImageData object from a xnat SubjectData object | ||
| along with the 'key' that it will be used to access it. | ||
| """ | ||
|
|
||
| scan_objects = [] | ||
|
|
||
| for exp in subject_data.experiments: | ||
| if ( | ||
| "CR" in subject_data.experiments[exp].modality | ||
| or "DX" in subject_data.experiments[exp].modality | ||
| ): | ||
| for scan in subject_data.experiments[exp].scans: | ||
| scan_objects.append(subject_data.experiments[exp].scans[scan]) | ||
| return scan_objects | ||
|
|
||
| @staticmethod | ||
| def fetch_label(subject_data: SubjectData = None): | ||
| """ | ||
| Function that identifies and returns the required label from a XNAT SubjectData object. | ||
| """ | ||
| label = None | ||
| for exp in subject_data.experiments: | ||
| if ( | ||
| "CR" in subject_data.experiments[exp].modality | ||
| or "DX" in subject_data.experiments[exp].modality | ||
| ): | ||
| temp_label = subject_data.experiments[exp].label | ||
| x = temp_label.split("_") | ||
| label = int(x[1]) | ||
|
|
||
| return label | ||
|
|
||
| def dataset_stats(self, dataset: List[Dict], fields=["label"]) -> dict: | ||
| """Calculate dataset statistics | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import os | ||
| import logging | ||
| import pandas as pd | ||
| import numpy as np | ||
|
|
||
| from utils.tools import DataBuilderXNAT | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could XNATDataImport be merged/refactored with DataBuilderXNAT? What's the benefit of having them separate? A merged version is easier to get rid of/ignore for not image-based projects.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Different functionalities so keep separate.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make plans to move both to CSC-XNAT package |
||
| from xnat.mixin import ImageScanData, SubjectData | ||
| from typing import List, Dict | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| class XNATDataImport(): | ||
|
|
||
| def __init__(self, xnat_configuration: dict = None, num_workers: int = 4): | ||
| self.xnat_configuration = xnat_configuration | ||
| self.num_workers = num_workers | ||
|
|
||
| def import_xnat_data(self): | ||
| actions = [ | ||
| (self.fetch_xr, "image"), | ||
| (self.fetch_label, "label"), | ||
| ] | ||
|
|
||
| data_builder = DataBuilderXNAT( | ||
| self.xnat_configuration, actions=actions, num_workers=self.num_workers | ||
| ) | ||
|
|
||
| data_builder.fetch_data() | ||
| return(data_builder.dataset) | ||
|
|
||
| @staticmethod | ||
| def fetch_xr(subject_data: SubjectData = None) -> List[ImageScanData]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "CR" and "DX" are I assume specific for X-rays. I recommend we add various
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Create as separate issue! |
||
| """ | ||
| Function that identifies and returns the required xnat ImageData object from a xnat SubjectData object | ||
| along with the 'key' that it will be used to access it. | ||
| """ | ||
|
|
||
| scan_objects = [] | ||
|
|
||
| for exp in subject_data.experiments: | ||
| if ( | ||
| "CR" in subject_data.experiments[exp].modality | ||
| or "DX" in subject_data.experiments[exp].modality | ||
| ): | ||
| for scan in subject_data.experiments[exp].scans: | ||
| scan_objects.append(subject_data.experiments[exp].scans[scan]) | ||
| return scan_objects | ||
|
|
||
| @staticmethod | ||
| def fetch_label(subject_data: SubjectData = None): | ||
| """ | ||
| Function that identifies and returns the required label from a XNAT SubjectData object. | ||
| """ | ||
| label = None | ||
| for exp in subject_data.experiments: | ||
| if ( | ||
| "CR" in subject_data.experiments[exp].modality | ||
| or "DX" in subject_data.experiments[exp].modality | ||
| ): | ||
| temp_label = subject_data.experiments[exp].label | ||
| x = temp_label.split("_") | ||
| label = int(x[1]) | ||
|
|
||
| return label | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should stay to be part of the DataModule component.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As it is its own component, I think it should have its own explanation.
Will create an issue to make a code structure/architecture image.