From 6112c322b52a122e59e9b2e924ad15a7069b82bd Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Tue, 29 Jun 2021 15:42:49 -0300 Subject: [PATCH 01/11] Remove some unused code --- pyprep/prep_pipeline.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index dbbeb89e..3a093221 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -2,7 +2,6 @@ import mne from mne.utils import check_random_state -from pyprep.find_noisy_channels import NoisyChannels from pyprep.reference import Reference from pyprep.removeTrend import removeTrend from pyprep.utils import _set_diff, _union # noqa: F401 @@ -176,12 +175,6 @@ def raw(self): def fit(self): """Run the whole PREP pipeline.""" - noisy_detector = NoisyChannels(self.raw_eeg, random_state=self.random_state) - noisy_detector.find_bad_by_nan_flat() - # unusable_channels = _union( - # noisy_detector.bad_by_nan, noisy_detector.bad_by_flat - # ) - # reference_channels = _set_diff(self.prep_params["ref_chs"], unusable_channels) # Step 1: 1Hz high pass filtering if len(self.prep_params["line_freqs"]) != 0: self.EEG_new = removeTrend( From 8ba133b7727daad1acdf075c13d5d246208a03c9 Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Tue, 29 Jun 2021 18:44:37 -0300 Subject: [PATCH 02/11] Make interpolation a separate Reference method --- pyprep/reference.py | 75 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 15 deletions(-) diff --git a/pyprep/reference.py b/pyprep/reference.py index 80cb5bdc..fb1a8d58 100644 --- a/pyprep/reference.py +++ b/pyprep/reference.py @@ -90,29 +90,41 @@ def __init__( "max_chunk_size": max_chunk_size, } self.random_state = check_random_state(random_state) - self._extra_info = {} self.matlab_strict = matlab_strict - def perform_reference(self, max_iterations=4): + # Initialize attributes that get filled in during referencing + self.bad_before_interpolation = None + self.EEG_before_interpolation = None + self.noisy_channels_before_interpolation = None + self.reference_signal_new = None + self.interpolated_channels = None + self.still_noisy_channels = None + self.noisy_channels_after_interpolation = None + self._extra_info = { + "initial_bad": None, "interpolated": None, "remaining_bad": None + } + + def perform_reference(self, max_iterations=4, interpolate_bads=True): """Estimate the true signal mean and interpolate bad channels. + This function implements the functionality of the `performReference` function + as part of the PREP pipeline on mne raw object. + Parameters ---------- max_iterations : int, optional The maximum number of iterations of noisy channel removal to perform during robust referencing. Defaults to ``4``. - - This function implements the functionality of the `performReference` function - as part of the PREP pipeline on mne raw object. + interpolate_bads : bool, optional + Whether or not any remaining bad channels following robust referencing + should be interpolated or left as-is. Defaults to ``True``. Notes ----- This function calls ``robust_reference`` first. - Currently this function only implements the functionality of default - settings, i.e., ``doRobustPost``. """ - # Phase 1: Estimate the true signal mean with robust referencing + # Estimate the true signal mean with robust referencing self.robust_reference(max_iterations) # If we interpolate the raw here we would be interpolating # more than what we later actually account for (in interpolated channels). @@ -126,6 +138,8 @@ def perform_reference(self, max_iterations=4): dummy.get_data(picks=self.reference_channels), axis=0 ) del dummy + + # Re-reference the data using the calculated robust average reference rereferenced_index = [ self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels ] @@ -133,38 +147,69 @@ def perform_reference(self, max_iterations=4): self.EEG, self.reference_signal, rereferenced_index ) - # Phase 2: Find the bad channels and interpolate + # Detect which channels are still bad following robust referencing self.raw._data = self.EEG noisy_detector = NoisyChannels( self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict ) noisy_detector.find_all_bads(**self.ransac_settings) - - # Record Noisy channels and EEG before interpolation self.bad_before_interpolation = noisy_detector.get_bads(verbose=True) self.EEG_before_interpolation = self.EEG.copy() self.noisy_channels_before_interpolation = noisy_detector.get_bads(as_dict=True) self._extra_info["interpolated"] = noisy_detector._extra_info + # Update bad channels in MNE raw object bad_channels = _union(self.bad_before_interpolation, self.unusable_channels) self.raw.info["bads"] = bad_channels + + # If enabled, interpolate all bad channels and detect any remaining bads + if interpolate_bads: + self.interpolate_bads() + + return self + + def interpolate_bads(self): + """Interpolate any remaining bad channels following robust referencing. + + This method can only be called if :meth:`~.perform_reference` has already + been run with the ``interpolate_bads`` parameter set to ``False``. It cannot + be run more than once per instance of :class:`~pyprep.Reference`. + + """ + if not self.bad_before_interpolation: + raise RuntimeError( + "Robust referencing must be performed before remaining bad channels " + "can be interpolated." + ) + elif self.interpolated_channels: + raise RuntimeError( + "Bad channel interpolation cannot be performed more than once - " + "interpolating signals using other interpolated signals is likely " + "to have poor results." + ) + + # Interpolate any channels flagged as bad following robust referencing + bad_channels = self.raw.info["bads"] if self.matlab_strict: _eeglab_interpolate_bads(self.raw) else: self.raw.interpolate_bads() + + # Calculate and remove the new average reference following interpolation reference_correct = np.nanmean( self.raw.get_data(picks=self.reference_channels), axis=0 ) + rereferenced_index = [ + self.ch_names_eeg.index(ch) for ch in self.rereferenced_channels + ] self.EEG = self.raw.get_data() self.EEG = self.remove_reference( self.EEG, reference_correct, rereferenced_index ) - # reference signal after interpolation self.reference_signal_new = self.reference_signal + reference_correct - # MNE Raw object after interpolation - self.raw._data = self.EEG + self.raw._data = self.EEG # Update the MNE Raw object - # Still noisy channels after interpolation + # Detect any remaining noisy channels following interpolation self.interpolated_channels = bad_channels noisy_detector = NoisyChannels( self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict From 8ec209fdacafc88a5ae066505af2051bc53b67a4 Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Wed, 30 Jun 2021 21:44:11 -0300 Subject: [PATCH 03/11] Remove unused MATLAB comparison test code --- tests/test_prep_pipeline.py | 181 ------------------------------------ 1 file changed, 181 deletions(-) diff --git a/tests/test_prep_pipeline.py b/tests/test_prep_pipeline.py index 48ebb54a..b0635315 100644 --- a/tests/test_prep_pipeline.py +++ b/tests/test_prep_pipeline.py @@ -26,187 +26,6 @@ def test_prep_pipeline(raw, montage): prep = PrepPipeline(raw_copy, prep_params, montage, random_state=42) prep.fit() - EEG_raw = raw_copy.get_data(picks="eeg") * 1e6 - EEG_raw_max = np.max(abs(EEG_raw), axis=None) - EEG_raw_matlab = sio.loadmat("./examples/matlab_results/EEG_raw.mat") - EEG_raw_matlab = EEG_raw_matlab["save_data"] - EEG_raw_diff = EEG_raw - EEG_raw_matlab - # EEG_raw_mse = (EEG_raw_diff / EEG_raw_max ** 2).mean(axis=None) - - fig, axs = plt.subplots(5, 3, sharex="all") - plt.setp(fig, facecolor=[1, 1, 1]) - fig.suptitle("Python versus Matlab PREP results", fontsize=16) - - im = axs[0, 0].imshow( - EEG_raw / EEG_raw_max, - aspect="auto", - extent=[0, (EEG_raw.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[0, 0].set_title("Python", fontsize=14) - axs[0, 1].imshow( - EEG_raw_matlab / EEG_raw_max, - aspect="auto", - extent=[0, (EEG_raw_matlab.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[0, 1].set_title("Matlab", fontsize=14) - axs[0, 2].imshow( - EEG_raw_diff / EEG_raw_max, - aspect="auto", - extent=[0, (EEG_raw_diff.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[0, 2].set_title("Difference", fontsize=14) - # axs[0, 0].set_title('Original EEG', loc='left', fontsize=14) - # axs[0, 0].set_ylabel('Channel Number', fontsize=14) - cb = fig.colorbar(im, ax=axs, fraction=0.05, pad=0.04) - cb.set_label("\u03BCVolt", fontsize=14) - - EEG_new_matlab = sio.loadmat("./examples/matlab_results/EEGNew.mat") - EEG_new_matlab = EEG_new_matlab["save_data"] - EEG_new = prep.EEG_new - EEG_new_max = np.max(abs(EEG_new), axis=None) - EEG_new_diff = EEG_new - EEG_new_matlab - # EEG_new_mse = ((EEG_new_diff / EEG_new_max) ** 2).mean(axis=None) - axs[1, 0].imshow( - EEG_new / EEG_new_max, - aspect="auto", - extent=[0, (EEG_new.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[1, 1].imshow( - EEG_new_matlab / EEG_new_max, - aspect="auto", - extent=[0, (EEG_new_matlab.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[1, 2].imshow( - EEG_new_diff / EEG_new_max, - aspect="auto", - extent=[0, (EEG_new_diff.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - # axs[1, 0].set_title('High pass filter', loc='left', fontsize=14) - # axs[1, 0].set_ylabel('Channel Number', fontsize=14) - - EEG_clean_matlab = sio.loadmat("./examples/matlab_results/EEG.mat") - EEG_clean_matlab = EEG_clean_matlab["save_data"] - EEG_clean = prep.EEG - EEG_max = np.max(abs(EEG_clean), axis=None) - EEG_diff = EEG_clean - EEG_clean_matlab - # EEG_mse = ((EEG_diff / EEG_max) ** 2).mean(axis=None) - axs[2, 0].imshow( - EEG_clean / EEG_max, - aspect="auto", - extent=[0, (EEG_clean.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[2, 1].imshow( - EEG_clean_matlab / EEG_max, - aspect="auto", - extent=[0, (EEG_clean_matlab.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[2, 2].imshow( - EEG_diff / EEG_max, - aspect="auto", - extent=[0, (EEG_diff.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - # axs[2, 0].set_title('Line-noise removal', loc='left', fontsize=14) - axs[2, 0].set_ylabel("Channel Number", fontsize=14) - - EEG = prep.EEG_before_interpolation - EEG_max = np.max(abs(EEG), axis=None) - EEG_ref_mat = sio.loadmat("./examples/matlab_results/EEGref.mat") - EEG_ref_matlab = EEG_ref_mat["save_EEG"] - # reference_matlab = EEG_ref_mat["save_reference"] - EEG_ref_diff = EEG - EEG_ref_matlab - # EEG_ref_mse = ((EEG_ref_diff / EEG_max) ** 2).mean(axis=None) - # reference_signal = prep.reference_before_interpolation - # reference_max = np.max(abs(reference_signal), axis=None) - # reference_diff = reference_signal - reference_matlab - # reference_mse = ((reference_diff / reference_max) ** 2).mean(axis=None) - axs[3, 0].imshow( - EEG / EEG_max, - aspect="auto", - extent=[0, (EEG.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[3, 1].imshow( - EEG_ref_matlab / EEG_max, - aspect="auto", - extent=[0, (EEG_ref_matlab.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[3, 2].imshow( - EEG_ref_diff / EEG_max, - aspect="auto", - extent=[0, (EEG_ref_diff.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - # axs[3, 0].set_title('Referencing', loc='left', fontsize=14) - # axs[3, 0].set_ylabel('Channel Number', fontsize=14) - - EEG_final = prep.raw.get_data() * 1e6 - EEG_final_max = np.max(abs(EEG_final), axis=None) - EEG_final_matlab = sio.loadmat("./examples/matlab_results/EEGinterp.mat") - EEG_final_matlab = EEG_final_matlab["save_data"] - EEG_final_diff = EEG_final - EEG_final_matlab - # EEG_final_mse = ((EEG_final_diff / EEG_final_max) ** 2).mean(axis=None) - axs[4, 0].imshow( - EEG_final / EEG_final_max, - aspect="auto", - extent=[0, (EEG_final.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[4, 1].imshow( - EEG_final_matlab / EEG_final_max, - aspect="auto", - extent=[0, (EEG_final_matlab.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - axs[4, 2].imshow( - EEG_final_diff / EEG_final_max, - aspect="auto", - extent=[0, (EEG_final_diff.shape[1] / sample_rate), 63, 0], - vmin=-1, - vmax=1, - cmap=plt.get_cmap("RdBu"), - ) - # axs[4, 0].set_title('Interpolation', loc='left', fontsize=14) - # axs[4, 0].set_ylabel('Channel Number', fontsize=14) - axs[4, 1].set_xlabel("Time(s)", fontsize=14) - @pytest.mark.usefixtures("raw", "montage") def test_prep_pipeline_non_eeg(raw, montage): From 5191d858eb370c488a279f200d983f7e161b8990 Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Wed, 30 Jun 2021 21:49:29 -0300 Subject: [PATCH 04/11] Add new separate methods for prep stages --- pyprep/prep_pipeline.py | 123 ++++++++++++++++++++++++++++------------ 1 file changed, 88 insertions(+), 35 deletions(-) diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index 3a093221..04810e94 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -4,7 +4,6 @@ from pyprep.reference import Reference from pyprep.removeTrend import removeTrend -from pyprep.utils import _set_diff, _union # noqa: F401 class PrepPipeline: @@ -164,6 +163,17 @@ def __init__( self.filter_kwargs = filter_kwargs self.matlab_strict = matlab_strict + # Initialize attributes to be filled in later + self.noisy_channels_original = None + self.noisy_channels_before_interpolation = None + self.noisy_channels_after_interpolation = None + self.bad_before_interpolation = None + self.EEG_before_interpolation = None + self.reference_before_interpolation = None + self.reference_after_interpolation = None + self.interpolated_channels = None + self.still_noisy_channels = None + @property def raw(self): """Return a version of self.raw_eeg that includes the non-eeg channels.""" @@ -173,39 +183,72 @@ def raw(self): else: return full_raw.add_channels([self.raw_non_eeg]) - def fit(self): - """Run the whole PREP pipeline.""" - # Step 1: 1Hz high pass filtering - if len(self.prep_params["line_freqs"]) != 0: - self.EEG_new = removeTrend( - self.EEG_raw, self.sfreq, matlab_strict=self.matlab_strict - ) - - # Step 2: Removing line noise - linenoise = self.prep_params["line_freqs"] - if self.filter_kwargs is None: - self.EEG_clean = mne.filter.notch_filter( - self.EEG_new, - Fs=self.sfreq, - freqs=linenoise, - method="spectrum_fit", - mt_bandwidth=2, - p_value=0.01, - filter_length="10s", - ) - else: - self.EEG_clean = mne.filter.notch_filter( - self.EEG_new, - Fs=self.sfreq, - freqs=linenoise, - **self.filter_kwargs, - ) - - # Add Trend back - self.EEG = self.EEG_raw - self.EEG_new + self.EEG_clean - self.raw_eeg._data = self.EEG - - # Step 3: Referencing + def remove_line_noise(self, line_freqs): + """Remove line noise from all EEG channels using multi-taper decomposition. + + This filtering method attempts to isolate and remove line noise from the + signal while preserving unrelated background signal in the same frequency + ranges. This is done to minimize distortions in the power-spectral density + curves due to line noise removal. + + Parameters + ---------- + line_freqs: {np.ndarray, list} + A list of the frequencies (in Hz) at which line noise should be removed + (e.g., ``np.arange(60, sfreq / 2, 60)`` for a recording with a powerline + noise of 60 Hz). + + """ + # Define default settings for filter and apply any kwargs overrides + settings = { + "method": "spectrum_fit", + "mt_bandwidth": 2, + "p_value": 0.01, + "filter_length": "10s" + } + if isinstance(self.filter_kwargs, dict): + settings.update(self.filter_kwargs) + + # Remove slow drifts from the recording prior to filtering + eeg_detrended = removeTrend( + self.EEG_raw, self.sfreq, matlab_strict=self.matlab_strict + ) + + # Remove line noise and add the removed slow drifts back + eeg_cleaned = mne.filter.notch_filter( + eeg_detrended, + Fs=self.sfreq, + freqs=line_freqs, + **settings, + ) + self.EEG_filtered = (self.EEG_raw - eeg_detrended) + eeg_cleaned + self.raw_eeg._data = self.EEG_filtered + + def robust_reference(self, max_iterations=4, interpolate_bads=True): + """Perform robust referencing on the EEG signal and detect bad channels. + + This method uses an iterative approach to estimate a robust average + reference signal free of contamination from bad channels, as detected + automatically using the methods of :class:`~pyprep.NoisyChannels`. Once + estimated, the robust average reference is applied to the data and bad + channel detection is re-run to flag any noisy or unusable channels + post-reference. + + By default, this method will also interpolate the signals of any channels + detected as bad following robust referencing, re-reference the data + accordingly, and re-detect any remaining bad channels. + + Parameters + ---------- + max_iterations : int, optional + The maximum number of iterations of noisy channel removal to perform + during robust referencing. Defaults to ``4``. + interpolate_bads : bool, optional + Whether or not any remaining bad channels following robust referencing + should be interpolated. Defaults to ``True``. + + """ + # Perform robust referencing on the signal reference = Reference( self.raw_eeg, self.prep_params, @@ -213,7 +256,8 @@ def fit(self): matlab_strict=self.matlab_strict, **self.ransac_settings, ) - reference.perform_reference(self.prep_params["max_iterations"]) + reference.perform_reference(max_iterations, interpolate_bads) + self.raw_eeg = reference.raw self.noisy_channels_original = reference.noisy_channels_original self.noisy_channels_before_interpolation = ( @@ -229,4 +273,13 @@ def fit(self): self.interpolated_channels = reference.interpolated_channels self.still_noisy_channels = reference.still_noisy_channels + def fit(self): + """Run the whole PREP pipeline.""" + # Step 1: Adaptive line noise removal + if len(self.prep_params["line_freqs"]) != 0: + self.remove_line_noise(self.prep_params["line_freqs"]) + + # Step 2: Robust Referencing + self.robust_reference(self.prep_params["max_iterations"]) + return self From f98bb9f84bde7e0cea8b9aecec40c5620d061ebb Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Wed, 30 Jun 2021 22:01:00 -0300 Subject: [PATCH 05/11] Make remove_line_noisy only use spectrum_fit --- pyprep/prep_pipeline.py | 2 +- tests/test_prep_pipeline.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index 04810e94..3be97326 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -201,7 +201,6 @@ def remove_line_noise(self, line_freqs): """ # Define default settings for filter and apply any kwargs overrides settings = { - "method": "spectrum_fit", "mt_bandwidth": 2, "p_value": 0.01, "filter_length": "10s" @@ -219,6 +218,7 @@ def remove_line_noise(self, line_freqs): eeg_detrended, Fs=self.sfreq, freqs=line_freqs, + method="spectrum_fit", **settings, ) self.EEG_filtered = (self.EEG_raw - eeg_detrended) + eeg_cleaned diff --git a/tests/test_prep_pipeline.py b/tests/test_prep_pipeline.py index b0635315..ec6a00f1 100644 --- a/tests/test_prep_pipeline.py +++ b/tests/test_prep_pipeline.py @@ -1,9 +1,7 @@ """Test the full PREP pipeline.""" -import matplotlib.pyplot as plt import mne import numpy as np import pytest -import scipy.io as sio from pyprep.prep_pipeline import PrepPipeline @@ -84,11 +82,11 @@ def test_prep_pipeline_filter_kwargs(raw, montage): "line_freqs": np.arange(60, sample_rate / 2, 60), } filter_kwargs = { - "method": "fir", - "phase": "zero-double", + "mt_bandwidth": 3, + "p_value": 0.05, } prep = PrepPipeline( raw_copy, prep_params, montage, random_state=42, filter_kwargs=filter_kwargs ) - prep.fit() + prep.remove_line_noise(prep_params["line_freqs"]) From 11987e285d9186f06510487424569414e819ae74 Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Sat, 3 Jul 2021 01:38:19 -0300 Subject: [PATCH 06/11] Fix black's dict complaints --- pyprep/prep_pipeline.py | 6 +----- pyprep/reference.py | 4 +++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index 3be97326..8ca6f99f 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -200,11 +200,7 @@ def remove_line_noise(self, line_freqs): """ # Define default settings for filter and apply any kwargs overrides - settings = { - "mt_bandwidth": 2, - "p_value": 0.01, - "filter_length": "10s" - } + settings = {"mt_bandwidth": 2, "p_value": 0.01, "filter_length": "10s"} if isinstance(self.filter_kwargs, dict): settings.update(self.filter_kwargs) diff --git a/pyprep/reference.py b/pyprep/reference.py index fb1a8d58..de3b09c4 100644 --- a/pyprep/reference.py +++ b/pyprep/reference.py @@ -101,7 +101,9 @@ def __init__( self.still_noisy_channels = None self.noisy_channels_after_interpolation = None self._extra_info = { - "initial_bad": None, "interpolated": None, "remaining_bad": None + "initial_bad": None, + "interpolated": None, + "remaining_bad": None, } def perform_reference(self, max_iterations=4, interpolate_bads=True): From fce17243932afdd47c24d531594d3fce0cda4593 Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Sat, 3 Jul 2021 01:48:51 -0300 Subject: [PATCH 07/11] Try fixing PREP example --- examples/run_full_prep.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/run_full_prep.py b/examples/run_full_prep.py index d0e63da0..4ce75047 100644 --- a/examples/run_full_prep.py +++ b/examples/run_full_prep.py @@ -37,6 +37,7 @@ import matplotlib.pyplot as plt from pyprep.prep_pipeline import PrepPipeline +from pyprep.removeTrend import removeTrend ############################################################################### # Let's download some data for testing. Picking the 1st run of subject 4 here. @@ -168,7 +169,7 @@ EEG_new_matlab = sio.loadmat(fname_mat2) EEG_new_matlab = EEG_new_matlab["save_data"] -EEG_new = prep.EEG_new +EEG_new = removeTrend(prep.EEG_raw, sample_rate=prep.sfreq) EEG_new_max = np.max(abs(EEG_new), axis=None) EEG_new_diff = EEG_new - EEG_new_matlab EEG_new_mse = ((EEG_new_diff / EEG_new_max) ** 2).mean(axis=None) @@ -201,7 +202,7 @@ EEG_clean_matlab = sio.loadmat(fname_mat3) EEG_clean_matlab = EEG_clean_matlab["save_data"] -EEG_clean = prep.EEG +EEG_clean = prep.EEG_filtered EEG_max = np.max(abs(EEG_clean), axis=None) EEG_diff = EEG_clean - EEG_clean_matlab EEG_mse = ((EEG_diff / EEG_max) ** 2).mean(axis=None) From 4e0e039d4d4f8e92c7ca470dc8edd380bd295fe1 Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Wed, 4 Aug 2021 18:35:42 -0300 Subject: [PATCH 08/11] Move bad chan info to dicts, rename attributes --- pyprep/prep_pipeline.py | 79 ++++++++++++++++++++++++++--------------- 1 file changed, 51 insertions(+), 28 deletions(-) diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index 8ca6f99f..f18e0aa8 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -164,24 +164,47 @@ def __init__( self.matlab_strict = matlab_strict # Initialize attributes to be filled in later - self.noisy_channels_original = None - self.noisy_channels_before_interpolation = None - self.noisy_channels_after_interpolation = None - self.bad_before_interpolation = None - self.EEG_before_interpolation = None - self.reference_before_interpolation = None - self.reference_after_interpolation = None + self.EEG_raw = self.raw_eeg.get_data() + self.EEG_filtered = None + self.EEG_post_reference = None + + # NOTE: 'original' refers to before initial average reference, not first + # pass afterwards. Not necessarily comparable to later values? + self.noisy_info = { + "original": None, "post-reference": None, "post-interpolation": None + } + self.bad_channels = { + "original": None, "post-reference": None, "post-interpolation": None + } self.interpolated_channels = None - self.still_noisy_channels = None + self.robust_reference_signal = None + self._interpolated_reference_signal = None @property def raw(self): """Return a version of self.raw_eeg that includes the non-eeg channels.""" full_raw = self.raw_eeg.copy() - if self.raw_non_eeg is None: - return full_raw - else: - return full_raw.add_channels([self.raw_non_eeg]) + if self.raw_non_eeg is not None: + full_raw.add_channels([self.raw_non_eeg]) + return full_raw + + @property + def current_noisy_info(self): + post_ref = self.noisy_info["post-reference"] + post_interp = self.noisy_info["post-interpolation"] + return post_interp if post_interp else post_ref + + @property + def remaining_bad_channels(self): + post_ref = self.bad_channels["post-reference"] + post_interp = self.bad_channels["post-interpolation"] + return post_interp if post_interp else post_ref + + @property + def current_reference_signal(self): + post_ref = self.robust_reference_signal + post_interp = self._interpolated_reference_signal + return post_interp if post_interp else post_ref def remove_line_noise(self, line_freqs): """Remove line noise from all EEG channels using multi-taper decomposition. @@ -216,6 +239,7 @@ def remove_line_noise(self, line_freqs): freqs=line_freqs, method="spectrum_fit", **settings, + # Add support for parallel jobs if joblib installed? ) self.EEG_filtered = (self.EEG_raw - eeg_detrended) + eeg_cleaned self.raw_eeg._data = self.EEG_filtered @@ -245,29 +269,28 @@ def robust_reference(self, max_iterations=4, interpolate_bads=True): """ # Perform robust referencing on the signal - reference = Reference( + ref = Reference( self.raw_eeg, self.prep_params, random_state=self.random_state, matlab_strict=self.matlab_strict, **self.ransac_settings, ) - reference.perform_reference(max_iterations, interpolate_bads) + ref.perform_reference(max_iterations, interpolate_bads) - self.raw_eeg = reference.raw - self.noisy_channels_original = reference.noisy_channels_original - self.noisy_channels_before_interpolation = ( - reference.noisy_channels_before_interpolation - ) - self.noisy_channels_after_interpolation = ( - reference.noisy_channels_after_interpolation - ) - self.bad_before_interpolation = reference.bad_before_interpolation - self.EEG_before_interpolation = reference.EEG_before_interpolation - self.reference_before_interpolation = reference.reference_signal - self.reference_after_interpolation = reference.reference_signal_new - self.interpolated_channels = reference.interpolated_channels - self.still_noisy_channels = reference.still_noisy_channels + self.raw_eeg = ref.raw + self.EEG_post_reference = ref.EEG_before_interpolation + self.robust_reference_signal = ref.reference_signal + self._interpolated_reference_signal = ref.reference_signal_new + + self.noisy_info["original"] = ref.noisy_channels_original + self.noisy_info["post-reference"] = ref.noisy_channels_before_interpolation + self.noisy_info["post-interpolation"] = ref.noisy_channels_after_interpolation + + self.bad_channels["original"] = ref.noisy_channels_original["bad_all"] + self.bad_channels["post-reference"] = ref.bad_before_interpolation + self.bad_channels["post-interpolation"] = ref.still_noisy_channels + self.interpolated_channels = ref.interpolated_channels def fit(self): """Run the whole PREP pipeline.""" From c4bb8063924c1fec0a205f6151cbb4a96b92b9ee Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Wed, 4 Aug 2021 18:36:57 -0300 Subject: [PATCH 09/11] Add get_raw method for easier data access --- pyprep/prep_pipeline.py | 51 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index f18e0aa8..d56f5edb 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -206,6 +206,57 @@ def current_reference_signal(self): post_interp = self._interpolated_reference_signal return post_interp if post_interp else post_ref + def get_raw(self, stage=None): + """Retrieve the full recording data at a given stage of the pipeline. + + Valid pipeline stages include 'unprocessed' (the raw data prior to running + the pipeline), 'filtered' (the data following adaptive line noise + removal), 'post-reference' (the data after robust referencing, prior to any + bad channel interpolation), and 'post-interpolation' (the data after robust + referencing and bad channel interpolation). + + Parameters + ---------- + stage : str, optional + The stage of the pipeline for which the full data will be retrieved. If + not specified, the current state of the data will be retrieved. + + Returns + ------- + full_raw: mne.io.Raw + An MNE Raw object containing the EEG data for the given stage of the + pipeline, along with any non-EEG channels that were present in the + original input data. + + """ + interpolated = self.interpolated_channels is not None + stages = { + "unprocessed": self.EEG_raw, + "filtered": self.EEG_filtered, + "post-reference": self.EEG_post_reference, + "post-interpolation": self.raw_eeg._data if interpolated else None, + } + if stage is not None and stage.lower() not in stages.keys(): + raise ValueError( + "'{stage}' is not a valid pipeline stage. Valid stages are " + "'unprocessed', 'filtered', 'post-reference', and 'post-interpolation'." + ) + + eeg_data = self.raw_eeg._data # Default to most recent stage of pipeline + if stage: + eeg_data = stages[stage.lower()] + if not eeg_data: + raise ValueError( + "Could not retrieve {stage} data, as that stage of the pipeline " + "has not yet been performed." + ) + full_raw = self.raw_eeg.copy() + full_raw._data = eeg_data + if self.raw_non_eeg is not None: + full_raw.add_channels([self.raw_non_eeg]) + + return full_raw + def remove_line_noise(self, line_freqs): """Remove line noise from all EEG channels using multi-taper decomposition. From 46256b39ad8e1e5b673a1d30a31c206e1397a0ab Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Wed, 4 Aug 2021 19:30:48 -0300 Subject: [PATCH 10/11] Fix full PREP example --- examples/run_full_prep.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/run_full_prep.py b/examples/run_full_prep.py index 4ce75047..8a0efeff 100644 --- a/examples/run_full_prep.py +++ b/examples/run_full_prep.py @@ -104,9 +104,12 @@ # # You can check the detected bad channels in each step of PREP. +original_bads = prep.bad_channels["original"] +post_interp_bads = prep.bad_channels["post-interpolation"] + print("Bad channels: {}".format(prep.interpolated_channels)) -print("Bad channels original: {}".format(prep.noisy_channels_original["bad_all"])) -print("Bad channels after interpolation: {}".format(prep.still_noisy_channels)) +print("Bad channels original: {}".format(original_bads)) +print("Bad channels after interpolation: {}".format(post_interp_bads)) # Matlab's results # ---------------- From 6ff27552ba9321598940c05c256f9623fd3251fc Mon Sep 17 00:00:00 2001 From: Austin Hurst Date: Wed, 4 Aug 2021 19:37:00 -0300 Subject: [PATCH 11/11] Update names and scalings for PREP example --- examples/run_full_prep.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/run_full_prep.py b/examples/run_full_prep.py index 8a0efeff..2f019d63 100644 --- a/examples/run_full_prep.py +++ b/examples/run_full_prep.py @@ -172,7 +172,7 @@ EEG_new_matlab = sio.loadmat(fname_mat2) EEG_new_matlab = EEG_new_matlab["save_data"] -EEG_new = removeTrend(prep.EEG_raw, sample_rate=prep.sfreq) +EEG_new = removeTrend(prep.EEG_raw, sample_rate=prep.sfreq) * 1e6 EEG_new_max = np.max(abs(EEG_new), axis=None) EEG_new_diff = EEG_new - EEG_new_matlab EEG_new_mse = ((EEG_new_diff / EEG_new_max) ** 2).mean(axis=None) @@ -205,7 +205,7 @@ EEG_clean_matlab = sio.loadmat(fname_mat3) EEG_clean_matlab = EEG_clean_matlab["save_data"] -EEG_clean = prep.EEG_filtered +EEG_clean = prep.EEG_filtered * 1e6 EEG_max = np.max(abs(EEG_clean), axis=None) EEG_diff = EEG_clean - EEG_clean_matlab EEG_mse = ((EEG_diff / EEG_max) ** 2).mean(axis=None) @@ -236,14 +236,14 @@ axs[2, 1].set_title("Line-noise removed EEG", fontsize=14) axs[2, 0].set_ylabel("Channel Number", fontsize=14) -EEG = prep.EEG_before_interpolation +EEG = prep.EEG_post_reference * 1e6 EEG_max = np.max(abs(EEG), axis=None) EEG_ref_mat = sio.loadmat(fname_mat4) EEG_ref_matlab = EEG_ref_mat["save_EEG"] reference_matlab = EEG_ref_mat["save_reference"] EEG_ref_diff = EEG - EEG_ref_matlab EEG_ref_mse = ((EEG_ref_diff / EEG_max) ** 2).mean(axis=None) -reference_signal = prep.reference_before_interpolation +reference_signal = prep.robust_reference_signal * 1e6 reference_max = np.max(abs(reference_signal), axis=None) reference_diff = reference_signal - reference_matlab reference_mse = ((reference_diff / reference_max) ** 2).mean(axis=None)