diff --git a/CITATION.cff b/CITATION.cff index 1da3c61..a90ecc3 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -55,6 +55,10 @@ authors: family-names: Veillette affiliation: 'Department of Psychology, University of Chicago, Chicago, IL, USA' orcid: 'https://orcid.org/0000-0002-0332-4372' + - given-names: Roy Eric + family-names: Wieske + affiliation: 'Biopsychology and Neuroergonomics, Technische Universität Berlin, Berlin, Germany' + orcid: 'https://orcid.org/0009-0006-2018-1074' type: software repository-code: 'https://github.com/sappelhoff/pyprep' license: MIT diff --git a/docs/authors.rst b/docs/authors.rst index 0ddb8f6..bccf1ed 100644 --- a/docs/authors.rst +++ b/docs/authors.rst @@ -11,3 +11,4 @@ .. _Victor Xiang: https://github.com/Nick3151 .. _Yorguin Mantilla: https://github.com/yjmantilla .. _John Veillette: https://github.com/john-veillette +.. _Roy Eric Wieske: https://github.com/Randomidous diff --git a/docs/changelog.rst b/docs/changelog.rst index 0d7a00f..ffec374 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -26,6 +26,8 @@ Version 0.6.0 (unreleased) Changelog ~~~~~~~~~ +- Added :meth:`~pyprep.NoisyChannels.find_bad_by_PSD` method for detecting channels with abnormally high or low power spectral density. This is a PyPREP-only feature not present in MATLAB PREP, by `Roy Eric Wieske`_ (:gh:`145`) +- Added ``reject_by_annotation`` parameter to :class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, and :class:`~pyprep.NoisyChannels` to exclude BAD-annotated time segments from channel quality assessment, by `Roy Eric Wieske`_ (:gh:`180`) - Users can now determine whether or not to use ``correlation`` as a method for finding bad channels in :meth:`~pyprep.NoisyChannels.find_all_bads` (defaults to True), by `Stefan Appelhoff`_ (:gh:`169`) - Manually marked bad channels are ignored for finding further bads (just like NaN and flat channels) in :meth:`~pyprep.NoisyChannels.find_all_bads`, by `Stefan Appelhoff`_ (:gh:`168`) diff --git a/docs/matlab_differences.rst b/docs/matlab_differences.rst index 9e3b668..8d5327f 100644 --- a/docs/matlab_differences.rst +++ b/docs/matlab_differences.rst @@ -233,6 +233,72 @@ MATLAB PREP, PyPREP will use a Python reimplementation of ``eeg_interp`` instead when the ``matlab_strict`` parameter is set to ``True``. +Annotation-Based Segment Rejection +---------------------------------- + +PyPREP supports the ``reject_by_annotation`` parameter in +:class:`~pyprep.PrepPipeline`, :class:`~pyprep.Reference`, and +:class:`~pyprep.NoisyChannels`, which allows excluding BAD-annotated time +segments from channel quality assessment. BAD segments are any MNE annotations +with descriptions starting with "BAD" or "bad" (see +:ref:`mne:tut-reject-data-spans` for details). This is useful when recordings +contain breaks, movement artifacts, or other periods that shouldn't influence +channel rejection decisions. + +MATLAB PREP does not have this feature. In fact, MATLAB PREP explicitly warns +against using PREP on data with discontinuities (such as boundary markers from +paused/resumed recordings). However, the ``reject_by_annotation`` feature in +PyPREP is designed for a different use case: temporarily excluding known-bad +segments (e.g., participant movement during breaks) from *statistical analysis* +while preserving the original continuous data structure in the output. + +When ``reject_by_annotation`` is set to ``'omit'``, MNE's +:meth:`~mne.io.Raw.get_data` is used to concatenate non-BAD segments for +computing channel quality metrics. The final processed output retains the +original continuous structure with all time points intact. + +.. note:: + + This feature is intended for excluding a small number of longer segments + (e.g., recording breaks). Using it with many short BAD segments (e.g., from + automated muscle artifact detection via + :func:`mne.preprocessing.annotate_muscle_zscore`) may introduce edge effects + at concatenation boundaries, particularly for methods that apply filtering + to the concatenated data. PyPREP will emit a warning if many small BAD + segments are detected. + +This parameter has no equivalent in MATLAB PREP. When ``matlab_strict`` is set +to ``True``, ``reject_by_annotation`` is automatically set to ``None``. + + +PyPREP-Only Features +-------------------- + +The following features are available in PyPREP but are not present in the +original MATLAB PREP implementation. + + +Bad channel detection by PSD +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :meth:`~pyprep.NoisyChannels.find_bad_by_PSD` method detects channels with +abnormally high or low power spectral density (PSD) compared to other channels. +This method is not part of the original MATLAB PREP pipeline, but can be +considered a refinement of the ``bad_by_hfnoise`` detection in MATLAB PREP, +which flags channels based on the ratio of high-frequency power (>50 Hz) to +total power. + +A channel is considered "bad-by-PSD" if its total PSD (computed using Welch's +method over a configurable frequency range, defaulting to 1-45 Hz to exclude +line noise) deviates considerably from the median channel PSD. The deviation +is calculated using robust Z-scoring based on the median absolute deviation +(MAD). + +This method is called by :meth:`~pyprep.NoisyChannels.find_all_bads` by default, +but is skipped when ``matlab_strict=True`` to maintain equivalence with the +original MATLAB PREP pipeline. + + References ---------- diff --git a/matprep_artifacts b/matprep_artifacts index c7e99e5..6c272f3 160000 --- a/matprep_artifacts +++ b/matprep_artifacts @@ -1 +1 @@ -Subproject commit c7e99e5329afc505ba071f856242f47373fe87a7 +Subproject commit 6c272f3da47eb7dc4ef029ef5a87cd3b1dc05157 diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index 92af43a..79253e0 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -57,6 +57,14 @@ class NoisyChannels: List of channels that are bad. These channels will be excluded when trying to find additional bad channels. Note that the union of these channels and those declared in ``raw.info["bads"]`` will be used. Defaults to ``None``. + reject_by_annotation : {None, 'omit'} | None + How to handle BAD-annotated time segments (annotations starting with + "BAD" or "bad") during channel quality assessment. If ``'omit'``, + annotated segments are excluded from analysis (clean segments are + concatenated). If ``None`` (default), annotations are ignored and the + full recording is used. This is useful when recordings contain breaks + or movement artifacts that shouldn't influence channel rejection + decisions. References ---------- @@ -76,6 +84,7 @@ def __init__( ransac=True, correlation=True, bad_by_manual=None, + reject_by_annotation=None, ): # Make sure that we got an MNE object assert isinstance(raw, mne.io.BaseRaw) @@ -100,12 +109,51 @@ def __init__( assert isinstance(correlation, bool), msg self.correlation = correlation + # Validate reject_by_annotation parameter + if reject_by_annotation is not None and reject_by_annotation != "omit": + raise ValueError( + f"reject_by_annotation must be None or 'omit', " + f"got: {reject_by_annotation}" + ) + # reject_by_annotation is not available in MATLAB PREP + if matlab_strict and reject_by_annotation is not None: + logger.warning( + "reject_by_annotation is not available in MATLAB PREP. " + f"Setting reject_by_annotation to None (was '{reject_by_annotation}')." + ) + reject_by_annotation = None + self.reject_by_annotation = reject_by_annotation + + # Warn if many small BAD segments are present (potential edge effects) + if reject_by_annotation is not None: + bad_annots = [ + a + for a in raw.annotations + if a["description"].startswith(("BAD", "bad")) + ] + n_bad_segments = len(bad_annots) + if n_bad_segments > 0: + total_bad_time = sum(a["duration"] for a in bad_annots) + recording_length = raw.times[-1] + bad_percentage = (total_bad_time / recording_length) * 100 + mean_duration = total_bad_time / n_bad_segments + if bad_percentage > 15 and mean_duration < 5.0: + logger.warning( + f"Found {n_bad_segments} BAD segments covering " + f"{bad_percentage:.1f}% of the recording with mean duration " + f"{mean_duration:.1f}s. Using reject_by_annotation with many " + "short segments may introduce edge effects from concatenation. " + "This feature is intended for excluding a small number of " + "longer segments (e.g., recording breaks)." + ) + # Extra data for debugging self._extra_info = { "bad_by_deviation": {}, "bad_by_hf_noise": {}, "bad_by_correlation": {}, "bad_by_dropout": {}, + "bad_by_psd": {}, "bad_by_ransac": {}, } @@ -120,13 +168,14 @@ def __init__( self.bad_by_correlation = [] self.bad_by_SNR = [] self.bad_by_dropout = [] + self.bad_by_psd = [] self.bad_by_ransac = [] # Get original EEG channel names, channel count & samples ch_names = np.asarray(self.raw_mne.info["ch_names"]) self.ch_names_original = ch_names self.n_chans_original = len(ch_names) - self.n_samples = raw.get_data().shape[1] + self.n_samples_original = raw.get_data().shape[1] # Before anything else, flag bad-by-NaNs and bad-by-flats self.find_bad_by_nan_flat() @@ -137,7 +186,11 @@ def __init__( # Make a subset of the data containing only usable EEG channels self.usable_idx = np.isin(ch_names, bads_unusable, invert=True) - self.EEGData = self.raw_mne.get_data(picks=ch_names[self.usable_idx]) + self.EEGData = self.raw_mne.get_data( + picks=ch_names[self.usable_idx], + reject_by_annotation=self.reject_by_annotation, + ) + self.n_samples = self.EEGData.shape[1] self.EEGFiltered = None # Get usable EEG channel names & channel counts @@ -173,10 +226,10 @@ def get_bads(self, verbose=False, as_dict=False): Parameters ---------- - verbose : bool, optional + verbose : bool | None If ``True``, a summary of the channels currently flagged as by bad per category is printed. Defaults to ``False``. - as_dict: bool, optional + as_dict: bool | None If ``True``, this method will return a dict of the channels currently flagged as bad by each individual bad channel type. If ``False``, this method will return a list of all unique bad channels detected so far. @@ -197,6 +250,7 @@ def get_bads(self, verbose=False, as_dict=False): "bad_by_correlation": self.bad_by_correlation, "bad_by_SNR": self.bad_by_SNR, "bad_by_dropout": self.bad_by_dropout, + "bad_by_psd": self.bad_by_psd, "bad_by_ransac": self.bad_by_ransac, "bad_by_manual": self.bad_by_manual, } @@ -205,7 +259,12 @@ def get_bads(self, verbose=False, as_dict=False): for bad_chs in bads.values(): all_bads.update(bad_chs) - name_map = {"nan": "NaN", "hf_noise": "HF noise", "ransac": "RANSAC"} + name_map = { + "nan": "NaN", + "hf_noise": "HF noise", + "psd": "PSD", + "ransac": "RANSAC", + } if verbose: out = f"Found {len(all_bads)} uniquely bad channels:\n" for bad_type, bad_chs in bads.items(): @@ -223,7 +282,13 @@ def get_bads(self, verbose=False, as_dict=False): return bads def find_all_bads( - self, *, ransac=None, channel_wise=False, max_chunk_size=None, correlation=None + self, + *, + ransac=None, + channel_wise=False, + max_chunk_size=None, + correlation=None, + reject_by_annotation=None, ): """Call all the functions to detect bad channels. @@ -238,7 +303,7 @@ def find_all_bads( detection considerably. If ``None`` (default), then the value at instantiation of the ``NoisyChannels`` class is taken (defaults to ``True``), else the instantiation value is overwritten. - channel_wise : bool, optional + channel_wise : bool | None Whether RANSAC should predict signals for chunks of channels over the entire signal length ("channel-wise RANSAC", see `max_chunk_size` parameter). If ``False``, RANSAC will instead predict signals for all @@ -248,7 +313,7 @@ def find_all_bads( (especially if `max_chunk_size` is ``None``), but can be faster on systems with lots of RAM to spare. Has no effect if not using RANSAC. Defaults to ``False``. - max_chunk_size : {int, None}, optional + max_chunk_size : {int, None} | None The maximum number of channels to predict at once during channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk size that will fit into the available RAM, which may slow down @@ -260,8 +325,17 @@ def find_all_bads( to the other methods. If ``None`` (default), then the value at instantiation of the ``NoisyChannels`` class is taken (defaults to ``True``), else the instantiation value is overwritten. + reject_by_annotation : {None, 'omit'} | None + This parameter is accepted for compatibility but is ignored here. + Annotation rejection is applied during ``NoisyChannels`` initialization, + not during ``find_all_bads``. To use annotation rejection, pass + ``reject_by_annotation`` to the ``NoisyChannels`` constructor. """ + # Note: reject_by_annotation is accepted but ignored here - it's applied + # during __init__ when data is extracted. This parameter exists only for + # compatibility with ransac_settings dict unpacking. + del reject_by_annotation # unused, applied in __init__ if ransac is not None and ransac != self.ransac: msg = f"ransac must be boolean, got: {ransac}" assert isinstance(ransac, bool), msg @@ -288,6 +362,8 @@ def find_all_bads( if self.correlation: self.find_bad_by_correlation() self.find_bad_by_SNR() + if not self.matlab_strict: + self.find_bad_by_PSD() if self.ransac: self.find_bad_by_ransac( channel_wise=channel_wise, max_chunk_size=max_chunk_size @@ -306,7 +382,7 @@ def find_bad_by_nan_flat(self, flat_threshold=1e-15): Parameters ---------- - flat_threshold : float, optional + flat_threshold : float | None The lowest standard deviation or MAD value for a channel to be considered bad-by-flat. Defaults to ``1e-15`` volts (corresponds to 10e-10 µV in MATLAB PREP). @@ -343,7 +419,7 @@ def find_bad_by_deviation(self, deviation_threshold=5.0): Parameters ---------- - deviation_threshold : float, optional + deviation_threshold : float | None The minimum absolute z-score of a channel for it to be considered bad-by-deviation. Defaults to ``5.0``. @@ -389,7 +465,7 @@ def find_bad_by_hfnoise(self, HF_zscore_threshold=5.0): Parameters ---------- - HF_zscore_threshold : float, optional + HF_zscore_threshold : float | None The minimum noisiness z-score of a channel for it to be considered bad-by-high-frequency-noise. Defaults to ``5.0``. @@ -454,12 +530,12 @@ def find_bad_by_correlation( Parameters ---------- - correlation_secs : float, optional + correlation_secs : float | None The length (in seconds) of each correlation window. Defaults to ``1.0``. - correlation_threshold : float, optional + correlation_threshold : float | None The lowest maximum inter-channel correlation for a channel to be considered "bad" within a given window. Defaults to ``0.4``. - frac_bad : float, optional + frac_bad : float | None The minimum proportion of bad windows for a channel to be considered "bad-by-correlation" or "bad-by-dropout". Defaults to ``0.01`` (1% of all windows). @@ -556,6 +632,150 @@ def find_bad_by_SNR(self): # Flag channels bad by both HF noise and low correlation as bad by low SNR self.bad_by_SNR = list(bad_by_corr.intersection(bad_by_hf)) + def find_bad_by_PSD(self, zscore_threshold=3.0, fmin=1.0, fmax=45.0): + """Detect channels with abnormally high or low power spectral density. + + This is a PyPREP-only method not present in the original MATLAB PREP. + + A channel is considered "bad-by-psd" if: + 1. Its power in any frequency band (low: 1-15 Hz, mid: 15-30 Hz, + high: 30-45 Hz) is abnormally HIGH compared to other channels, OR + 2. Its high-frequency band has more power than its low-frequency band + (violating the typical 1/f spectral profile of EEG). + + Note: Only excess power (positive z-scores) is flagged, as abnormally + low power could reflect normal topographic variation. + + PSD is computed using Welch's method over the specified frequency range. + The default range (1-45 Hz) excludes line noise frequencies (50/60 Hz). + + Parameters + ---------- + zscore_threshold : float, optional + The minimum absolute z-score of a channel for it to be considered + bad-by-psd. Defaults to ``3.0``. + fmin : float, optional + The lower frequency bound (in Hz) for PSD computation. + Defaults to ``1.0``. + fmax : float, optional + The upper frequency bound (in Hz) for PSD computation. The default + of ``45.0`` excludes 50/60 Hz line noise from the analysis. + + """ + MAD_TO_SD = 1.4826 # Scales units of MAD to units of SD, assuming normality + # Reference: https://stat.ethz.ch/R-manual/R-devel/library/stats/html/mad.html + + # Define frequency bands (in Hz) + BAND_LOW = (fmin, 15.0) # ~ delta, theta, alpha + BAND_MID = (15.0, 30.0) # ~ beta + BAND_HIGH = (30.0, fmax) # ~ gamma + + if self.EEGFiltered is None: + self.EEGFiltered = self._get_filtered_data() + + # Create a temporary Raw object from filtered data for PSD computation + info = mne.create_info( + ch_names=self.ch_names_new.tolist(), + sfreq=self.sample_rate, + ch_types="eeg", + ) + raw_filtered = mne.io.RawArray(self.EEGFiltered, info, verbose=False) + + # Compute PSD using Welch method and convert to log scale (dB) + psd = raw_filtered.compute_psd( + method="welch", fmin=fmin, fmax=fmax, verbose=False + ) + psd_data = psd.get_data() + freqs = psd.freqs + log_psd = 10 * np.log10(psd_data) + + # Get frequency indices for each band + idx_low = (freqs >= BAND_LOW[0]) & (freqs < BAND_LOW[1]) + idx_mid = (freqs >= BAND_MID[0]) & (freqs < BAND_MID[1]) + idx_high = (freqs >= BAND_HIGH[0]) & (freqs <= BAND_HIGH[1]) + + # Compute band power (sum of log PSD within each band) for each channel + band_power_low = np.sum(log_psd[:, idx_low], axis=1) + band_power_mid = np.sum(log_psd[:, idx_mid], axis=1) + band_power_high = np.sum(log_psd[:, idx_high], axis=1) + + def robust_zscore(values): + """Compute robust z-scores using MAD.""" + median = np.median(values) + mad = np.median(np.abs(values - median)) + sd = mad * MAD_TO_SD + if sd > 0: + return (values - median) / sd + return np.zeros_like(values) + + # Criterion 1: Outlier with abnormally HIGH power in any band + # Note: Only positive z-scores (excess power) are flagged, as low power + # could reflect normal topographic variation rather than a bad channel + zscore_low = robust_zscore(band_power_low) + zscore_mid = robust_zscore(band_power_mid) + zscore_high = robust_zscore(band_power_high) + + bad_by_band = ( + (zscore_low > zscore_threshold) + | (zscore_mid > zscore_threshold) + | (zscore_high > zscore_threshold) + ) + + # Criterion 2: 1/f violation (high freq band has more power than low freq band) + # This is unusual for normal EEG and suggests muscle artifact or bad contact + bad_by_1f_violation = band_power_high > band_power_low + + # Criterion 3: Abnormal band ratios compared to other channels + # Use small epsilon to avoid division by zero + eps = np.finfo(float).eps + ratio_low_mid = band_power_low / (band_power_mid + eps) + ratio_low_high = band_power_low / (band_power_high + eps) + ratio_mid_high = band_power_mid / (band_power_high + eps) + + zscore_ratio_low_mid = robust_zscore(ratio_low_mid) + zscore_ratio_low_high = robust_zscore(ratio_low_high) + zscore_ratio_mid_high = robust_zscore(ratio_mid_high) + + bad_by_ratio = ( + (np.abs(zscore_ratio_low_mid) > zscore_threshold) + | (np.abs(zscore_ratio_low_high) > zscore_threshold) + | (np.abs(zscore_ratio_mid_high) > zscore_threshold) + ) + + # Combine criteria (bad if ANY criterion is met) + # Note: bad_by_ratio is computed for diagnostics but not used in final + # decision as it tends to be overly sensitive and theoretically debatable + bad_by_psd_usable = bad_by_band | bad_by_1f_violation + + # Map back to original channel indices + psd_channel_mask = np.zeros(self.n_chans_original, dtype=bool) + psd_channel_mask[self.usable_idx] = bad_by_psd_usable + abnormal_psd_channels = self.ch_names_original[psd_channel_mask] + + # Compute combined z-score for reporting (max absolute z-score across bands) + psd_zscore = np.zeros(self.n_chans_original) + max_band_zscore = np.maximum( + np.abs(zscore_low), np.maximum(np.abs(zscore_mid), np.abs(zscore_high)) + ) + psd_zscore[self.usable_idx] = max_band_zscore + + # Update names of bad channels by abnormal PSD & save additional info + self.bad_by_psd = abnormal_psd_channels.tolist() + self._extra_info["bad_by_psd"].update( + { + "psd_zscore": psd_zscore, + "band_power_low": band_power_low, + "band_power_mid": band_power_mid, + "band_power_high": band_power_high, + "zscore_low": zscore_low, + "zscore_mid": zscore_mid, + "zscore_high": zscore_high, + "bad_by_band": bad_by_band, + "bad_by_1f_violation": bad_by_1f_violation, + "bad_by_ratio": bad_by_ratio, + } + ) + def find_bad_by_ransac( self, n_samples=50, @@ -598,26 +818,26 @@ def find_bad_by_ransac( Parameters ---------- - n_samples : int, optional + n_samples : int | None Number of random channel samples to use for RANSAC. Defaults to ``50``. - sample_prop : float, optional + sample_prop : float | None Proportion of total channels to use for signal prediction per RANSAC sample. This needs to be in the range [0, 1], where 0 would mean no channels would be used and 1 would mean all channels would be used (neither of which would be useful values). Defaults to ``0.25`` (e.g., 16 channels per sample for a 64-channel dataset). - corr_thresh : float, optional + corr_thresh : float | None The minimum predicted vs. actual signal correlation for a channel to be considered good within a given RANSAC window. Defaults to ``0.75``. - frac_bad : float, optional + frac_bad : float | None The minimum fraction of bad (i.e., below-threshold) RANSAC windows for a channel to be considered bad-by-RANSAC. Defaults to ``0.4``. - corr_window_secs : float, optional + corr_window_secs : float | None The duration (in seconds) of each RANSAC correlation window. Defaults to 5 seconds. - channel_wise : bool, optional + channel_wise : bool | None Whether RANSAC should predict signals for chunks of channels over the entire signal length ("channel-wise RANSAC", see `max_chunk_size` parameter). If ``False``, RANSAC will instead predict signals for all @@ -626,7 +846,7 @@ def find_bad_by_ransac( RANSAC generally has higher RAM demands than window-wise RANSAC (especially if `max_chunk_size` is ``None``), but can be faster on systems with lots of RAM to spare. Defaults to ``False``. - max_chunk_size : {int, None}, optional + max_chunk_size : {int, None} | None The maximum number of channels to predict at once during channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk size that will fit into the available RAM, which may slow down diff --git a/pyprep/prep_pipeline.py b/pyprep/prep_pipeline.py index 73e29a1..bcf038a 100644 --- a/pyprep/prep_pipeline.py +++ b/pyprep/prep_pipeline.py @@ -39,15 +39,15 @@ class PrepPipeline: For example, for 60Hz you may specify ``np.arange(60, sfreq / 2, 60)``. Specify an empty list to skip the line noise removal step. - - max_iterations : int, optional + - max_iterations : int | None - The maximum number of iterations of noisy channel removal to perform during robust referencing. Defaults to ``4``. montage : mne.channels.DigMontage Digital montage of EEG data. - ransac : bool, optional + ransac : bool | None Whether or not to use RANSAC for noisy channel detection in addition to the other methods in :class:`~pyprep.NoisyChannels`. Defaults to True. - channel_wise : bool, optional + channel_wise : bool | None Whether RANSAC should predict signals for chunks of channels over the entire signal length ("channel-wise RANSAC", see `max_chunk_size` parameter). If ``False``, RANSAC will instead predict signals for all @@ -57,24 +57,32 @@ class PrepPipeline: (especially if `max_chunk_size` is ``None``), but can be faster on systems with lots of RAM to spare. Has no effect if not using RANSAC. Defaults to ``False``. - max_chunk_size : {int, None}, optional + max_chunk_size : {int, None} | None The maximum number of channels to predict at once during channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk size that will fit into the available RAM, which may slow down other programs on the host system. If using window-wise RANSAC (the default) or not using RANSAC at all, this parameter has no effect. Defaults to ``None``. - random_state : {int, None, np.random.RandomState}, optional + random_state : {int, None, np.random.RandomState} | None The random seed at which to initialize the class. If random_state is an int, it will be used as a seed for RandomState. If None, the seed will be obtained from the operating system (see RandomState for details). Default is None. - filter_kwargs : {dict, None}, optional + filter_kwargs : {dict, None} | None Optional keywords arguments to be passed on to mne.filter.notch_filter. Do not set the "x", Fs", and "freqs" arguments via the filter_kwargs parameter, but use the "raw" and "prep_params" parameters instead. If None is passed, the pyprep default settings for filtering are used instead. - matlab_strict : bool, optional + reject_by_annotation : {None, 'omit'} | None + How to handle BAD-annotated time segments (annotations starting with + "BAD" or "bad") during channel quality assessment. If ``'omit'``, + annotated segments are excluded from analysis (clean segments are + concatenated). If ``None`` (default), annotations are ignored and the + full recording is used. This is useful when recordings contain breaks + or movement artifacts that shouldn't influence channel rejection + decisions. + matlab_strict : bool | None Whether or not PyPREP should strictly follow MATLAB PREP's internal math, ignoring any improvements made in PyPREP over the original code (see :ref:`matlab-diffs` for more details). Defaults to False. @@ -128,6 +136,7 @@ def __init__( max_chunk_size=None, random_state=None, filter_kwargs=None, + reject_by_annotation=None, matlab_strict=False, ): """Initialize PREP class.""" @@ -167,6 +176,7 @@ def __init__( "ransac": ransac, "channel_wise": channel_wise, "max_chunk_size": max_chunk_size, + "reject_by_annotation": reject_by_annotation, } self.random_state = check_random_state(random_state) self.filter_kwargs = filter_kwargs diff --git a/pyprep/ransac.py b/pyprep/ransac.py index e16bc05..c077ebc 100644 --- a/pyprep/ransac.py +++ b/pyprep/ransac.py @@ -60,24 +60,24 @@ def find_bad_by_ransac( exclude : list Labels of channels to exclude as signal predictors during RANSAC (i.e., channels already flagged as bad by metrics other than HF noise). - n_samples : int, optional + n_samples : int | None Number of random channel samples to use for RANSAC. Defaults to ``50``. - sample_prop : float, optional + sample_prop : float | None Proportion of total channels to use for signal prediction per RANSAC sample. This needs to be in the range [0, 1], where 0 would mean no channels would be used and 1 would mean all channels would be used (neither of which would be useful values). Defaults to ``0.25`` (e.g., 16 channels per sample for a 64-channel dataset). - corr_thresh : float, optional + corr_thresh : float | None The minimum predicted vs. actual signal correlation for a channel to be considered good within a given RANSAC window. Defaults to ``0.75``. - frac_bad : float, optional + frac_bad : float | None The minimum fraction of bad (i.e., below-threshold) RANSAC windows for a channel to be considered bad-by-RANSAC. Defaults to ``0.4``. - corr_window_secs : float, optional + corr_window_secs : float | None The duration (in seconds) of each RANSAC correlation window. Defaults to 5 seconds. - channel_wise : bool, optional + channel_wise : bool | None Whether RANSAC should predict signals for chunks of channels over the entire signal length ("channel-wise RANSAC", see `max_chunk_size` parameter). If ``False``, RANSAC will instead predict signals for all @@ -86,18 +86,18 @@ def find_bad_by_ransac( RANSAC generally has higher RAM demands than window-wise RANSAC (especially if `max_chunk_size` is ``None``), but can be faster on systems with lots of RAM to spare. Defaults to ``False``. - max_chunk_size : {int, None}, optional + max_chunk_size : {int, None} | None The maximum number of channels to predict at once during channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk size that will fit into the available RAM, which may slow down other programs on the host system. If using window-wise RANSAC (the default), this parameter has no effect. Defaults to ``None``. - random_state : {int, None, np.random.RandomState}, optional + random_state : {int, None, np.random.RandomState} | None The random seed with which to generate random samples of channels during RANSAC. If random_state is an int, it will be used as a seed for RandomState. If ``None``, the seed will be obtained from the operating system (see RandomState for details). Defaults to ``None``. - matlab_strict : bool, optional + matlab_strict : bool | None Whether or not RANSAC should strictly follow MATLAB PREP's internal math, ignoring any improvements made in PyPREP over the original code (see :ref:`matlab-diffs` for more details). Defaults to ``False``. diff --git a/pyprep/reference.py b/pyprep/reference.py index 99e7af0..458eede 100644 --- a/pyprep/reference.py +++ b/pyprep/reference.py @@ -32,10 +32,10 @@ class Reference: Parameters of PREP which include at least the following keys: - ``ref_chs`` - ``reref_chs`` - ransac : bool, optional + ransac : bool | None Whether or not to use RANSAC for noisy channel detection in addition to the other methods in :class:`~pyprep.NoisyChannels`. Defaults to True. - channel_wise : bool, optional + channel_wise : bool | None Whether RANSAC should predict signals for chunks of channels over the entire signal length ("channel-wise RANSAC", see `max_chunk_size` parameter). If ``False``, RANSAC will instead predict signals for all @@ -45,18 +45,22 @@ class Reference: (especially if `max_chunk_size` is ``None``), but can be faster on systems with lots of RAM to spare. Has no effect if not using RANSAC. Defaults to ``False``. - max_chunk_size : {int, None}, optional + max_chunk_size : {int, None} | None The maximum number of channels to predict at once during channel-wise RANSAC. If ``None``, RANSAC will use the largest chunk size that will fit into the available RAM, which may slow down other programs on the host system. If using window-wise RANSAC (the default) or not using RANSAC at all, this parameter has no effect. Defaults to ``None``. - random_state : {int, None, np.random.RandomState}, optional + random_state : {int, None, np.random.RandomState} | None The random seed at which to initialize the class. If random_state is an int, it will be used as a seed for RandomState. If None, the seed will be obtained from the operating system (see RandomState for details). Default is None. - matlab_strict : bool, optional + reject_by_annotation : {None, 'omit'} | None + How to handle BAD-annotated time segments (annotations starting with + "BAD" or "bad") during channel quality assessment. If ``'omit'``, + annotated segments are excluded. Defaults to ``None`` (ignore). + matlab_strict : bool | None Whether or not PyPREP should strictly follow MATLAB PREP's internal math, ignoring any improvements made in PyPREP over the original code. Defaults to False. @@ -77,6 +81,7 @@ def __init__( channel_wise=False, max_chunk_size=None, random_state=None, + reject_by_annotation=None, matlab_strict=False, ): """Initialize the class.""" @@ -94,6 +99,7 @@ def __init__( "ransac": ransac, "channel_wise": channel_wise, "max_chunk_size": max_chunk_size, + "reject_by_annotation": reject_by_annotation, } self.random_state = check_random_state(random_state) self._extra_info = {} @@ -104,7 +110,7 @@ def perform_reference(self, max_iterations=4): Parameters ---------- - max_iterations : int, optional + max_iterations : int | None The maximum number of iterations of noisy channel removal to perform during robust referencing. Defaults to ``4``. @@ -142,7 +148,10 @@ def perform_reference(self, max_iterations=4): # Phase 2: Find the bad channels and interpolate self.raw._data = self.EEG noisy_detector = NoisyChannels( - self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict + self.raw, + random_state=self.random_state, + matlab_strict=self.matlab_strict, + reject_by_annotation=self.ransac_settings.get("reject_by_annotation"), ) noisy_detector.find_all_bads(**self.ransac_settings) @@ -174,7 +183,10 @@ def perform_reference(self, max_iterations=4): # Still noisy channels after interpolation self.interpolated_channels = bad_channels noisy_detector = NoisyChannels( - self.raw, random_state=self.random_state, matlab_strict=self.matlab_strict + self.raw, + random_state=self.random_state, + matlab_strict=self.matlab_strict, + reject_by_annotation=self.ransac_settings.get("reject_by_annotation"), ) noisy_detector.find_all_bads(**self.ransac_settings) self.still_noisy_channels = noisy_detector.get_bads() @@ -192,7 +204,7 @@ def robust_reference(self, max_iterations=4): Parameters ---------- - max_iterations : int, optional + max_iterations : int | None The maximum number of iterations of noisy channel removal to perform during robust referencing. Defaults to ``4``. @@ -216,6 +228,7 @@ def robust_reference(self, max_iterations=4): do_detrend=False, random_state=self.random_state, matlab_strict=self.matlab_strict, + reject_by_annotation=self.ransac_settings.get("reject_by_annotation"), ) noisy_detector.find_all_bads(**self.ransac_settings) self.noisy_channels_original = noisy_detector.get_bads(as_dict=True) @@ -238,6 +251,7 @@ def robust_reference(self, max_iterations=4): "bad_by_correlation": [], "bad_by_SNR": [], "bad_by_dropout": [], + "bad_by_psd": [], "bad_by_ransac": [], "bad_by_manual": self.bads_manual, "bad_all": [], @@ -265,6 +279,7 @@ def robust_reference(self, max_iterations=4): do_detrend=False, random_state=self.random_state, matlab_strict=self.matlab_strict, + reject_by_annotation=self.ransac_settings.get("reject_by_annotation"), ) # Detrend applied at the beginning of the function. @@ -338,7 +353,7 @@ def remove_reference(signal, reference, index=None): The original EEG signal. reference : np.ndarray, shape(times,) The reference signal. - index : {list, None}, optional + index : {list, None} | None A list of channel indices from which the reference signal should be subtracted. Defaults to all channels in `signal`. diff --git a/pyprep/removeTrend.py b/pyprep/removeTrend.py index 0b9e345..cd6c12c 100644 --- a/pyprep/removeTrend.py +++ b/pyprep/removeTrend.py @@ -27,16 +27,16 @@ def removeTrend( A 2-D array of EEG data to detrend. sample_rate : float The sample rate (in Hz) of the input EEG data. - detrendType : str, optional + detrendType : str | None Type of detrending to be performed: must be one of 'high pass', 'high pass sinc, or 'local detrend'. Defaults to 'high pass'. - detrendCutoff : float, optional + detrendCutoff : float | None The high-pass cutoff frequency (in Hz) to use for detrending. Defaults to 1.0 Hz. - detrendChannels : {list, None}, optional + detrendChannels : {list, None} | None List of the indices of all channels that require detrending/filtering. If ``None``, all channels are used (default). - matlab_strict : bool, optional + matlab_strict : bool | None Whether or not detrending should strictly follow MATLAB PREP's internal math, ignoring any improvements made in PyPREP over the original code (see :ref:`matlab-diffs` for more details). Defaults to ``False``. diff --git a/pyprep/utils.py b/pyprep/utils.py index 9ec3390..c10d536 100644 --- a/pyprep/utils.py +++ b/pyprep/utils.py @@ -56,7 +56,7 @@ def _mat_quantile(arr, q, axis=None): q : float The quantile to calculate for the input data. Must be between 0 and 1, inclusive. - axis : {int, tuple of int, None}, optional + axis : {int, tuple of int, None} | None Axis along which quantile values should be calculated. Defaults to calculating the value at the given quantile for the entire array. @@ -130,7 +130,7 @@ def _mat_iqr(arr, axis=None): ---------- arr : np.ndarray Input array containing samples from the distribution to summarize. - axis : {int, tuple of int, None}, optional + axis : {int, tuple of int, None} | None Axis along which IQRs should be calculated. Defaults to calculating the IQR for the entire array. @@ -435,7 +435,7 @@ def _correlate_arrays(a, b, matlab_strict=False): A 2-D array to correlate with `a`. b : np.ndarray A 2-D array to correlate with `b`. - matlab_strict : bool, optional + matlab_strict : bool | None Whether or not correlations should be calculated identically to MATLAB PREP (i.e., without mean subtraction) instead of by traditional Pearson product-moment correlation (see Notes for details). Defaults to diff --git a/tests/test_find_noisy_channels.py b/tests/test_find_noisy_channels.py index ee18675..ca225eb 100644 --- a/tests/test_find_noisy_channels.py +++ b/tests/test_find_noisy_channels.py @@ -217,6 +217,84 @@ def test_bad_by_SNR(raw_tmp): assert nd.bad_by_SNR == [raw_tmp.ch_names[low_snr_idx]] +def test_bad_by_PSD(raw_tmp): + """Test detection of channels with abnormal power spectral density.""" + # set scaling factors for high and low PSD test channels + low_psd_factor = 0.05 + high_psd_factor = 20.0 + + # make the signal for a random channel have very high power (high PSD) + n_chans = raw_tmp.get_data().shape[0] + high_psd_idx = int(rng.integers(0, n_chans, 1)[0]) + raw_tmp._data[high_psd_idx, :] *= high_psd_factor + + # test detection of abnormally high-PSD channels + nd = NoisyChannels(raw_tmp, do_detrend=False) + nd.find_bad_by_PSD() + assert raw_tmp.ch_names[high_psd_idx] in nd.bad_by_psd + + # verify that extra_info is populated correctly with band-based metrics + extra = nd._extra_info["bad_by_psd"] + assert "psd_zscore" in extra + assert len(extra["psd_zscore"]) == n_chans + # Check band power arrays + assert "band_power_low" in extra + assert "band_power_mid" in extra + assert "band_power_high" in extra + # Check per-band z-scores + assert "zscore_low" in extra + assert "zscore_mid" in extra + assert "zscore_high" in extra + # Check detection criteria flags + assert "bad_by_band" in extra + assert "bad_by_1f_violation" in extra + assert "bad_by_ratio" in extra + + # make the signal for a different channel have very low power (low PSD) + low_psd_idx = (high_psd_idx - 1) if high_psd_idx > 0 else 1 + raw_tmp._data[low_psd_idx, :] *= low_psd_factor + + # test detection of both abnormally high and low PSD channels + nd = NoisyChannels(raw_tmp, do_detrend=False) + nd.find_bad_by_PSD() + assert raw_tmp.ch_names[high_psd_idx] in nd.bad_by_psd + assert ( + raw_tmp.ch_names[low_psd_idx] not in nd.bad_by_psd + ) # the low PSD criterion was ommitted + + # verify that bad_by_psd is included in get_bads() output + all_bads = nd.get_bads(as_dict=True) + assert "bad_by_psd" in all_bads + assert raw_tmp.ch_names[high_psd_idx] in all_bads["bad_all"] + + +def test_bad_by_PSD_1f_violation(raw_tmp): + """Test detection of channels violating the 1/f spectral profile.""" + n_chans = raw_tmp.get_data().shape[0] + bad_idx = int(rng.integers(0, n_chans, 1)[0]) + + # Replace channel with high-frequency dominated signal (violates 1/f) + # Normal EEG has more power in low frequencies than high frequencies + # This channel will have more power in 30-45 Hz than in 1-15 Hz + high_freq_signal = _generate_signal(32, 44, raw_tmp.times, fcount=10) + raw_tmp._data[bad_idx, :] = high_freq_signal * 50 # Strong high-freq signal + + nd = NoisyChannels(raw_tmp, do_detrend=False) + nd.find_bad_by_PSD() + + # Channel should be flagged due to 1/f violation + assert raw_tmp.ch_names[bad_idx] in nd.bad_by_psd + + # Verify the 1/f violation was detected + extra = nd._extra_info["bad_by_psd"] + # Find the index in usable channels (convert boolean mask to int indices) + usable_int_idx = np.where(nd.usable_idx)[0] + usable_names = [raw_tmp.ch_names[i] for i in usable_int_idx] + if raw_tmp.ch_names[bad_idx] in usable_names: + usable_pos = usable_names.index(raw_tmp.ch_names[bad_idx]) + assert extra["bad_by_1f_violation"][usable_pos] + + def test_find_bad_by_ransac(raw_tmp): """Test the RANSAC component of NoisyChannels.""" # Set a consistent random seed for all RANSAC runs @@ -304,3 +382,67 @@ def test_find_bad_by_ransac_err(raw_tmp): nd = NoisyChannels(raw_tmp, do_detrend=False) with pytest.raises(IOError): nd.find_bad_by_ransac() + + +# Tests for reject_by_annotation functionality + + +def test_reject_by_annotation_omit(raw_tmp): + """Test that 'omit' mode excludes annotated segments.""" + # Add a BAD annotation covering 10% of the recording + duration = raw_tmp.times[-1] + raw_tmp.annotations.append( + onset=duration * 0.4, + duration=duration * 0.1, + description="BAD_test", + ) + + original_samples = raw_tmp.get_data().shape[1] + + # With 'omit', sample count should be reduced + nd = NoisyChannels(raw_tmp, do_detrend=False, reject_by_annotation="omit") + assert nd.n_samples_original == original_samples + assert nd.n_samples < original_samples + + # Without rejection, sample count should be unchanged + nd_no_reject = NoisyChannels(raw_tmp, do_detrend=False, reject_by_annotation=None) + assert nd_no_reject.n_samples == original_samples + + +def test_reject_by_annotation_invalid(raw_tmp): + """Test that invalid reject_by_annotation values raise ValueError.""" + with pytest.raises(ValueError, match="reject_by_annotation must be"): + NoisyChannels(raw_tmp, do_detrend=False, reject_by_annotation="invalid") + + with pytest.raises(ValueError, match="reject_by_annotation must be"): + NoisyChannels(raw_tmp, do_detrend=False, reject_by_annotation="skip") + + +def test_reject_by_annotation_data_extraction(raw_tmp): + """Test that reject_by_annotation correctly filters data for analysis.""" + # Add a BAD annotation for the first 30% of the recording + duration = raw_tmp.times[-1] + bad_duration = duration * 0.3 + + raw_tmp.annotations.append( + onset=0.0, + duration=bad_duration, + description="BAD_movement", + ) + + # Create two NoisyChannels instances - one with and one without annotation rejection + nd_no_reject = NoisyChannels(raw_tmp.copy(), do_detrend=False) + nd_with_reject = NoisyChannels( + raw_tmp.copy(), do_detrend=False, reject_by_annotation="omit" + ) + + # With annotation rejection, the data should have fewer samples + assert nd_with_reject.n_samples < nd_no_reject.n_samples + + # The reduction should be approximately 30% (the annotated portion) + expected_reduction = 0.3 + actual_reduction = 1 - (nd_with_reject.n_samples / nd_no_reject.n_samples) + assert abs(actual_reduction - expected_reduction) < 0.05 # 5% tolerance + + # Both should have the same original sample count stored + assert nd_no_reject.n_samples_original == nd_with_reject.n_samples_original