From 42e900bd46e85fb42c43c3a93e142a4cfdca25f0 Mon Sep 17 00:00:00 2001 From: anascacais Date: Fri, 28 Mar 2025 13:56:17 +0100 Subject: [PATCH 1/2] enlarge plots --- sefef/evaluation.py | 119 ++++++++++++++++++++++++----------------- sefef/scoring.py | 17 +++--- sefef/visualization.py | 27 +++++++--- 3 files changed, 99 insertions(+), 64 deletions(-) diff --git a/sefef/evaluation.py b/sefef/evaluation.py index e6389eb..9efcd4c 100644 --- a/sefef/evaluation.py +++ b/sefef/evaluation.py @@ -113,7 +113,7 @@ def split(self, dataset, iteratively=False, plot=False, extend_final_test_set=Fa dataset_lead_sz = self._get_lead_sz_dataset(dataset) if self.initial_train_duration is None: - total_recorded_duration = dataset_lead_sz.files_metadata['total_duration'].sum() + total_recorded_duration = dataset_lead_sz.metadata['duration'].sum() if total_recorded_duration == 0: raise ValueError(f"Dataset is empty.") self.initial_train_duration = (1/3) * total_recorded_duration @@ -122,7 +122,7 @@ def split(self, dataset, iteratively=False, plot=False, extend_final_test_set=Fa self.test_duration = (1/2) * self.initial_train_duration # Check basic conditions - if dataset_lead_sz.files_metadata['total_duration'].sum() < self.initial_train_duration + self.test_duration: + if dataset_lead_sz.metadata['duration'].sum() < self.initial_train_duration + self.test_duration: raise ValueError( f"Dataset does not contain enough data to do this split. Just give up (or decrease 'initial_train_duration' ({self.initial_train_duration}) and/or 'test_duration' ({self.test_duration})).") @@ -167,7 +167,7 @@ def _expanding_window_split(self, dataset, initial_cutoff_ts): test_start_ts = test_end_ts try: - test_end_ts = after_train_set.index[after_train_set['total_duration'].cumsum() >= self.test_duration].tolist()[ + test_end_ts = after_train_set.index[after_train_set['duration'].cumsum() >= self.test_duration].tolist()[ 0] test_end_ts = self._check_criteria_split(after_train_set, test_end_ts) split_ind_ts += [[train_start_ts, test_start_ts, test_end_ts]] @@ -185,7 +185,7 @@ def _sliding_window_split(self): def _get_cutoff_ts(self, dataset): """Internal method for getting the first iteration of the cutoff timestamp based on 'self.initial_train_duration'.""" - cutoff_ts = dataset.metadata.index[dataset.metadata['total_duration'].cumsum() > self.initial_train_duration].tolist()[ + cutoff_ts = dataset.metadata.index[dataset.metadata['duration'].cumsum() > self.initial_train_duration].tolist()[ 0] return cutoff_ts @@ -221,10 +221,10 @@ def _check_criteria_initial_split(self, dataset, initial_cutoff_ts): # Criteria 1: min number of events in train criteria_check[0] = ((initial_train_set['sz_onset'].sum() >= self.n_min_events_train) & - (self._check_if_preictal(initial_train_set) >= self.n_min_events_train)) + (self._get_preictal_counts(initial_train_set) >= self.n_min_events_train)) # Criteria 2: min number of events in test criteria_check[1] = ((after_train_set['sz_onset'].sum() >= self.n_min_events_test) & - (self._check_if_preictal(after_train_set) >= self.n_min_events_test)) + (self._get_preictal_counts(after_train_set) >= self.n_min_events_test)) if not all(criteria_check): print( @@ -241,7 +241,7 @@ def _check_criteria_initial_split(self, dataset, initial_cutoff_ts): t += 1 # Check if there's enough data left for at least one test set - if after_train_set['total_duration'].sum() < self.test_duration: + if after_train_set['duration'].sum() < self.test_duration: raise ValueError( f"Dataset does not comply with the conditions for this split. Just give up (or decrease 'n_min_events_train' ({self.n_min_events_train}), 'initial_train_duration' ({self.initial_train_duration}), and/or 'test_duration' ({self.test_duration})).") @@ -252,12 +252,15 @@ def _check_criteria_initial_split(self, dataset, initial_cutoff_ts): return dataset.metadata.iloc[initial_cutoff_ind].name - def _check_if_preictal(self, metadata): + def _get_preictal_counts(self, metadata): + nb_preictal_samples = self._check_if_preictal(metadata, metadata[metadata['sz_onset'] == 1].index.to_numpy()) + return np.count_nonzero(nb_preictal_samples) + + def _check_if_preictal(self, metadata, sz_onsets): '''Internal method that counts the number of seizure onsets for which there exist preictal samples.''' - preictal_starts = metadata[metadata['sz_onset'] == 1].index.to_numpy( - ) - self.preictal_duration - self.prediction_latency - preictal_ends = metadata[metadata['sz_onset'] == 1].index.to_numpy() - self.prediction_latency + preictal_starts = sz_onsets - self.preictal_duration - self.prediction_latency + preictal_ends = sz_onsets - self.prediction_latency # For each seizure onset, count number of samples within preictal period nb_preictal_samples = np.sum(np.logical_and( @@ -265,7 +268,7 @@ def _check_if_preictal(self, metadata): metadata.index.to_numpy()[:, np.newaxis] < preictal_ends[np.newaxis, :], ), axis=0) - return np.count_nonzero(nb_preictal_samples) + return nb_preictal_samples def _check_criteria_split(self, metadata, cutoff_ts): """Internal method for iterating the cutoff timestamp for n>1 folds in order to respect the condition on the minimum number of seizures in test.""" @@ -281,7 +284,7 @@ def _check_criteria_split(self, metadata, cutoff_ts): criteria_check[0] = cutoff_ind <= len(metadata) # Criteria 2: min number of events in test criteria_check[1] = ((test_set['sz_onset'].sum() >= self.n_min_events_test) & - (self._check_if_preictal(test_set) >= self.n_min_events_test)) + (self._get_preictal_counts(test_set) >= self.n_min_events_test)) if not all(criteria_check): print( @@ -323,13 +326,21 @@ def plot(self, dataset, folder_path=None, filename=None, mode='lines'): fig = go.Figure() - file_duration = dataset.metadata['total_duration'].iloc[0] + file_duration = dataset.metadata['duration'].iloc[0] for ifold in range(self.n_folds): train_set = dataset.metadata.loc[self.split_ind_ts[ifold, 0]: self.split_ind_ts[ifold, 1]] test_set = dataset.metadata.loc[self.split_ind_ts[ifold, 1]: self.split_ind_ts[ifold, 2]] + ts_lead_sz = self._get_lead_seizures(train_set[train_set['sz_onset'] == 1].index.to_numpy()) + ts_lead_sz = ts_lead_sz[np.nonzero(self._check_if_preictal(train_set, ts_lead_sz))] + ts_lead_sz = pd.to_datetime(ts_lead_sz, unit='s').to_numpy() + + ts_preictal_sz_test = test_set[test_set['sz_onset'] == 1].index.to_numpy()[np.nonzero(self._check_if_preictal( + test_set, test_set[test_set['sz_onset'] == 1].index.to_numpy()))] + ts_preictal_sz_test = pd.to_datetime(ts_preictal_sz_test, unit='s').to_numpy() + # handle missing data between files train_set = self._handle_missing_data(train_set, ifold+1, file_duration) test_set = self._handle_missing_data(test_set, ifold+1, file_duration) @@ -341,22 +352,28 @@ def plot(self, dataset, folder_path=None, filename=None, mode='lines'): test_set, color=COLOR_PALETTE[1], mode=mode, name='Test', showlegend=(ifold == 0))) # add seizures - ts_lead_sz = self._get_lead_seizures(train_set[train_set['sz_onset'] == 1].index.to_numpy()) fig.add_trace(self._get_scatter_plot_sz( train_set.loc[ts_lead_sz], color=COLOR_PALETTE[0] )) ts_non_lead_sz = train_set[train_set['sz_onset'] == 1].index.to_numpy()[~np.any( train_set[train_set['sz_onset'] == 1].index.to_numpy()[:, np.newaxis] == ts_lead_sz[np.newaxis, :], axis=1)] + ts_no_preictal_sz_test = test_set[test_set['sz_onset'] == 1].index.to_numpy()[~np.any( + test_set[test_set['sz_onset'] == 1].index.to_numpy()[:, np.newaxis] == ts_lead_sz[np.newaxis, :], axis=1)] fig.add_trace(self._get_scatter_plot_sz( train_set.loc[ts_non_lead_sz], color=COLOR_PALETTE[0], opacity=0.5 )) fig.add_trace(self._get_scatter_plot_sz( - test_set[test_set['sz_onset'] == 1], + test_set.loc[ts_preictal_sz_test], color=COLOR_PALETTE[1] )) + fig.add_trace(self._get_scatter_plot_sz( + test_set.loc[ts_no_preictal_sz_test], + color=COLOR_PALETTE[1], + opacity=0.5 + )) # fig.add_trace(self._get_dummy_scatter(color=COLOR_PALETTE[2], showlegend=(ifold == 0))) @@ -367,13 +384,19 @@ def plot(self, dataset, folder_path=None, filename=None, mode='lines'): autorange='reversed', tickvals=list(range(1, self.n_folds+1)), ticktext=[f'Fold {i} ' for i in range(1, self.n_folds+1)], - tickfont=dict(size=12), + tickfont=dict(size=12), # comment ) - fig.update_xaxes(title='Time') + fig.update_xaxes(title='Time', + tickfont=dict(size=12), # comment + ) fig.update_layout( title='Time Series Cross Validation', # showlegend=False, - plot_bgcolor='white') + plot_bgcolor='white', + # font=dict(size=24), # uncomment + # width=1063, + # height=int(1063 / (4/3)) + ) fig.show() if folder_path is not None: @@ -402,7 +425,7 @@ def _get_scatter_plot_sz(self, metadata, color, opacity=1, showlegend=False): color='rgba' + str(hex_to_rgba( h=color, alpha=opacity )), - size=12, + size=12, # 18 symbol='star', ), ) @@ -561,60 +584,58 @@ class Dataset: Attributes ---------- - files_metadata: pd.DataFrame - Input DataFrame with the following columns: - - 'filepath' (str): Path to each file containing data. - - 'first_timestamp' (int64): The Unix-time timestamp (in seconds) of the first sample of each file. - - 'total_duration' (int64): Total duration of file in seconds (equivalent to #samples * sampling_frequency) - It is expected that data within each file is non-overlapping in time and that there are no time gaps between samples in the file. + timestamps : array-like, shape (#samples,) + The Unix-time timestamp (in seconds) of the start timestamp of each sample. + samples_duration : array-like, shape (#samples,) + Duration of samples in seconds. sz_onsets: np.array Contains the Unix-time timestamps (in seconds) corresponding to the onsets of seizures. sampling_frequency: int Frequency at which the data is stored in each file. ''' - def __init__(self, files_metadata, sz_onsets): - self.files_metadata = files_metadata.astype({'first_timestamp': 'int64', 'total_duration': 'int64'}) + def __init__(self, timestamps, samples_duration, sz_onsets): + timestamps = np.array(timestamps, dtype='int64') + samples_duration = np.array(samples_duration, dtype='int64') self.sz_onsets = np.array(sz_onsets) - self.metadata = self._get_metadata() - self.metadata = self.metadata.astype({'filepath': str}) + self.metadata = self._get_metadata(timestamps, samples_duration) - def _get_metadata(self): + def _get_metadata(self, timestamps, samples_duration): """Internal method that updates 'self.metadata' by placing each seizure onset within an acquisition file.""" - timestamps_file_start = self.files_metadata['first_timestamp'].to_numpy() - timestamps_file_end = (self.files_metadata['first_timestamp'] + - self.files_metadata['total_duration']).to_numpy() + timestamps_file_start = timestamps.copy() + timestamps_file_end = timestamps_file_start + samples_duration # identify seizures within existant files sz_onset_indx = np.argwhere((self.sz_onsets[:, np.newaxis] >= timestamps_file_start[np.newaxis, :]) & ( self.sz_onsets[:, np.newaxis] < timestamps_file_end[np.newaxis, :])) - files_metadata = self.files_metadata.copy() - files_metadata['sz_onset'] = 0 - files_metadata.loc[sz_onset_indx[:, 1], 'sz_onset'] = 1 + metadata = pd.DataFrame({'timestamp': timestamps_file_start, 'duration': samples_duration, 'sz_onset': 0}) + metadata.loc[sz_onset_indx[:, 1], 'sz_onset'] = 1 # identify seizures outside of existant files sz_onset_indx = np.argwhere(~np.any(((self.sz_onsets[:, np.newaxis] >= timestamps_file_start[np.newaxis, :]) & ( self.sz_onsets[:, np.newaxis] < timestamps_file_end[np.newaxis, :])), axis=1)).flatten() if len(sz_onset_indx) != 0: - sz_onsets = pd.DataFrame({'first_timestamp': self.sz_onsets[sz_onset_indx], 'sz_onset': [ + sz_onsets = pd.DataFrame({'timestamp': self.sz_onsets[sz_onset_indx], 'sz_onset': [ 1]*len(sz_onset_indx)}, dtype='int64') - files_metadata = pd.merge(files_metadata.reset_index(), sz_onsets.reset_index(), - on='first_timestamp', how='outer', suffixes=('_df1', '_df2')) - files_metadata['sz_onset'] = files_metadata['sz_onset_df1'].combine_first( - files_metadata['sz_onset_df2']).fillna(0).astype('int64') - files_metadata['total_duration'] = files_metadata['total_duration'].fillna(0).astype('int64') + metadata = pd.merge(metadata.reset_index(), sz_onsets.reset_index(), + on='timestamp', how='outer', suffixes=('_df1', '_df2')) + metadata['sz_onset'] = metadata['sz_onset_df1'].combine_first( + metadata['sz_onset_df2']).fillna(0).astype('int64') + metadata['duration'] = metadata['duration'].fillna( + 0).astype('int64') - files_metadata.set_index(pd.Index(files_metadata['first_timestamp'].to_numpy(), dtype='int64'), inplace=True) - files_metadata = files_metadata.loc[:, ['filepath', 'total_duration', 'sz_onset']] + metadata.set_index(pd.Index( + metadata['timestamp'].to_numpy(), dtype='int64'), inplace=True) + metadata = metadata.loc[:, ['duration', 'sz_onset']] try: - files_metadata = pd.concat(( - files_metadata, pd.DataFrame([[np.nan, 0, 0]], columns=files_metadata.columns, index=pd.Series( - [files_metadata.iloc[-1].name+files_metadata.iloc[-1]['total_duration']], dtype='int64')), + metadata = pd.concat(( + metadata, pd.DataFrame([[0, 0]], columns=metadata.columns, index=pd.Series( + [metadata.iloc[-1].name+metadata.iloc[-1]['duration']], dtype='int64')), ), ignore_index=False) # add empty row at the end for indexing except IndexError: pass - return files_metadata + return metadata diff --git a/sefef/scoring.py b/sefef/scoring.py index 9c1df93..db9788f 100644 --- a/sefef/scoring.py +++ b/sefef/scoring.py @@ -129,7 +129,7 @@ def _get_counts(self, forecasts, timestamps_start_forecast, threshold): (self.sz_onsets[:, np.newaxis] >= timestamps_start_forecast[np.newaxis, :]) & (self.sz_onsets[:, np.newaxis] <= timestamps_end_forecast[np.newaxis, :]) & (forecasts >= threshold), - axis=1) + axis=0) no_sz_forecasts = forecasts[~np.any( (self.sz_onsets[:, np.newaxis] >= timestamps_start_forecast[np.newaxis, :]) @@ -177,7 +177,7 @@ def _get_bins_indx(self, forecasts, binning_method, num_bins): num_bins = np.ceil(len(forecasts)**(1/3)).astype('int64') if binning_method == 'uniform': - bin_edges = np.linspace(0, 1, num_bins + 1) + bin_edges = np.linspace(min(forecasts), max(forecasts), num_bins + 1) elif binning_method == 'quantile': percentile = np.linspace(0, 100, num_bins + 1) bin_edges = np.percentile(np.sort(forecasts), percentile)[1:] # remove edge corresponding to 0th percentile @@ -308,7 +308,7 @@ def reliability_diagram(self, forecasts, timestamps, binning_method, num_bins): x=diagram_data.loc[:, 'forecasted_proba'], y=diagram_data.loc[:, 'observed_proba'], mode='lines', - line=dict(width=3, color=COLOR_PALETTE[1]), + line=dict(width=3, color=COLOR_PALETTE[1]), # width 5 name='Reliability curve' )) @@ -323,7 +323,7 @@ def reliability_diagram(self, forecasts, timestamps, binning_method, num_bins): fig.add_trace(go.Scatter( x=[0, 1], y=[0, 1], - line=dict(width=3, color=COLOR_PALETTE[0], dash='dash'), + line=dict(width=3, color=COLOR_PALETTE[0], dash='dash'), # width 5 # showlegend=False, mode='lines', name='Perfect reliability' @@ -332,7 +332,7 @@ def reliability_diagram(self, forecasts, timestamps, binning_method, num_bins): fig.add_trace(go.Scatter( x=[0, 1], y=[y_avg, y_avg], - line=dict(width=3, color='lightgrey', dash='dash'), + line=dict(width=3, color='lightgrey', dash='dash'), # width 5 mode='lines', name='No resolution' )) @@ -340,14 +340,14 @@ def reliability_diagram(self, forecasts, timestamps, binning_method, num_bins): # Config plot layout fig.update_yaxes( title='observed probability', - tickfont=dict(size=12), + tickfont=dict(size=12), # comment showline=True, linewidth=2, linecolor=COLOR_PALETTE[2], showgrid=False, range=[diagram_data.min().min(), diagram_data.max().max()] ) fig.update_xaxes( title='forecasted probability', - tickfont=dict(size=12), + tickfont=dict(size=12), # comment showline=True, linewidth=2, linecolor=COLOR_PALETTE[2], showgrid=False, range=[diagram_data.min().min(), diagram_data.max().max()], @@ -356,5 +356,8 @@ def reliability_diagram(self, forecasts, timestamps, binning_method, num_bins): title=f'Reliability diagram (binning method: {binning_method})', showlegend=True, plot_bgcolor='white', + # font=dict(size=24), # uncomment + # width=1063, + # height=int(1063 / (4/3)), ) fig.show() diff --git a/sefef/visualization.py b/sefef/visualization.py index 89744c5..28e2a00 100644 --- a/sefef/visualization.py +++ b/sefef/visualization.py @@ -103,7 +103,7 @@ def plot_forecasts(forecasts, ts, sz_onsets, high_likelihood_thr, forecast_horiz x=pd.to_datetime(df['ts'], unit='s'), y=df['forecasts'], mode='lines', line_color='white', - line_width=3, + line_width=3, # 5 name='Forecast' )) @@ -112,7 +112,7 @@ def plot_forecasts(forecasts, ts, sz_onsets, high_likelihood_thr, forecast_horiz mode='markers', marker=dict( color=COLOR_PALETTE[1], - size=12, + size=12, # 18 symbol='star', ), name='Seizure' @@ -125,7 +125,7 @@ def plot_forecasts(forecasts, ts, sz_onsets, high_likelihood_thr, forecast_horiz fig.update_yaxes( title='Probability', showgrid=False, - tickfont=dict(size=12), + tickfont=dict(size=12), # comment range=[np.max([0, np.min(non_nan_forecasts) - np.std(non_nan_forecasts)]), np.min([1, np.max(non_nan_forecasts) + np.std(non_nan_forecasts)])], ) @@ -134,6 +134,9 @@ def plot_forecasts(forecasts, ts, sz_onsets, high_likelihood_thr, forecast_horiz showgrid=False, ) fig.update_layout( + # font=dict(size=24), # uncomment + # width=1063, + # height=1000, title=title, showlegend=True, plot_bgcolor='white', @@ -173,7 +176,7 @@ def aggregate_plots(figs, folder_path=None, filename=None, show=True,): Figures to aggregate into a single plot. ''' fig = make_subplots(rows=1, cols=len( - figs), shared_yaxes=True, horizontal_spacing=0.005) + figs), shared_yaxes=True, horizontal_spacing=0.05) # 0.005 forecasts = [] for ifig, figure in enumerate(figs): print(f'Aggregating forecast plots ({ifig+1}/{len(figs)})', end='\r') @@ -195,6 +198,10 @@ def aggregate_plots(figs, folder_path=None, filename=None, show=True,): fig.add_annotation(annotation) fig.update_xaxes( range=[x0, x1], row=1, col=(ifig+1)) + # Apply intercalated tick angle (0° for even-indexed subplots, 45° for odd) + # for i in range(1, len(figs) + 1): + # angle = 0 if i % 2 == 0 else -90 # Alternating angles + # fig.update_xaxes(tickangle=angle, row=(i-1)//2 + 1, col=(i-1) % 2 + 1) fig.update_yaxes( showticklabels=(ifig == 0), row=1, col=(ifig+1)) @@ -203,23 +210,27 @@ def aggregate_plots(figs, folder_path=None, filename=None, show=True,): # Update layout forecasts = np.concat(forecasts) fig.update_layout( + # font=dict(size=24), # uncomment + # width=1063, + # height=int(1063 / (4/3)), title_text=figure['layout']['title']['text'], showlegend=True, plot_bgcolor='white', legend_bgcolor='whitesmoke', yaxis_title='Probability', annotations=[ dict( text="Time", - x=0.5, y=-0.15, + x=0.5, y=-0.11, xref="paper", yref="paper", showarrow=False, - font=dict(size=14) + font=dict(size=24) # dict(size=14) )],) fig.update_xaxes(showgrid=False,) + non_nan_forecasts = forecasts[~np.isnan(forecasts)] fig.update_yaxes( showgrid=False, - range=[np.max([0, np.min(forecasts) - np.std(forecasts)]), - np.min([1, np.max(forecasts) + np.std(forecasts)])]) + range=[np.max([0, np.min(non_nan_forecasts) - np.std(non_nan_forecasts)]), + np.min([1, np.max(non_nan_forecasts) + np.std(non_nan_forecasts)])]) if folder_path is not None: if not os.path.exists(folder_path): From 112027c69a88dc871a4cb13afe09d5694c321ce7 Mon Sep 17 00:00:00 2001 From: anascacais Date: Fri, 28 Mar 2025 14:04:41 +0100 Subject: [PATCH 2/2] simplify instanciation of Dataset --- sefef/evaluation.py | 10 ++-------- sefef/scoring.py | 14 +++++--------- sefef/visualization.py | 44 ++++-------------------------------------- setup.py | 2 +- 4 files changed, 12 insertions(+), 58 deletions(-) diff --git a/sefef/evaluation.py b/sefef/evaluation.py index 9efcd4c..ed0159c 100644 --- a/sefef/evaluation.py +++ b/sefef/evaluation.py @@ -375,8 +375,6 @@ def plot(self, dataset, folder_path=None, filename=None, mode='lines'): opacity=0.5 )) - # fig.add_trace(self._get_dummy_scatter(color=COLOR_PALETTE[2], showlegend=(ifold == 0))) - # Config plot layout fig.update_yaxes( title='TSCV folds', @@ -384,18 +382,14 @@ def plot(self, dataset, folder_path=None, filename=None, mode='lines'): autorange='reversed', tickvals=list(range(1, self.n_folds+1)), ticktext=[f'Fold {i} ' for i in range(1, self.n_folds+1)], - tickfont=dict(size=12), # comment + tickfont=dict(size=12), ) fig.update_xaxes(title='Time', - tickfont=dict(size=12), # comment + tickfont=dict(size=12), ) fig.update_layout( title='Time Series Cross Validation', - # showlegend=False, plot_bgcolor='white', - # font=dict(size=24), # uncomment - # width=1063, - # height=int(1063 / (4/3)) ) fig.show() diff --git a/sefef/scoring.py b/sefef/scoring.py index db9788f..2ecc037 100644 --- a/sefef/scoring.py +++ b/sefef/scoring.py @@ -308,7 +308,7 @@ def reliability_diagram(self, forecasts, timestamps, binning_method, num_bins): x=diagram_data.loc[:, 'forecasted_proba'], y=diagram_data.loc[:, 'observed_proba'], mode='lines', - line=dict(width=3, color=COLOR_PALETTE[1]), # width 5 + line=dict(width=3, color=COLOR_PALETTE[1]), name='Reliability curve' )) @@ -323,8 +323,7 @@ def reliability_diagram(self, forecasts, timestamps, binning_method, num_bins): fig.add_trace(go.Scatter( x=[0, 1], y=[0, 1], - line=dict(width=3, color=COLOR_PALETTE[0], dash='dash'), # width 5 - # showlegend=False, + line=dict(width=3, color=COLOR_PALETTE[0], dash='dash'), mode='lines', name='Perfect reliability' )) @@ -332,7 +331,7 @@ def reliability_diagram(self, forecasts, timestamps, binning_method, num_bins): fig.add_trace(go.Scatter( x=[0, 1], y=[y_avg, y_avg], - line=dict(width=3, color='lightgrey', dash='dash'), # width 5 + line=dict(width=3, color='lightgrey', dash='dash'), mode='lines', name='No resolution' )) @@ -340,14 +339,14 @@ def reliability_diagram(self, forecasts, timestamps, binning_method, num_bins): # Config plot layout fig.update_yaxes( title='observed probability', - tickfont=dict(size=12), # comment + tickfont=dict(size=12), showline=True, linewidth=2, linecolor=COLOR_PALETTE[2], showgrid=False, range=[diagram_data.min().min(), diagram_data.max().max()] ) fig.update_xaxes( title='forecasted probability', - tickfont=dict(size=12), # comment + tickfont=dict(size=12), showline=True, linewidth=2, linecolor=COLOR_PALETTE[2], showgrid=False, range=[diagram_data.min().min(), diagram_data.max().max()], @@ -356,8 +355,5 @@ def reliability_diagram(self, forecasts, timestamps, binning_method, num_bins): title=f'Reliability diagram (binning method: {binning_method})', showlegend=True, plot_bgcolor='white', - # font=dict(size=24), # uncomment - # width=1063, - # height=int(1063 / (4/3)), ) fig.show() diff --git a/sefef/visualization.py b/sefef/visualization.py index 28e2a00..f5f80d6 100644 --- a/sefef/visualization.py +++ b/sefef/visualization.py @@ -30,32 +30,6 @@ def hex_to_rgba(h, alpha): return tuple([int(h.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)] + [alpha]) -def _color_fader(prob, thr=0.5, ll='#FFFFC7', lh='#FFC900', hl='#FF9300', hh='#FF0000'): - ''' Fade (interpolate) from color c1 to c2 with a non-linear transformation, according to the provided threshold. - - Parameters - ---------- - ll_color, lh_color, hl_color, hh_color : any format supported by matplotlib, e.g., 'blue', '#FF0000' - prob : float64 - Value between 0 and 1 corresponding to the probability of a seizure happening. - thr : float64 - Value between 0 and 1 corresponding to the threshold - - Returns - ------- - A hex string representing the blended color. - ''' - ll_color = np.array(mpl.colors.to_rgb(ll)) - lh_color = np.array(mpl.colors.to_rgb(lh)) - hl_color = np.array(mpl.colors.to_rgb(hl)) - hh_color = np.array(mpl.colors.to_rgb(hh)) - - if prob <= thr: - return mpl.colors.to_hex((1 - prob/thr) * ll_color + (prob/thr) * lh_color) - else: - return mpl.colors.to_hex((1 - ((prob-thr)/(1-thr))) * hl_color + ((prob-thr)/(1-thr)) * hh_color) - - def plot_forecasts(forecasts, ts, sz_onsets, high_likelihood_thr, forecast_horizon, title='Seizure probability', folder_path=None, filename=None, show=True, return_plot=False, n_points=100): ''' Provide visualization of forecasts. @@ -103,7 +77,7 @@ def plot_forecasts(forecasts, ts, sz_onsets, high_likelihood_thr, forecast_horiz x=pd.to_datetime(df['ts'], unit='s'), y=df['forecasts'], mode='lines', line_color='white', - line_width=3, # 5 + line_width=3, name='Forecast' )) @@ -112,7 +86,7 @@ def plot_forecasts(forecasts, ts, sz_onsets, high_likelihood_thr, forecast_horiz mode='markers', marker=dict( color=COLOR_PALETTE[1], - size=12, # 18 + size=12, symbol='star', ), name='Seizure' @@ -125,7 +99,7 @@ def plot_forecasts(forecasts, ts, sz_onsets, high_likelihood_thr, forecast_horiz fig.update_yaxes( title='Probability', showgrid=False, - tickfont=dict(size=12), # comment + tickfont=dict(size=12), range=[np.max([0, np.min(non_nan_forecasts) - np.std(non_nan_forecasts)]), np.min([1, np.max(non_nan_forecasts) + np.std(non_nan_forecasts)])], ) @@ -134,9 +108,6 @@ def plot_forecasts(forecasts, ts, sz_onsets, high_likelihood_thr, forecast_horiz showgrid=False, ) fig.update_layout( - # font=dict(size=24), # uncomment - # width=1063, - # height=1000, title=title, showlegend=True, plot_bgcolor='white', @@ -198,10 +169,6 @@ def aggregate_plots(figs, folder_path=None, filename=None, show=True,): fig.add_annotation(annotation) fig.update_xaxes( range=[x0, x1], row=1, col=(ifig+1)) - # Apply intercalated tick angle (0° for even-indexed subplots, 45° for odd) - # for i in range(1, len(figs) + 1): - # angle = 0 if i % 2 == 0 else -90 # Alternating angles - # fig.update_xaxes(tickangle=angle, row=(i-1)//2 + 1, col=(i-1) % 2 + 1) fig.update_yaxes( showticklabels=(ifig == 0), row=1, col=(ifig+1)) @@ -210,9 +177,6 @@ def aggregate_plots(figs, folder_path=None, filename=None, show=True,): # Update layout forecasts = np.concat(forecasts) fig.update_layout( - # font=dict(size=24), # uncomment - # width=1063, - # height=int(1063 / (4/3)), title_text=figure['layout']['title']['text'], showlegend=True, plot_bgcolor='white', legend_bgcolor='whitesmoke', yaxis_title='Probability', @@ -223,7 +187,7 @@ def aggregate_plots(figs, folder_path=None, filename=None, show=True,): xref="paper", yref="paper", showarrow=False, - font=dict(size=24) # dict(size=14) + font=dict(size=14) )],) fig.update_xaxes(showgrid=False,) non_nan_forecasts = forecasts[~np.isnan(forecasts)] diff --git a/setup.py b/setup.py index 357f543..3e3a5fb 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ def read_readme(): setup( name='sefef', - version='2.3.3', + version='3.0.0', license="BSD 3-clause", description='SeFEF: Seizure Forecasting Evaluation Framework', long_description=read_readme(),