Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 66 additions & 51 deletions sefef/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def split(self, dataset, iteratively=False, plot=False, extend_final_test_set=Fa
dataset_lead_sz = self._get_lead_sz_dataset(dataset)

if self.initial_train_duration is None:
total_recorded_duration = dataset_lead_sz.files_metadata['total_duration'].sum()
total_recorded_duration = dataset_lead_sz.metadata['duration'].sum()
if total_recorded_duration == 0:
raise ValueError(f"Dataset is empty.")
self.initial_train_duration = (1/3) * total_recorded_duration
Expand All @@ -122,7 +122,7 @@ def split(self, dataset, iteratively=False, plot=False, extend_final_test_set=Fa
self.test_duration = (1/2) * self.initial_train_duration

# Check basic conditions
if dataset_lead_sz.files_metadata['total_duration'].sum() < self.initial_train_duration + self.test_duration:
if dataset_lead_sz.metadata['duration'].sum() < self.initial_train_duration + self.test_duration:
raise ValueError(
f"Dataset does not contain enough data to do this split. Just give up (or decrease 'initial_train_duration' ({self.initial_train_duration}) and/or 'test_duration' ({self.test_duration})).")

Expand Down Expand Up @@ -167,7 +167,7 @@ def _expanding_window_split(self, dataset, initial_cutoff_ts):
test_start_ts = test_end_ts

try:
test_end_ts = after_train_set.index[after_train_set['total_duration'].cumsum() >= self.test_duration].tolist()[
test_end_ts = after_train_set.index[after_train_set['duration'].cumsum() >= self.test_duration].tolist()[
0]
test_end_ts = self._check_criteria_split(after_train_set, test_end_ts)
split_ind_ts += [[train_start_ts, test_start_ts, test_end_ts]]
Expand All @@ -185,7 +185,7 @@ def _sliding_window_split(self):

def _get_cutoff_ts(self, dataset):
"""Internal method for getting the first iteration of the cutoff timestamp based on 'self.initial_train_duration'."""
cutoff_ts = dataset.metadata.index[dataset.metadata['total_duration'].cumsum() > self.initial_train_duration].tolist()[
cutoff_ts = dataset.metadata.index[dataset.metadata['duration'].cumsum() > self.initial_train_duration].tolist()[
0]
return cutoff_ts

Expand Down Expand Up @@ -221,10 +221,10 @@ def _check_criteria_initial_split(self, dataset, initial_cutoff_ts):

# Criteria 1: min number of events in train
criteria_check[0] = ((initial_train_set['sz_onset'].sum() >= self.n_min_events_train) &
(self._check_if_preictal(initial_train_set) >= self.n_min_events_train))
(self._get_preictal_counts(initial_train_set) >= self.n_min_events_train))
# Criteria 2: min number of events in test
criteria_check[1] = ((after_train_set['sz_onset'].sum() >= self.n_min_events_test) &
(self._check_if_preictal(after_train_set) >= self.n_min_events_test))
(self._get_preictal_counts(after_train_set) >= self.n_min_events_test))

if not all(criteria_check):
print(
Expand All @@ -241,7 +241,7 @@ def _check_criteria_initial_split(self, dataset, initial_cutoff_ts):
t += 1

# Check if there's enough data left for at least one test set
if after_train_set['total_duration'].sum() < self.test_duration:
if after_train_set['duration'].sum() < self.test_duration:
raise ValueError(
f"Dataset does not comply with the conditions for this split. Just give up (or decrease 'n_min_events_train' ({self.n_min_events_train}), 'initial_train_duration' ({self.initial_train_duration}), and/or 'test_duration' ({self.test_duration})).")

Expand All @@ -252,20 +252,23 @@ def _check_criteria_initial_split(self, dataset, initial_cutoff_ts):

return dataset.metadata.iloc[initial_cutoff_ind].name

def _check_if_preictal(self, metadata):
def _get_preictal_counts(self, metadata):
nb_preictal_samples = self._check_if_preictal(metadata, metadata[metadata['sz_onset'] == 1].index.to_numpy())
return np.count_nonzero(nb_preictal_samples)

def _check_if_preictal(self, metadata, sz_onsets):
'''Internal method that counts the number of seizure onsets for which there exist preictal samples.'''

preictal_starts = metadata[metadata['sz_onset'] == 1].index.to_numpy(
) - self.preictal_duration - self.prediction_latency
preictal_ends = metadata[metadata['sz_onset'] == 1].index.to_numpy() - self.prediction_latency
preictal_starts = sz_onsets - self.preictal_duration - self.prediction_latency
preictal_ends = sz_onsets - self.prediction_latency

# For each seizure onset, count number of samples within preictal period
nb_preictal_samples = np.sum(np.logical_and(
metadata.index.to_numpy()[:, np.newaxis] >= preictal_starts[np.newaxis, :],
metadata.index.to_numpy()[:, np.newaxis] < preictal_ends[np.newaxis, :],
), axis=0)

return np.count_nonzero(nb_preictal_samples)
return nb_preictal_samples

def _check_criteria_split(self, metadata, cutoff_ts):
"""Internal method for iterating the cutoff timestamp for n>1 folds in order to respect the condition on the minimum number of seizures in test."""
Expand All @@ -281,7 +284,7 @@ def _check_criteria_split(self, metadata, cutoff_ts):
criteria_check[0] = cutoff_ind <= len(metadata)
# Criteria 2: min number of events in test
criteria_check[1] = ((test_set['sz_onset'].sum() >= self.n_min_events_test) &
(self._check_if_preictal(test_set) >= self.n_min_events_test))
(self._get_preictal_counts(test_set) >= self.n_min_events_test))

if not all(criteria_check):
print(
Expand Down Expand Up @@ -323,13 +326,21 @@ def plot(self, dataset, folder_path=None, filename=None, mode='lines'):

fig = go.Figure()

file_duration = dataset.metadata['total_duration'].iloc[0]
file_duration = dataset.metadata['duration'].iloc[0]

for ifold in range(self.n_folds):

train_set = dataset.metadata.loc[self.split_ind_ts[ifold, 0]: self.split_ind_ts[ifold, 1]]
test_set = dataset.metadata.loc[self.split_ind_ts[ifold, 1]: self.split_ind_ts[ifold, 2]]

ts_lead_sz = self._get_lead_seizures(train_set[train_set['sz_onset'] == 1].index.to_numpy())
ts_lead_sz = ts_lead_sz[np.nonzero(self._check_if_preictal(train_set, ts_lead_sz))]
ts_lead_sz = pd.to_datetime(ts_lead_sz, unit='s').to_numpy()

ts_preictal_sz_test = test_set[test_set['sz_onset'] == 1].index.to_numpy()[np.nonzero(self._check_if_preictal(
test_set, test_set[test_set['sz_onset'] == 1].index.to_numpy()))]
ts_preictal_sz_test = pd.to_datetime(ts_preictal_sz_test, unit='s').to_numpy()

# handle missing data between files
train_set = self._handle_missing_data(train_set, ifold+1, file_duration)
test_set = self._handle_missing_data(test_set, ifold+1, file_duration)
Expand All @@ -341,24 +352,28 @@ def plot(self, dataset, folder_path=None, filename=None, mode='lines'):
test_set, color=COLOR_PALETTE[1], mode=mode, name='Test', showlegend=(ifold == 0)))

# add seizures
ts_lead_sz = self._get_lead_seizures(train_set[train_set['sz_onset'] == 1].index.to_numpy())
fig.add_trace(self._get_scatter_plot_sz(
train_set.loc[ts_lead_sz],
color=COLOR_PALETTE[0]
))
ts_non_lead_sz = train_set[train_set['sz_onset'] == 1].index.to_numpy()[~np.any(
train_set[train_set['sz_onset'] == 1].index.to_numpy()[:, np.newaxis] == ts_lead_sz[np.newaxis, :], axis=1)]
ts_no_preictal_sz_test = test_set[test_set['sz_onset'] == 1].index.to_numpy()[~np.any(
test_set[test_set['sz_onset'] == 1].index.to_numpy()[:, np.newaxis] == ts_lead_sz[np.newaxis, :], axis=1)]
fig.add_trace(self._get_scatter_plot_sz(
train_set.loc[ts_non_lead_sz],
color=COLOR_PALETTE[0],
opacity=0.5
))
fig.add_trace(self._get_scatter_plot_sz(
test_set[test_set['sz_onset'] == 1],
test_set.loc[ts_preictal_sz_test],
color=COLOR_PALETTE[1]
))

# fig.add_trace(self._get_dummy_scatter(color=COLOR_PALETTE[2], showlegend=(ifold == 0)))
fig.add_trace(self._get_scatter_plot_sz(
test_set.loc[ts_no_preictal_sz_test],
color=COLOR_PALETTE[1],
opacity=0.5
))

# Config plot layout
fig.update_yaxes(
Expand All @@ -369,11 +384,13 @@ def plot(self, dataset, folder_path=None, filename=None, mode='lines'):
ticktext=[f'Fold {i} ' for i in range(1, self.n_folds+1)],
tickfont=dict(size=12),
)
fig.update_xaxes(title='Time')
fig.update_xaxes(title='Time',
tickfont=dict(size=12),
)
fig.update_layout(
title='Time Series Cross Validation',
# showlegend=False,
plot_bgcolor='white')
plot_bgcolor='white',
)
fig.show()

