-
Notifications
You must be signed in to change notification settings - Fork 36
[WIP] Split PrepPipeline into separate methods, make final interpolation optional #99
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6112c32
8ba133b
8ec209f
5191d85
f98bb9f
11987e2
fce1724
4e0e039
c4bb806
46256b3
6ff2755
46b164a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,10 +2,8 @@ | |||||||||
| import mne | ||||||||||
| from mne.utils import check_random_state | ||||||||||
|
|
||||||||||
| from pyprep.find_noisy_channels import NoisyChannels | ||||||||||
| from pyprep.reference import Reference | ||||||||||
| from pyprep.removeTrend import removeTrend | ||||||||||
| from pyprep.utils import _set_diff, _union # noqa: F401 | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class PrepPipeline: | ||||||||||
|
|
@@ -166,75 +164,193 @@ def __init__( | |||||||||
| self.filter_kwargs = filter_kwargs | ||||||||||
| self.matlab_strict = matlab_strict | ||||||||||
|
|
||||||||||
| # Initialize attributes to be filled in later | ||||||||||
| self.EEG_raw = self.raw_eeg.get_data() | ||||||||||
| self.EEG_filtered = None | ||||||||||
| self.EEG_post_reference = None | ||||||||||
|
|
||||||||||
| # NOTE: 'original' refers to before initial average reference, not first | ||||||||||
| # pass afterwards. Not necessarily comparable to later values? | ||||||||||
| self.noisy_info = { | ||||||||||
| "original": None, "post-reference": None, "post-interpolation": None | ||||||||||
| } | ||||||||||
| self.bad_channels = { | ||||||||||
| "original": None, "post-reference": None, "post-interpolation": None | ||||||||||
| } | ||||||||||
| self.interpolated_channels = None | ||||||||||
| self.robust_reference_signal = None | ||||||||||
| self._interpolated_reference_signal = None | ||||||||||
|
|
||||||||||
| @property | ||||||||||
| def raw(self): | ||||||||||
| """Return a version of self.raw_eeg that includes the non-eeg channels.""" | ||||||||||
| full_raw = self.raw_eeg.copy() | ||||||||||
| if self.raw_non_eeg is None: | ||||||||||
| return full_raw | ||||||||||
| else: | ||||||||||
| return full_raw.add_channels([self.raw_non_eeg], force_update_info=True) | ||||||||||
| if self.raw_non_eeg is not None: | ||||||||||
| full_raw.add_channels([self.raw_non_eeg], force_update_info=True) | ||||||||||
| return full_raw | ||||||||||
|
|
||||||||||
| def fit(self): | ||||||||||
| """Run the whole PREP pipeline.""" | ||||||||||
| noisy_detector = NoisyChannels(self.raw_eeg, random_state=self.random_state) | ||||||||||
| noisy_detector.find_bad_by_nan_flat() | ||||||||||
| # unusable_channels = _union( | ||||||||||
| # noisy_detector.bad_by_nan, noisy_detector.bad_by_flat | ||||||||||
| # ) | ||||||||||
| # reference_channels = _set_diff(self.prep_params["ref_chs"], unusable_channels) | ||||||||||
| # Step 1: 1Hz high pass filtering | ||||||||||
| if len(self.prep_params["line_freqs"]) != 0: | ||||||||||
| self.EEG_new = removeTrend( | ||||||||||
| self.EEG_raw, self.sfreq, matlab_strict=self.matlab_strict | ||||||||||
| @property | ||||||||||
| def current_noisy_info(self): | ||||||||||
| post_ref = self.noisy_info["post-reference"] | ||||||||||
| post_interp = self.noisy_info["post-interpolation"] | ||||||||||
| return post_interp if post_interp else post_ref | ||||||||||
|
|
||||||||||
| @property | ||||||||||
| def remaining_bad_channels(self): | ||||||||||
| post_ref = self.bad_channels["post-reference"] | ||||||||||
| post_interp = self.bad_channels["post-interpolation"] | ||||||||||
| return post_interp if post_interp else post_ref | ||||||||||
|
|
||||||||||
| @property | ||||||||||
| def current_reference_signal(self): | ||||||||||
| post_ref = self.robust_reference_signal | ||||||||||
| post_interp = self._interpolated_reference_signal | ||||||||||
| return post_interp if post_interp else post_ref | ||||||||||
|
|
||||||||||
| def get_raw(self, stage=None): | ||||||||||
| """Retrieve the full recording data at a given stage of the pipeline. | ||||||||||
|
|
||||||||||
| Valid pipeline stages include 'unprocessed' (the raw data prior to running | ||||||||||
| the pipeline), 'filtered' (the data following adaptive line noise | ||||||||||
| removal), 'post-reference' (the data after robust referencing, prior to any | ||||||||||
| bad channel interpolation), and 'post-interpolation' (the data after robust | ||||||||||
| referencing and bad channel interpolation). | ||||||||||
|
|
||||||||||
| Parameters | ||||||||||
| ---------- | ||||||||||
| stage : str, optional | ||||||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I'd do something like this:
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That causes the line to go beyond 88 characters, is line wrap for argument types something that's supported by Numpy docstyle? |
||||||||||
| The stage of the pipeline for which the full data will be retrieved. If | ||||||||||
| not specified, the current state of the data will be retrieved. | ||||||||||
|
|
||||||||||
| Returns | ||||||||||
| ------- | ||||||||||
| full_raw: mne.io.Raw | ||||||||||
| An MNE Raw object containing the EEG data for the given stage of the | ||||||||||
| pipeline, along with any non-EEG channels that were present in the | ||||||||||
| original input data. | ||||||||||
|
|
||||||||||
| """ | ||||||||||
| interpolated = self.interpolated_channels is not None | ||||||||||
| stages = { | ||||||||||
| "unprocessed": self.EEG_raw, | ||||||||||
| "filtered": self.EEG_filtered, | ||||||||||
| "post-reference": self.EEG_post_reference, | ||||||||||
| "post-interpolation": self.raw_eeg._data if interpolated else None, | ||||||||||
| } | ||||||||||
| if stage is not None and stage.lower() not in stages.keys(): | ||||||||||
| raise ValueError( | ||||||||||
| "'{stage}' is not a valid pipeline stage. Valid stages are " | ||||||||||
| "'unprocessed', 'filtered', 'post-reference', and 'post-interpolation'." | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Step 2: Removing line noise | ||||||||||
| linenoise = self.prep_params["line_freqs"] | ||||||||||
| if self.filter_kwargs is None: | ||||||||||
| self.EEG_clean = mne.filter.notch_filter( | ||||||||||
| self.EEG_new, | ||||||||||
| Fs=self.sfreq, | ||||||||||
| freqs=linenoise, | ||||||||||
| method="spectrum_fit", | ||||||||||
| mt_bandwidth=2, | ||||||||||
| p_value=0.01, | ||||||||||
| filter_length="10s", | ||||||||||
| ) | ||||||||||
| else: | ||||||||||
| self.EEG_clean = mne.filter.notch_filter( | ||||||||||
| self.EEG_new, | ||||||||||
| Fs=self.sfreq, | ||||||||||
| freqs=linenoise, | ||||||||||
| **self.filter_kwargs, | ||||||||||
| eeg_data = self.raw_eeg._data # Default to most recent stage of pipeline | ||||||||||
| if stage: | ||||||||||
| eeg_data = stages[stage.lower()] | ||||||||||
| if not eeg_data: | ||||||||||
| raise ValueError( | ||||||||||
| "Could not retrieve {stage} data, as that stage of the pipeline " | ||||||||||
| "has not yet been performed." | ||||||||||
|
Comment on lines
+251
to
+252
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whoops, nice catch! |
||||||||||
| ) | ||||||||||
| full_raw = self.raw_eeg.copy() | ||||||||||
| full_raw._data = eeg_data | ||||||||||
| if self.raw_non_eeg is not None: | ||||||||||
| full_raw.add_channels([self.raw_non_eeg]) | ||||||||||
|
|
||||||||||
| # Add Trend back | ||||||||||
| self.EEG = self.EEG_raw - self.EEG_new + self.EEG_clean | ||||||||||
| self.raw_eeg._data = self.EEG | ||||||||||
| return full_raw | ||||||||||
|
|
||||||||||
| # Step 3: Referencing | ||||||||||
| reference = Reference( | ||||||||||
| def remove_line_noise(self, line_freqs): | ||||||||||
| """Remove line noise from all EEG channels using multi-taper decomposition. | ||||||||||
|
|
||||||||||
| This filtering method attempts to isolate and remove line noise from the | ||||||||||
| signal while preserving unrelated background signal in the same frequency | ||||||||||
| ranges. This is done to minimize distortions in the power-spectral density | ||||||||||
| curves due to line noise removal. | ||||||||||
|
|
||||||||||
| Parameters | ||||||||||
| ---------- | ||||||||||
| line_freqs: {np.ndarray, list} | ||||||||||
| A list of the frequencies (in Hz) at which line noise should be removed | ||||||||||
| (e.g., ``np.arange(60, sfreq / 2, 60)`` for a recording with a powerline | ||||||||||
| noise of 60 Hz). | ||||||||||
|
|
||||||||||
| """ | ||||||||||
| # Define default settings for filter and apply any kwargs overrides | ||||||||||
| settings = {"mt_bandwidth": 2, "p_value": 0.01, "filter_length": "10s"} | ||||||||||
| if isinstance(self.filter_kwargs, dict): | ||||||||||
| settings.update(self.filter_kwargs) | ||||||||||
|
|
||||||||||
| # Remove slow drifts from the recording prior to filtering | ||||||||||
| eeg_detrended = removeTrend( | ||||||||||
| self.EEG_raw, self.sfreq, matlab_strict=self.matlab_strict | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Remove line noise and add the removed slow drifts back | ||||||||||
| eeg_cleaned = mne.filter.notch_filter( | ||||||||||
| eeg_detrended, | ||||||||||
| Fs=self.sfreq, | ||||||||||
| freqs=line_freqs, | ||||||||||
| method="spectrum_fit", | ||||||||||
| **settings, | ||||||||||
| # Add support for parallel jobs if joblib installed? | ||||||||||
| ) | ||||||||||
| self.EEG_filtered = (self.EEG_raw - eeg_detrended) + eeg_cleaned | ||||||||||
| self.raw_eeg._data = self.EEG_filtered | ||||||||||
|
|
||||||||||
| def robust_reference(self, max_iterations=4, interpolate_bads=True): | ||||||||||
| """Perform robust referencing on the EEG signal and detect bad channels. | ||||||||||
|
|
||||||||||
| This method uses an iterative approach to estimate a robust average | ||||||||||
| reference signal free of contamination from bad channels, as detected | ||||||||||
| automatically using the methods of :class:`~pyprep.NoisyChannels`. Once | ||||||||||
| estimated, the robust average reference is applied to the data and bad | ||||||||||
| channel detection is re-run to flag any noisy or unusable channels | ||||||||||
| post-reference. | ||||||||||
|
|
||||||||||
| By default, this method will also interpolate the signals of any channels | ||||||||||
| detected as bad following robust referencing, re-reference the data | ||||||||||
| accordingly, and re-detect any remaining bad channels. | ||||||||||
|
|
||||||||||
| Parameters | ||||||||||
| ---------- | ||||||||||
| max_iterations : int, optional | ||||||||||
| The maximum number of iterations of noisy channel removal to perform | ||||||||||
| during robust referencing. Defaults to ``4``. | ||||||||||
| interpolate_bads : bool, optional | ||||||||||
| Whether or not any remaining bad channels following robust referencing | ||||||||||
| should be interpolated. Defaults to ``True``. | ||||||||||
|
|
||||||||||
| """ | ||||||||||
| # Perform robust referencing on the signal | ||||||||||
| ref = Reference( | ||||||||||
| self.raw_eeg, | ||||||||||
| self.prep_params, | ||||||||||
| random_state=self.random_state, | ||||||||||
| matlab_strict=self.matlab_strict, | ||||||||||
| **self.ransac_settings, | ||||||||||
| ) | ||||||||||
| reference.perform_reference(self.prep_params["max_iterations"]) | ||||||||||
| self.raw_eeg = reference.raw | ||||||||||
| self.noisy_channels_original = reference.noisy_channels_original | ||||||||||
| self.noisy_channels_before_interpolation = ( | ||||||||||
| reference.noisy_channels_before_interpolation | ||||||||||
| ) | ||||||||||
| self.noisy_channels_after_interpolation = ( | ||||||||||
| reference.noisy_channels_after_interpolation | ||||||||||
| ) | ||||||||||
| self.bad_before_interpolation = reference.bad_before_interpolation | ||||||||||
| self.EEG_before_interpolation = reference.EEG_before_interpolation | ||||||||||
| self.reference_before_interpolation = reference.reference_signal | ||||||||||
| self.reference_after_interpolation = reference.reference_signal_new | ||||||||||
| self.interpolated_channels = reference.interpolated_channels | ||||||||||
| self.still_noisy_channels = reference.still_noisy_channels | ||||||||||
| ref.perform_reference(max_iterations, interpolate_bads) | ||||||||||
|
|
||||||||||
| self.raw_eeg = ref.raw | ||||||||||
| self.EEG_post_reference = ref.EEG_before_interpolation | ||||||||||
| self.robust_reference_signal = ref.reference_signal | ||||||||||
| self._interpolated_reference_signal = ref.reference_signal_new | ||||||||||
|
|
||||||||||
| self.noisy_info["original"] = ref.noisy_channels_original | ||||||||||
| self.noisy_info["post-reference"] = ref.noisy_channels_before_interpolation | ||||||||||
| self.noisy_info["post-interpolation"] = ref.noisy_channels_after_interpolation | ||||||||||
|
|
||||||||||
| self.bad_channels["original"] = ref.noisy_channels_original["bad_all"] | ||||||||||
| self.bad_channels["post-reference"] = ref.bad_before_interpolation | ||||||||||
| self.bad_channels["post-interpolation"] = ref.still_noisy_channels | ||||||||||
| self.interpolated_channels = ref.interpolated_channels | ||||||||||
|
|
||||||||||
| def fit(self): | ||||||||||
| """Run the whole PREP pipeline.""" | ||||||||||
| # Step 1: Adaptive line noise removal | ||||||||||
| if len(self.prep_params["line_freqs"]) != 0: | ||||||||||
| self.remove_line_noise(self.prep_params["line_freqs"]) | ||||||||||
|
|
||||||||||
| # Step 2: Robust Referencing | ||||||||||
| self.robust_reference(self.prep_params["max_iterations"]) | ||||||||||
|
|
||||||||||
| return self | ||||||||||
Uh oh!
There was an error while loading. Please reload this page.