if folder_path is not None:
Expand Down Expand Up @@ -402,7 +419,7 @@ def _get_scatter_plot_sz(self, metadata, color, opacity=1, showlegend=False):
color='rgba' + str(hex_to_rgba(
h=color, alpha=opacity
)),
size=12,
size=12, # 18
symbol='star',
),
)
Expand Down Expand Up @@ -561,60 +578,58 @@ class Dataset:

Attributes
----------
files_metadata: pd.DataFrame
Input DataFrame with the following columns:
- 'filepath' (str): Path to each file containing data.
- 'first_timestamp' (int64): The Unix-time timestamp (in seconds) of the first sample of each file.
- 'total_duration' (int64): Total duration of file in seconds (equivalent to #samples * sampling_frequency)
It is expected that data within each file is non-overlapping in time and that there are no time gaps between samples in the file.
timestamps : array-like, shape (#samples,)
The Unix-time timestamp (in seconds) of the start timestamp of each sample.
samples_duration : array-like, shape (#samples,)
Duration of samples in seconds.
sz_onsets: np.array
Contains the Unix-time timestamps (in seconds) corresponding to the onsets of seizures.
sampling_frequency: int
Frequency at which the data is stored in each file.
'''

def __init__(self, files_metadata, sz_onsets):
self.files_metadata = files_metadata.astype({'first_timestamp': 'int64', 'total_duration': 'int64'})
def __init__(self, timestamps, samples_duration, sz_onsets):
timestamps = np.array(timestamps, dtype='int64')
samples_duration = np.array(samples_duration, dtype='int64')
self.sz_onsets = np.array(sz_onsets)

self.metadata = self._get_metadata()
self.metadata = self.metadata.astype({'filepath': str})
self.metadata = self._get_metadata(timestamps, samples_duration)

def _get_metadata(self):
def _get_metadata(self, timestamps, samples_duration):
"""Internal method that updates 'self.metadata' by placing each seizure onset within an acquisition file."""

timestamps_file_start = self.files_metadata['first_timestamp'].to_numpy()
timestamps_file_end = (self.files_metadata['first_timestamp'] +
self.files_metadata['total_duration']).to_numpy()
timestamps_file_start = timestamps.copy()
timestamps_file_end = timestamps_file_start + samples_duration

# identify seizures within existant files
sz_onset_indx = np.argwhere((self.sz_onsets[:, np.newaxis] >= timestamps_file_start[np.newaxis, :]) & (
self.sz_onsets[:, np.newaxis] < timestamps_file_end[np.newaxis, :]))

files_metadata = self.files_metadata.copy()
files_metadata['sz_onset'] = 0
files_metadata.loc[sz_onset_indx[:, 1], 'sz_onset'] = 1
metadata = pd.DataFrame({'timestamp': timestamps_file_start, 'duration': samples_duration, 'sz_onset': 0})
metadata.loc[sz_onset_indx[:, 1], 'sz_onset'] = 1

# identify seizures outside of existant files
sz_onset_indx = np.argwhere(~np.any(((self.sz_onsets[:, np.newaxis] >= timestamps_file_start[np.newaxis, :]) & (
self.sz_onsets[:, np.newaxis] < timestamps_file_end[np.newaxis, :])), axis=1)).flatten()
if len(sz_onset_indx) != 0:
sz_onsets = pd.DataFrame({'first_timestamp': self.sz_onsets[sz_onset_indx], 'sz_onset': [
sz_onsets = pd.DataFrame({'timestamp': self.sz_onsets[sz_onset_indx], 'sz_onset': [
1]*len(sz_onset_indx)}, dtype='int64')
files_metadata = pd.merge(files_metadata.reset_index(), sz_onsets.reset_index(),
on='first_timestamp', how='outer', suffixes=('_df1', '_df2'))
files_metadata['sz_onset'] = files_metadata['sz_onset_df1'].combine_first(
files_metadata['sz_onset_df2']).fillna(0).astype('int64')
files_metadata['total_duration'] = files_metadata['total_duration'].fillna(0).astype('int64')
metadata = pd.merge(metadata.reset_index(), sz_onsets.reset_index(),
on='timestamp', how='outer', suffixes=('_df1', '_df2'))
metadata['sz_onset'] = metadata['sz_onset_df1'].combine_first(
metadata['sz_onset_df2']).fillna(0).astype('int64')
metadata['duration'] = metadata['duration'].fillna(
0).astype('int64')

files_metadata.set_index(pd.Index(files_metadata['first_timestamp'].to_numpy(), dtype='int64'), inplace=True)
files_metadata = files_metadata.loc[:, ['filepath', 'total_duration', 'sz_onset']]
metadata.set_index(pd.Index(
metadata['timestamp'].to_numpy(), dtype='int64'), inplace=True)
metadata = metadata.loc[:, ['duration', 'sz_onset']]

try:
files_metadata = pd.concat((
files_metadata, pd.DataFrame([[np.nan, 0, 0]], columns=files_metadata.columns, index=pd.Series(
[files_metadata.iloc[-1].name+files_metadata.iloc[-1]['total_duration']], dtype='int64')),
metadata = pd.concat((
metadata, pd.DataFrame([[0, 0]], columns=metadata.columns, index=pd.Series(
[metadata.iloc[-1].name+metadata.iloc[-1]['duration']], dtype='int64')),
), ignore_index=False) # add empty row at the end for indexing
except IndexError:
pass
return files_metadata
return metadata
5 changes: 2 additions & 3 deletions sefef/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _get_counts(self, forecasts, timestamps_start_forecast, threshold):
(self.sz_onsets[:, np.newaxis] >= timestamps_start_forecast[np.newaxis, :])
& (self.sz_onsets[:, np.newaxis] <= timestamps_end_forecast[np.newaxis, :])
& (forecasts >= threshold),
axis=1)
axis=0)

no_sz_forecasts = forecasts[~np.any(
(self.sz_onsets[:, np.newaxis] >= timestamps_start_forecast[np.newaxis, :])
Expand Down Expand Up @@ -177,7 +177,7 @@ def _get_bins_indx(self, forecasts, binning_method, num_bins):
num_bins = np.ceil(len(forecasts)**(1/3)).astype('int64')

if binning_method == 'uniform':
bin_edges = np.linspace(0, 1, num_bins + 1)
bin_edges = np.linspace(min(forecasts), max(forecasts), num_bins + 1)
elif binning_method == 'quantile':
percentile = np.linspace(0, 100, num_bins + 1)
bin_edges = np.percentile(np.sort(forecasts), percentile)[1:] # remove edge corresponding to 0th percentile
Expand Down Expand Up @@ -324,7 +324,6 @@ def reliability_diagram(self, forecasts, timestamps, binning_method, num_bins):
x=[0, 1],
y=[0, 1],
line=dict(width=3, color=COLOR_PALETTE[0], dash='dash'),
# showlegend=False,
mode='lines',
name='Perfect reliability'
))
Expand Down
35 changes: 5 additions & 30 deletions sefef/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,6 @@ def hex_to_rgba(h, alpha):
return tuple([int(h.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)] + [alpha])


def _color_fader(prob, thr=0.5, ll='#FFFFC7', lh='#FFC900', hl='#FF9300', hh='#FF0000'):
''' Fade (interpolate) from color c1 to c2 with a non-linear transformation, according to the provided threshold.

Parameters
----------
ll_color, lh_color, hl_color, hh_color : any format supported by matplotlib, e.g., 'blue', '#FF0000'
prob : float64
Value between 0 and 1 corresponding to the probability of a seizure happening.
thr : float64
Value between 0 and 1 corresponding to the threshold

Returns
-------
A hex string representing the blended color.
'''
ll_color = np.array(mpl.colors.to_rgb(ll))
lh_color = np.array(mpl.colors.to_rgb(lh))
hl_color = np.array(mpl.colors.to_rgb(hl))
hh_color = np.array(mpl.colors.to_rgb(hh))

if prob <= thr:
return mpl.colors.to_hex((1 - prob/thr) * ll_color + (prob/thr) * lh_color)
else:
return mpl.colors.to_hex((1 - ((prob-thr)/(1-thr))) * hl_color + ((prob-thr)/(1-thr)) * hh_color)


def plot_forecasts(forecasts, ts, sz_onsets, high_likelihood_thr, forecast_horizon, title='Seizure probability', folder_path=None, filename=None, show=True, return_plot=False, n_points=100):
''' Provide visualization of forecasts.

Expand Down Expand Up @@ -173,7 +147,7 @@ def aggregate_plots(figs, folder_path=None, filename=None, show=True,):
Figures to aggregate into a single plot.
'''
fig = make_subplots(rows=1, cols=len(
figs), shared_yaxes=True, horizontal_spacing=0.005)
figs), shared_yaxes=True, horizontal_spacing=0.05) # 0.005
forecasts = []
for ifig, figure in enumerate(figs):
print(f'Aggregating forecast plots ({ifig+1}/{len(figs)})', end='\r')
Expand Down Expand Up @@ -209,17 +183,18 @@ def aggregate_plots(figs, folder_path=None, filename=None, show=True,):
annotations=[
dict(
text="Time",
x=0.5, y=-0.15,
x=0.5, y=-0.11,
xref="paper",
yref="paper",
showarrow=False,
font=dict(size=14)
)],)
fig.update_xaxes(showgrid=False,)
non_nan_forecasts = forecasts[~np.isnan(forecasts)]
fig.update_yaxes(
showgrid=False,
range=[np.max([0, np.min(forecasts) - np.std(forecasts)]),
np.min([1, np.max(forecasts) + np.std(forecasts)])])
range=[np.max([0, np.min(non_nan_forecasts) - np.std(non_nan_forecasts)]),
np.min([1, np.max(non_nan_forecasts) + np.std(non_nan_forecasts)])])

if folder_path is not None:
if not os.path.exists(folder_path):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def read_readme():

setup(
name='sefef',
version='2.3.3',
version='3.0.0',
license="BSD 3-clause",
description='SeFEF: Seizure Forecasting Evaluation Framework',
long_description=read_readme(),
Expand Down
Loading