From 54aac773a4dc4925baa2b82892d4d3e5fffe006b Mon Sep 17 00:00:00 2001 From: fedepup Date: Mon, 17 Feb 2025 14:53:41 +0100 Subject: [PATCH 1/2] minor additions --- conda_recipe/meta.yaml | 3 +- docs/_static/bench_table.csv | 53 +++-- docs/selfeeg.augmentation.rst | 1 + extra_material/Augmentation_benchmark.py | 87 ++++----- extra_material/bench_table.csv | 50 ++--- selfeeg/augmentation/__init__.py | 2 +- selfeeg/augmentation/compose.py | 88 ++++++++- selfeeg/augmentation/functional.py | 175 +++++++++-------- selfeeg/dataloading/load.py | 193 +++++++++---------- selfeeg/models/encoders.py | 47 +++-- selfeeg/models/zoo.py | 10 +- setup.py | 1 - test/EEGself/augmentation/compose_test.py | 17 +- test/EEGself/augmentation/functional_test.py | 4 +- 14 files changed, 427 insertions(+), 304 deletions(-) diff --git a/conda_recipe/meta.yaml b/conda_recipe/meta.yaml index 44d84fb..d18df37 100644 --- a/conda_recipe/meta.yaml +++ b/conda_recipe/meta.yaml @@ -11,7 +11,7 @@ source: build: noarch: python - script: {{ PYTHON }} -m pip install . -vv --no-deps --no-build-isolation + script: {{ PYTHON }} -m pip install . -vv number: 0 requirements: @@ -26,7 +26,6 @@ requirements: - scipy >=1.10.1 - pytorch >=2.0.0 - torchaudio >=2.0.2 - - torchvision >=0.15.2 - tqdm test: diff --git a/docs/_static/bench_table.csv b/docs/_static/bench_table.csv index 3cf902c..126bd73 100644 --- a/docs/_static/bench_table.csv +++ b/docs/_static/bench_table.csv @@ -1,27 +1,26 @@ -,Numpy,Numpy,Torch,Torch,Torch GPU,Torch GPU - ,**no BE**,**BE**,**no BE**,**BE**,**no BE**,**BE** -add_band_noise,,12.683,,1.937,,0.364 -add_eeg_artifact,15.919,15.475,18.084,8.774,10.680,0.544 -add_gaussian_noise,,71.040,,35.011,,0.070 -add_noise_SNR,,73.885,,37.647,,12.803 -change_ref,,9.273,,19.219,,4.698 -channel_dropout,7.088,3.348,17.496,0.553,4.641,0.067 -crop_and_resize,107.009,168.657,383.806,39.491,101.634,1.962 -filter_bandpass,,45.507,,152.942,,2.958 -filter_bandstop,,44.845,,145.338,,1.608 -filter_highpass,,37.997,,138.716,,1.111 -filter_lowpass,,46.904,,129.532,,1.201 -flip_horizontal,,0.001,,0.023,,0.009 -flip_vertical,,3.365,,0.0253,,0.008 -masking,8.351,5.401,10.081,1.960,4.046,0.051 -moving_avg,,27.874,,63.697,,0.158 -permutation_signal,12.659,26.512,26.081,1.141,5.567,0.070 -permute_channels_net,43.501,8.773,166.101,2.974,43.669,0.664 -permute_channels,9.332,8.783,34.481,3.281,10.940,0.155 -random_FT_phase,38.544,47.198,90.381,8.5997,12.139,0.258 -random_slope_scale,44.880,20.625,39.141,1.667,0.197,0.093 -scaling,5.069,5.061,0.406,0.415,0.022,0.054 -shift_frequency,47.336,69.527,97.961,12.574,6.523,0.290 -shift_horizontal,7.665,7.356,10.154,0.682,2.694,0.035 -shift_vertical,,6.996,,0.057,,0.070 -warp_signal,247.153,410.710,947.214,86.178,221.647,4.442 +,Numpy Array no BE,Numpy Array BE,Torch Tensor no BE,Torch Tensor BE,Torch Tensor GPU no BE,Torch Tensor GPU BE +add_band_noise,,1.419,,0.133,,0.046 +add_eeg_artifact,1.595,1.724,1.76,0.396,1.119,0.058 +add_gaussian_noise,,7.395,,1.68,,0.007 +add_noise_SNR,,7.695,,1.938,,1.344 +change_ref,,0.767,,0.729,,0.498 +channel_dropout,0.779,0.363,0.455,0.028,0.502,0.007 +crop_and_resize,11.479,17.503,15.851,3.434,10.389,0.2 +filter_bandpass,,4.716,,5.786,,1.035 +filter_bandstop,,4.793,,5.93,,0.166 +filter_highpass,,3.431,,5.961,,0.107 +filter_lowpass,,4.365,,6.036,,0.115 +flip_horizontal,,0.0,,0.022,,0.001 +flip_vertical,,0.358,,0.023,,0.001 +masking,0.815,0.548,0.495,0.039,0.432,0.006 +moving_avg,,3.066,,2.661,,0.009 +permutation_signal,1.055,1.305,0.951,0.242,0.599,0.008 +permute_channels_network,4.667,0.954,6.435,0.157,4.698,0.072 +permute_channels,0.905,0.877,1.237,0.075,1.148,0.016 +random_FT_phase,4.014,4.963,2.819,0.485,1.272,0.026 +random_slope_scale,3.849,2.251,1.783,0.2,0.012,0.009 +scaling,0.562,0.56,0.046,0.046,0.002,0.002 +shift_frequency,5.155,7.262,3.173,0.757,0.692,0.029 +shift_horizontal,0.797,0.787,0.389,0.056,0.279,0.004 +shift_vertical,,0.712,,0.061,,0.002 +warp_signal,25.965,43.759,35.951,7.205,23.263,0.467 diff --git a/docs/selfeeg.augmentation.rst b/docs/selfeeg.augmentation.rst index 393c129..97a3e25 100644 --- a/docs/selfeeg.augmentation.rst +++ b/docs/selfeeg.augmentation.rst @@ -20,6 +20,7 @@ Classes :nosignatures: :template: classtemplate.rst + CircularAug DynamicSingleAug RandomAug SequentialAug diff --git a/extra_material/Augmentation_benchmark.py b/extra_material/Augmentation_benchmark.py index 8faf415..3e84aa5 100644 --- a/extra_material/Augmentation_benchmark.py +++ b/extra_material/Augmentation_benchmark.py @@ -31,7 +31,7 @@ 5. Torch Tensor on GPU with no batch equal 6. Torch Tensor on GPU with batch equal -Each augmentation will be run 10 times, while the number of repetitions of the +Each augmentation will be run 100 times, while the number of repetitions of the timeit function (so the total is 10*repetition) can be parsed. For example: @@ -46,13 +46,12 @@ type=int, nargs="?", const=1, - help="""an integer for the number of times - an augmentation is called 10 times""", + help="an integer for the number of times an augmentation is called 100 times", ) args = parser.parse_args() n = args.repetition -print("start benchmark with ", str(n), " repetition of 10 calls") +print("start benchmark with ", str(n), " repetition of 100 calls") device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") if device.type == "cpu": device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") @@ -149,7 +148,7 @@ print("evaluating add_band_noise") s = """ -for i in range(10): +for i in range(100): xaug = aug.add_band_noise(x, ['theta',(10,20),50], 128) """ bench_dict["add_band_noise"][1] = timeit.timeit(s, sup_np, number=n) @@ -160,11 +159,11 @@ print("evaluating add_eeg_artifact") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.add_eeg_artifact(x, 128 , batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.add_eeg_artifact(x, 128 , batch_equal=True) """ bench_dict["add_eeg_artifact"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -178,18 +177,17 @@ print("evaluating add_gaussian_noise") s = """ -for i in range(10): +for i in range(100): xaug = aug.add_gaussian_noise(x) """ bench_dict["add_gaussian_noise"][1] = timeit.timeit(s, sup_np, number=n) bench_dict["add_gaussian_noise"][3] = timeit.timeit(s, sup_torch, number=n) if device.type != "cpu": bench_dict["add_gaussian_noise"][5] = timeit.timeit(s, sup_torch_gpu, number=n) -print(bench_dict["add_gaussian_noise"]) s = """ -for i in range(10): +for i in range(100): xaug = aug.add_noise_SNR(x) """ bench_dict["add_noise_SNR"][1] = timeit.timeit(s, sup_np, number=n) @@ -200,7 +198,7 @@ print("evaluating change_ref") s = """ -for i in range(10): +for i in range(100): xaug = aug.change_ref(x) """ bench_dict["change_ref"][1] = timeit.timeit(s, sup_np, number=n) @@ -211,11 +209,11 @@ print("evaluating channel_dropout") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.channel_dropout(x, 8 , batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.channel_dropout(x, 8 , batch_equal=True) """ bench_dict["channel_dropout"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -229,11 +227,11 @@ print("evaluating crop_and_resize") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.crop_and_resize(x, 10 , 2, batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.crop_and_resize(x, 10, 2, batch_equal=True) """ bench_dict["crop_and_resize"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -247,7 +245,7 @@ print("evaluating filter_bandpass") s = """ -for i in range(10): +for i in range(100): xaug = aug.filter_bandpass(x, 128) """ bench_dict["filter_bandpass"][1] = timeit.timeit(s, sup_np, number=n) @@ -258,7 +256,7 @@ print("evaluating filter_bandstop") s = """ -for i in range(10): +for i in range(100): xaug = aug.filter_bandstop(x, 128) """ bench_dict["filter_bandstop"][1] = timeit.timeit(s, sup_np, number=n) @@ -269,7 +267,7 @@ print("evaluating filter_highpass") s = """ -for i in range(10): +for i in range(100): xaug = aug.filter_highpass(x, 128) """ bench_dict["filter_highpass"][1] = timeit.timeit(s, sup_np, number=n) @@ -280,7 +278,7 @@ print("evaluating filter_lowpass") s = """ -for i in range(10): +for i in range(100): xaug = aug.filter_lowpass(x, 128) """ bench_dict["filter_lowpass"][1] = timeit.timeit(s, sup_np, number=n) @@ -291,7 +289,7 @@ print("evaluating flip_horizontal") s = """ -for i in range(10): +for i in range(100): xaug = aug.flip_horizontal(x) """ bench_dict["flip_horizontal"][1] = timeit.timeit(s, sup_np, number=n) @@ -302,7 +300,7 @@ print("evaluating flip_vertical") s = """ -for i in range(10): +for i in range(100): xaug = aug.flip_vertical(x) """ bench_dict["flip_vertical"][1] = timeit.timeit(s, sup_np, number=n) @@ -313,11 +311,11 @@ print("evaluating masking") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.masking(x, 4 , 0.4, batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.masking(x, 4, 0.4 , batch_equal=True) """ bench_dict["masking"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -331,7 +329,7 @@ print("evaluating moving_avg") s = """ -for i in range(10): +for i in range(100): xaug = aug.moving_avg(x) """ bench_dict["moving_avg"][1] = timeit.timeit(s, sup_np, number=n) @@ -342,11 +340,11 @@ print("evaluating permutation_signal") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.permutation_signal(x, 15 , 5, batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.permutation_signal(x, 15, 5 , batch_equal=True) """ bench_dict["permutation_signal"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -360,11 +358,11 @@ print("evaluating permute_channels_network") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.permute_channels(x, 35 , 'network', batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.permute_channels(x, 35 , 'network' , batch_equal=True) """ bench_dict["permute_channels_network"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -378,11 +376,11 @@ print("evaluating permute_channels") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.permute_channels(x, 35, batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.permute_channels(x, 35, batch_equal=True) """ bench_dict["permute_channels"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -396,11 +394,11 @@ print("evaluating random_FT_phase") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.random_FT_phase(x, 0.2, batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.random_FT_phase(x, 0.2, batch_equal=True) """ bench_dict["random_FT_phase"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -414,11 +412,11 @@ print("evaluating random_slope_scale") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.random_slope_scale(x, batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.random_slope_scale(x, batch_equal=True) """ bench_dict["random_slope_scale"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -432,11 +430,11 @@ print("evaluating scaling") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.scaling(x, batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.scaling(x, batch_equal=True) """ bench_dict["scaling"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -450,11 +448,11 @@ print("evaluating shift_frequency") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.shift_frequency(x, 4,128, batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.shift_frequency(x, 4, 128, batch_equal=True) """ bench_dict["shift_frequency"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -468,11 +466,11 @@ print("evaluating shift_horizontal") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.shift_horizontal(x, 0.2, 128, random_shift=True, batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.shift_horizontal(x, 0.2 , 128, batch_equal=True) """ bench_dict["shift_horizontal"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -486,7 +484,7 @@ print("evaluating shift_vertical") s = """ -for i in range(10): +for i in range(100): xaug = aug.shift_vertical(x, 2) """ bench_dict["shift_vertical"][1] = timeit.timeit(s, sup_np, number=n) @@ -497,11 +495,11 @@ print("evaluating warp_signal") s_false = """ -for i in range(10): +for i in range(100): xaug = aug.warp_signal(x, 16, batch_equal=False) """ s_true = """ -for i in range(10): +for i in range(100): xaug = aug.warp_signal(x, 16, batch_equal=True) """ bench_dict["warp_signal"][0] = timeit.timeit(s_false, sup_np, number=n) @@ -525,4 +523,7 @@ "Torch Tensor GPU BE", ], ) +Bench_Table = Bench_Table.div(n) +Bench_Table = Bench_Table.round(decimals=3) Bench_Table.to_csv("bench_table.csv") +Bench_Table.to_csv("../docs/_static/bench_table.csv") diff --git a/extra_material/bench_table.csv b/extra_material/bench_table.csv index 307ee5b..126bd73 100644 --- a/extra_material/bench_table.csv +++ b/extra_material/bench_table.csv @@ -1,26 +1,26 @@ ,Numpy Array no BE,Numpy Array BE,Torch Tensor no BE,Torch Tensor BE,Torch Tensor GPU no BE,Torch Tensor GPU BE -add_band_noise,,12.683464657515287,,1.9375724727287889,,0.3644658047705889 -add_eeg_artifact,15.919340289197862,15.475980247370899,18.084918590262532,8.774911595508456,10.680491594597697,0.5442528845742345 -add_gaussian_noise,,71.04062628932297,,35.011875729076564,,0.07071956153959036 -add_noise_SNR,,73.8854307718575,,37.64778060372919,,12.803896361030638 -change_ref,,9.273256519809365,,19.219508921727538,,4.698239802382886 -channel_dropout,7.088237315416336,3.348068093881011,17.496242636814713,0.5531537905335426,4.6412657937034965,0.06793706305325031 -crop_and_resize,107.00998647324741,168.65701461862773,383.80653088726103,39.49100451823324,101.63478635903448,1.962239345535636 -filter_bandpass,,45.50764964334667,,152.94264073763043,,2.9582033222541213 -filter_bandstop,,44.845745482482016,,145.33897778298706,,1.608458423987031 -filter_highpass,,37.997497748583555,,138.71684326790273,,1.1115937577560544 -filter_lowpass,,46.904438538476825,,129.53218055609614,,1.201936817727983 -flip_horizontal,,0.0018146727234125137,,0.023218637332320213,,0.009426970034837723 -flip_vertical,,3.3653689613565803,,0.025393596850335598,,0.008693372830748558 -masking,8.351729959249496,5.401182125322521,10.081438167020679,1.9604026200249791,4.046475914306939,0.05191785003989935 -moving_avg,,27.874576395377517,,63.69746324699372,,0.15824944339692593 -permutation_signal,12.659091288223863,26.51213189586997,26.08147291932255,1.1414739051833749,5.5677090007811785,0.07013320829719305 -permute_channels_network,43.50118824560195,8.773427788168192,166.10135567001998,2.974851111881435,43.66938316076994,0.6647974560037255 -permute_channels,9.332423039712012,8.783952502533793,34.4811232117936,3.2817837204784155,10.940687504597008,0.15529456175863743 -random_FT_phase,38.54426374472678,47.19803645182401,90.38117585889995,8.599700069054961,12.139617369510233,0.2582742003723979 -random_slope_scale,44.880009615793824,20.625416637398303,39.141109311021864,1.667090612463653,0.19796890392899513,0.09318123385310173 -scaling,5.069442653097212,5.061666578985751,0.4061941262334585,0.41580748558044434,0.02233150787651539,0.05405097547918558 -shift_frequency,47.3368937689811,69.52789301145822,97.96175611764193,12.574184968136251,6.523466225713491,0.2909739715978503 -shift_horizontal,7.665743426419795,7.3564243000000715,10.154253432527184,0.6827773712575436,2.694342313334346,0.035021417774260044 -shift_vertical,,6.996130005456507,,0.057990497909486294,,0.07085594069212675 -warp_signal,247.15335492789745,410.7101323986426,947.21491935011,86.17811584100127,221.64730475377291,4.44210946932435 +add_band_noise,,1.419,,0.133,,0.046 +add_eeg_artifact,1.595,1.724,1.76,0.396,1.119,0.058 +add_gaussian_noise,,7.395,,1.68,,0.007 +add_noise_SNR,,7.695,,1.938,,1.344 +change_ref,,0.767,,0.729,,0.498 +channel_dropout,0.779,0.363,0.455,0.028,0.502,0.007 +crop_and_resize,11.479,17.503,15.851,3.434,10.389,0.2 +filter_bandpass,,4.716,,5.786,,1.035 +filter_bandstop,,4.793,,5.93,,0.166 +filter_highpass,,3.431,,5.961,,0.107 +filter_lowpass,,4.365,,6.036,,0.115 +flip_horizontal,,0.0,,0.022,,0.001 +flip_vertical,,0.358,,0.023,,0.001 +masking,0.815,0.548,0.495,0.039,0.432,0.006 +moving_avg,,3.066,,2.661,,0.009 +permutation_signal,1.055,1.305,0.951,0.242,0.599,0.008 +permute_channels_network,4.667,0.954,6.435,0.157,4.698,0.072 +permute_channels,0.905,0.877,1.237,0.075,1.148,0.016 +random_FT_phase,4.014,4.963,2.819,0.485,1.272,0.026 +random_slope_scale,3.849,2.251,1.783,0.2,0.012,0.009 +scaling,0.562,0.56,0.046,0.046,0.002,0.002 +shift_frequency,5.155,7.262,3.173,0.757,0.692,0.029 +shift_horizontal,0.797,0.787,0.389,0.056,0.279,0.004 +shift_vertical,,0.712,,0.061,,0.002 +warp_signal,25.965,43.759,35.951,7.205,23.263,0.467 diff --git a/selfeeg/augmentation/__init__.py b/selfeeg/augmentation/__init__.py index e7d09b7..f232afd 100644 --- a/selfeeg/augmentation/__init__.py +++ b/selfeeg/augmentation/__init__.py @@ -2,7 +2,7 @@ This is the import of the data augmentation module """ -from .compose import DynamicSingleAug, RandomAug, SequentialAug, StaticSingleAug +from .compose import CircularAug, DynamicSingleAug, RandomAug, SequentialAug, StaticSingleAug from .functional import ( add_band_noise, add_eeg_artifact, diff --git a/selfeeg/augmentation/compose.py b/selfeeg/augmentation/compose.py index 923f951..772ff9d 100644 --- a/selfeeg/augmentation/compose.py +++ b/selfeeg/augmentation/compose.py @@ -8,7 +8,13 @@ import numpy as np from numpy.typing import ArrayLike -__all__ = ["DynamicSingleAug", "RandomAug", "SequentialAug", "StaticSingleAug"] +__all__ = [ + "CircularAug", + "DynamicSingleAug", + "RandomAug", + "SequentialAug", + "StaticSingleAug", +] class StaticSingleAug: @@ -19,7 +25,7 @@ class StaticSingleAug: where the optional arguments are previously set and given during initialization. No random choice of the arguments is performed. The class accepts multiple set of optional arguments. In this case they are called individually at each class - call, in a sequential and cyclic manner. This means that the first call uses the first set of + call, in a circular manner. This means that the first call uses the first set of arguments, the second will use the second set of arguments, and so on. When the last set of arguments is used, the class will restart from the first set of arguments. @@ -369,7 +375,7 @@ class SequentialAug: Note ---- - If you provide an augmentation implemented outside of this this library be + If you provide an augmentation implemented outside of this this library, be sure that the function will return a single output with the element to pass to the next augmentation function of the list. @@ -444,6 +450,82 @@ def _search_for_random_aug_with_index(self): item._search_for_random_aug_with_index() +class CircularAug: + """ + Single Augmenter called sequentially from a list, following a circular order. + + ``CircularAug`` calls an Augmenter from a given sequence following the order. + Augmenters are called circularly. This means that the first call uses the first + Augmenter from the input list, the second call will use the second, and so on. + When the last Augmenter is called, the class will restart from the + first one. + + To perform an augmentation, simply call the instantiated class + (see provided example or check the introductory notebook) + + Parameters + ---------- + *augmentations: "callable objects" + The list of augmentations to apply at each call. + It can be any callable object, but the first argument to pass must be + the element to augment. It is suggested to give a sequence of + ``StaticSingleAug`` or ``DynamicSingleAug`` instantiations. + + Note + ---- + The function will automatically handle RandomAug instances with return_index + set to True. In this case, an internal deepcopy with return_index set to false + will be automatically created. + + Methods + ------- + perform_augmentation(X: ArrayLike) + Apply the augmentations with the given arguments and specified order. + __call__() will call this method. + + + Example + ------- + >>> import selfeeg.augmentation as aug + >>> import torch + >>> BatchEEG = torch.zeros(16,32,1024) + torch.sin(torch.linspace(0, 8*np.pi,1024)) + >>> Aug_eye = aug.StaticSingleAug( + ... aug.add_eeg_artifact,{'Fs': 64, 'artifact': 'eye', 'amplitude': 0.5}) + >>> Circular = aug.CircularAug(Aug_eye, aug.identity) + >>> EEGeye = Circular(BatchEEG) + >>> EEGid = Circular(BatchEEG) + + """ + + def __init__(self, *augmentations): + self.augs = [item for item in augmentations] + self._augcnt = 0 + self._augnumber = len(self.augs) + self._search_for_random_aug_with_index() + + def perform_augmentation(self, X: ArrayLike) -> ArrayLike: + Xaugs = self.augs[self._augcnt](X) + self._update_counter() + return Xaugs + + def __call__(self, X): + return self.perform_augmentation(X) + + def _update_counter(self): + self._augcnt += 1 + if self._augcnt == self._augnumber: + self._augcnt = 0 + + def _search_for_random_aug_with_index(self): + for idx, item in enumerate(self.augs): + if isinstance(item, RandomAug): + if self.augs[idx].return_index == True: + self.augs[idx] = copy.deepcopy(item) + self.augs[idx].return_index = False + elif isinstance(item, SequentialAug): + item._search_for_random_aug_with_index() + + class RandomAug: """ Random augmentation chosen from a given set. diff --git a/selfeeg/augmentation/functional.py b/selfeeg/augmentation/functional.py index 85c1e9d..d0a1dc3 100755 --- a/selfeeg/augmentation/functional.py +++ b/selfeeg/augmentation/functional.py @@ -603,24 +603,34 @@ def add_noise_SNR( """ + N = len(x.shape) + if N == 1: + axis = 0 + new_size = (x.shape[-1],) + elif N == 2: + axis = 1 + new_size = (1, x.shape[-1]) + else: + axis = -1 + new_size = (*x.shape[0:-2], 1, x.shape[-1]) + # get signal power. Not exactly true since we have an already noised signal - x_pow = x**2 + factor = 10 ** (-target_snr / 20) if isinstance(x, np.ndarray): - x_pow_avg = np.mean(x_pow) - x_db_avg = 10 * np.log10(x_pow_avg) - noise_db_avg = x_db_avg - target_snr - noise_pow_avg = 10 ** (noise_db_avg / 10) - noise = np.random.normal(0, noise_pow_avg**0.5, size=x.shape) - x_noise = x + noise - + xpow = np.power(x, 2) + xpow = np.mean(xpow, keepdims=True) + xpow = np.sqrt(xpow) + noise = np.random.randn(*new_size) else: - x_pow_avg = torch.mean(x_pow) - x_db_avg = 10 * torch.log10(x_pow_avg) - noise_db_avg = x_db_avg - target_snr - noise_pow_avg = 10 ** (noise_db_avg / 10) - noise = (noise_pow_avg**0.5) * (torch.randn(*x.shape).to(device=x.device)) - x_noise = x + noise + xpow = torch.pow(x, 2) + xpow = torch.mean(xpow, axis, keepdim=True) + xpow = torch.sqrt(xpow) + noise = torch.randn(size=new_size).to(device=x.device) + + noise = factor * noise + noise = xpow * noise + x_noise = x + noise if get_noise: return x_noise, noise @@ -873,7 +883,7 @@ def random_slope_scale( x: ArrayLike, min_scale: float = 0.9, max_scale: float = 1.2, - batch_equal: bool = False, + batch_equal: bool = True, keep_memory: bool = False, ) -> ArrayLike: """ @@ -907,7 +917,7 @@ def random_slope_scale( Whether to apply the same rescale to all EEGs in the batch or not. This apply only if x has more than 2 dimensions, i.e. more than 1 EEG. - Default: False + Default: True keep_memory: bool, optional Whether to keep memory of the previous changes in slope and accumulate them during the transformation or not. Basically, instead of using: @@ -1225,16 +1235,18 @@ def get_filter_coeff( Parameters ---------- - Wp: float - Bandpass in Hz. - Ws: float - Stopband in Hz. + Wp: float or ArrayLike + Passband edges in Hz. It can be a float for lowpass and highpass filters, + or a length 2 scalar vector for bandpass and stopband filters. + Ws: float or ArrayLike + Stopband edges in Hz. It can be a float for lowpass and highpass filters, + or a length 2 scalar vector for bandpass and stopband filters. rp: float, optional - Ripple at bandpass in decibel. + Ripple at bandpass in dB. Default = -20*log10(0.95) rs: float, optional - Ripple at stopband in decibel. + Ripple at stopband in dB. Default = -20*log10(0.15) btype: str, optional @@ -1251,7 +1263,7 @@ def get_filter_coeff( The order of the filter. Default = None - Wn: array_like, optional + Wn: ArrayLike, optional The critical frequency or frequencies. Default = None @@ -1390,19 +1402,19 @@ def filter_lowpass( Fs: float The sampling frequency in Hz. Wp: float, optional - Bandpass in Hz. + Passband edge in Hz. Default = 50 Ws: float, optional - Stopband in Hz. + Stopband edge in Hz. Default = 70 rp: float, optional - Ripple at bandpass in decibel. + Ripple at bandpass in dB. Default = -20*log10(0.95) rs: float, optional - Ripple at stopband in decibel. + Ripple at stopband in dB. Default = -20*log10(0.15) filter_type: str, optional @@ -1442,10 +1454,10 @@ def filter_lowpass( Note ---- - Lots of parameters are the ones used to call scipy's matlab style filters, - aside to **Wp** and **Ws** which you must give directly in Hz. - The normalization to [0,1] with respect to the half-cycles / sample - (i.e. Nyquist frequency) is done directly inside the ``get_filter_coeff`` + Many parameters are those used in scipy's implementation of Matlab-style + filters, except for **Wp** and **Ws**, which must be specified directly in Hz. + The normalization to [0,1] with respect to half-cycles/sample + (i.e., Nyquist frequency) is done directly inside the ``get_filter_coeff`` function. Note @@ -1466,13 +1478,13 @@ def filter_lowpass( >>> f, per1 = periodogram(x[0,0], 128) >>> xaug = aug.filter_lowpass(x, 128, 20, 30) >>> f, per2 = periodogram(xaug[0,0], 128) - >>> print(np.isclose(np.max(per2[f>30]), 0, rtol=1e-04, atol=1e-04)) #should return True + >>> print(np.isclose(np.max(per2[f>30]), 0, rtol=1e-04, atol=1e-04)) """ if filter_type not in ["butter", "ellip", "cheby1", "cheby2"]: raise ValueError( - "filter type not supported. Choose between butter, " "elliptic, cheby1, cheby2" + "filter type not supported. Choose between butter, elliptic, cheby1, cheby2" ) if (a is None) or (b is None): @@ -1535,19 +1547,19 @@ def filter_highpass( The last two dimensions must refer to the EEG recording (Channels x Samples). Wp: float, optional - Bandpass in Hz. + Passband edge in Hz. Default = 30 Ws: float, optional - Stopband in Hz. + Stopband edge in Hz. Default = 13 rp: float, optional - Ripple at bandpass in decibel. + Ripple at bandpass in dB. Default = -20*log10(0.95) rs: float, optional - Ripple at stopband in decibel. + Ripple at stopband in dB. Default = -20*log10(0.15) filter_type: str, optional @@ -1587,11 +1599,11 @@ def filter_highpass( Note ---- - Lots of parameters are the ones used to call scipy's matlab style filters, - aside to **Wp** and **Ws** which you must give directly in Hz. - The normalization to [0,1] with respect to the half-cycles / sample - (i.e. Nyquist frequency) is done directly inside the ``get_filter_coeff`` - function. + Many parameters are those used in scipy's implementation of Matlab-style + filters, except for **Wp** and **Ws**, which must be specified directly in Hz. + The normalization to [0,1] with respect to half-cycles/sample + (i.e., Nyquist frequency) is done directly inside the + ``get_filter_coeff`` function. Note ---- @@ -1674,8 +1686,8 @@ def filter_bandpass( Therefore the arguments closer to a and b in the scheme are used to get the filter coefficient. - If ``eeg_band`` is given, (Wp,Ws,rp,rs) are bypassed and instantiated according - to the eeg band specified. The priority order remains, so if (Wn,order) or + If ``eeg_band`` is given, (Wp,Ws,rp,rs) are ignored and instantiated according + to the EEG band specified. The priority order remains. So, if (Wn, order) or (a,b) are given, the filter will be created according to such argument. Parameters @@ -1686,20 +1698,20 @@ def filter_bandpass( (Channels x Samples). Fs: float the sampling frequency in Hz. - Wp: float, optional - Bandpass in Hz. + Wp: ArrayLike, optional + Passband edges in Hz. It must be a length 2 scalar vector. Default = None - Ws: float, optional - Stopband in Hz. + Ws: ArrayLike, optional + Stopband edges in Hz. It must be a length 2 scalar vector. Default = None rp: float, optional - Ripple at bandpass in decibel. + Ripple at bandpass in dB. Default = -20*log10(0.95) rs: float, optional - Ripple at stopband in decibel. + Ripple at stopband in dB. Default = -20*log10(0.15) filter_type: str, optional @@ -1711,15 +1723,15 @@ def filter_bandpass( The order of the filter. Default = None - Wn: array_like, optional + Wn: ArrayLike, optional The critical frequency or frequencies. Default = None - a: array_like, optional + a: ArrayLike, optional The denominator coefficients of the filter Default = None - b: array_like, optional + b: ArrayLike, optional The numerator coefficients of the filer Default = None @@ -1745,11 +1757,11 @@ def filter_bandpass( Note ---- - Lots of parameters are the ones used to call scipy's matlab style filters, - aside to **Wp** and **Ws** which you must give directly in Hz. - The normalization to [0,1] with respect to the half-cycles / sample - (i.e. Nyquist frequency) is done directly inside the ``get_filter_coeff`` - function. + Many parameters are those used in scipy's implementation of Matlab-style + filters, except for **Wp** and **Ws**, which must be specified directly in Hz. + The normalization to [0,1] with respect to half-cycles/sample + (i.e., Nyquist frequency) is done directly inside the + ``get_filter_coeff`` function. Note ---- @@ -1847,19 +1859,19 @@ def filter_bandstop( Fs: float The sampling frequency in Hz. Wp: float, optional - Bandpass in Hz. + Passband edges in Hz. It must be a length 2 scalar vector. Default = 30 Ws: float, optional - Stopband in Hz. + Stopband edges in Hz. It must be a length 2 scalar vector. Default = 13 rp: float, optional - Ripple at bandpass in decibel. + Ripple at bandpass in dB. Default = -20*log10(0.95) rs: float, optional - Ripple at stopband in decibel. + Ripple at stopband in dB. Default = -20*log10(0.15) filter_type: str, optional @@ -1871,15 +1883,15 @@ def filter_bandstop( The order of the filter. Default = None - Wn: array_like, optional + Wn: ArrayLike, optional The critical frequency or frequencies. Default = None - a: array_like, optional + a: ArrayLike, optional The denominator coefficients of the filter. Default = None - b: array_like, optional + b: ArrayLike, optional The numerator coefficients of the filer. Default = None @@ -1905,10 +1917,10 @@ def filter_bandstop( Note ---- - Lots of parameters are the ones used to call scipy's matlab style filters, - aside to **Wp** and **Ws** which you must give directly in Hz. - The normalization to [0,1] with respect to the half-cycles / sample - (i.e. Nyquist frequency) is done directly inside the + Many parameters are those used in scipy's implementation of Matlab-style + filters, except for **Wp** and **Ws**, which must be specified directly in Hz. + The normalization to [0,1] with respect to half-cycles/sample + (i.e., Nyquist frequency) is done directly inside the ``get_filter_coeff`` function. Note @@ -2202,7 +2214,7 @@ def permute_channels( mode: str = "random", channel_map: list = None, chan_net: list[str] = "all", - batch_equal: bool = False, + batch_equal: bool = True, ) -> ArrayLike: """ permutes the ArrayLike object along the EEG channel dimension. @@ -2265,7 +2277,7 @@ def permute_channels( If True, permute_signal is called recursively in order to permute each EEG differently. - Default = False + Default = True Returns ------- @@ -2408,7 +2420,7 @@ def permute_channels( def permutation_signal( - x: ArrayLike, segments: int = 10, seg_to_per: int = -1, batch_equal: bool = False + x: ArrayLike, segments: int = 10, seg_to_per: int = -1, batch_equal: bool = True ) -> ArrayLike: """ permutes some portions of the ArrayLike object along its last dimension. @@ -2439,7 +2451,7 @@ def permutation_signal( If True, the function is called recursively in order to apply a different permutation to all EEGs. - Default = False + Default = True Returns ------- @@ -2519,7 +2531,7 @@ def warp_signal( segments: int = 10, stretch_strength: float = 2.0, squeeze_strength: float = 0.5, - batch_equal: bool = False, + batch_equal: bool = True, ) -> ArrayLike: """ stretches and squeezes portions of the ArrayLike object. @@ -2555,7 +2567,7 @@ def warp_signal( batch_equal: bool, optional Whether to apply the same warp to all records or not. - Default = False + Default = True Returns ------- @@ -2634,7 +2646,7 @@ def crop_and_resize( x: ArrayLike, segments: int = 10, N_cut: int = 1, - batch_equal: bool = False, + batch_equal: bool = True, ) -> ArrayLike: """ crops some segments of the ArrayLike object. @@ -2667,7 +2679,8 @@ def crop_and_resize( Default = 1 batch_equal: bool, optional Whether to apply the same crop to all EEG record or not. - Default = False + + Default = True Returns ------- @@ -2866,7 +2879,7 @@ def change_ref( def masking( - x: ArrayLike, mask_number: int = 1, masked_ratio: float = 0.1, batch_equal: bool = False + x: ArrayLike, mask_number: int = 1, masked_ratio: float = 0.1, batch_equal: bool = True ) -> ArrayLike: """ puts to zero random portions of the ArrayLike object. @@ -2899,7 +2912,7 @@ def masking( Whether to apply the same masking to all elements in the batch or not. It does apply only if x has more than 2 dimensions. - Default = False + Default = True Returns ------- @@ -3043,7 +3056,7 @@ def add_eeg_artifact( line_at_60Hz: bool = True, lost_time: float = None, drift_slope: float = None, - batch_equal: bool = False, + batch_equal: bool = True, ) -> ArrayLike: """ add common EEG artifacts to the ArrayLike object. @@ -3112,7 +3125,7 @@ def add_eeg_artifact( Whether to apply the same masking to all elements in the batch or not. Does apply only if x has more than 2 dimensions - Default = False + Default = True Returns ------- diff --git a/selfeeg/dataloading/load.py b/selfeeg/dataloading/load.py index 52485c7..73bde9f 100644 --- a/selfeeg/dataloading/load.py +++ b/selfeeg/dataloading/load.py @@ -47,126 +47,116 @@ def get_eeg_partition_number( verbose: bool = False, ) -> pd.DataFrame: """ - finds the number of unique partitions from each EEG signal. - - The function is applied to each EEG stored inside a given input directory. - Some default parameters are designed to work with - the 'BIDSAlign' library. For more info, see [bids]_ . - To further check how to use this function see the introductory - notebook provided in the documentation. + Calculates the number of unique partitions in each EEG signal. + This function processes each EEG file stored in a specified input directory. + It is designed with default parameters that are compatible with the + 'BIDSAlign' library. For additional information, see [1]_. + For a comprehensive guide on how to use this function, refer to the + introductory notebook included in the documentation. Parameters ---------- EEGpath : str - Directory with all EEG files. If the last element of the string is not - "/", the character will be added automatically. + The directory containing all EEG files. + If the string does not end with a "/", + the character will be added automatically. freq : int or float, optional - EEG sampling rate. Must be the same for all EEG files. + The EEG sampling rate, which must be consistent across all EEG files. - Default = 250 + Default = 250. window : int or float, optional - The window length given in seconds. + The length of the time window, specified in seconds. - Default = 2 + Default = 2. overlap : float, optional - Same EEG contiguous partitions overlap in percentage. - Must be in the interval [0,1). + The percentage overlap between contiguous EEG partitions. + This value must be in the interval [0, 1). - Default = 0.1 + Default = 0.1. includePartial : bool, optional - Whether to also count the final EEG portions which could potentially cover - at least half of the time windows. In this case the overlap between the - last included partition and the previous one will increase in order - to fill the incomplete partition with real recorded values. - Note that this apply only if at - least half of such partition will include new values. + Indicates whether to count the final portions of the EEG that may cover + at least half of the time windows. If this option is enabled, the overlap + between the last included partition and the previous one will be adjusted + to incorporate real recorded values, provided at least half of the + partition includes new data. - Default = True + Default = True. file_format : str or list[str], optional - A string used to detect a set of specific EEG files inside the given - EEGpath. It is directly put after ``EEGpath`` during call of the - glob.glob() method. Therefore, it can contain shell-style wildcards - (see glob.glob() help for more info). This parameter might be helpful - if you have other files other than the EEGs in your directory. - Alternatively, you can provide a list of strings to cast multiple - glob.glob() searches. This might be useful if you want to combine - multiple identification criteria (e.g. specific file extensions, - specific file names, etc.) - - Default = '*' - load_function : 'function', optional - A custom EEG file loading function. It will be used instead of the default: - - ``loadmat(ii, simplify_cells=True)['DATA_STRUCT']['data']`` - - which is the default output format for files preprocessed with the - BIDSalign library. The function must take only one required argument, - which is the full path to the EEG file (e.g. the function will be called - in this way: load_function(fullpath, optional_arguments) ). - - Default = None - optional_load_fun_args: list or dict, optional - Optional arguments to give to the custom loading function. - Can be a list or a dict. - - Default = None - transform_function : 'function', optional - A custom transformation to be applied after the EEG is loaded. - Might be useful if there are portions of the signal to cut - (usually the initial or the final). The function must take only one - required argument, which is the loaded EEG file to transform - (e.g. the function will be called in this way: - transform_function(EEG, optional_arguments) ). - - Default = None - optional_transform_fun_args: list or dict, optional - Optional arguments to give to the EEG transformation function. - Can be a list or a dict. - - Default = None + A string or list of strings used to filter specific EEG files in the + provided EEGpath. This is used directly in the `glob.glob()` method + and can include shell-style wildcards + (refer to the glob.glob() documentation for details). + This option is useful if there are other file types in the directory. + + Default = '*'. + load_function : function, optional + A custom function for loading EEG files, which will override the default: + + ``loadmat(ii, simplify_cells=True)['DATA_STRUCT']['data']``. + + The function must accept one required argument: + the full path to the EEG file + (e.g., it will be called as: load_function(fullpath, optional_arguments)). + + Default = None. + optional_load_fun_args : list or dict, optional + Additional arguments to pass to the custom loading function. + This can be specified as a list or a dictionary. + + Default = None. + transform_function : function, optional + A custom transformation function to apply after loading the EEG data. + This may be useful for trimming portions of the signal + (usually the beginning or end). The function must accept one required + argument: the loaded EEG file (e.g., it + will be called as: transform_function(EEG, optional_arguments)). + + Default = None. + optional_transform_fun_args : list or dict, optional + Additional arguments to pass to the EEG transformation function. + This can be specified as a list or a dictionary. + Default = None. keep_zero_sample : bool, optional - Whether to preserve Dataframe's rows with calculated - zero number of samples or not. + Specifies whether to retain DataFrame rows with a calculated zero + number of samples. - Default = True + Default = True. save : bool, optional - Whether to save the resulting DataFrame as a .csv file. - - Default = False - save_path: str, optional - A custom path to be used instead of the current working directory. - It is the string given to the ``pandas.DataFrame.to_csv()`` method. - Note that if save is True and no save_path is given, the file will - be saved in the current working directory as `EEGPartitionNumber.csv` or - `EEGPartitionNumber_k.csv` with k integer used to avoid overwriting a file. - - Default = None - verbose: bool, optional - whether to print some information during function excecution. - Useful to keep track of the calculation process. Might be useful for large - datasets. + Indicates whether to save the resulting DataFrame as a .csv file. + Default = False. + save_path : str, optional + A custom path for saving the .csv file instead of using the current + working directory. This string is passed to the `pandas.DataFrame.to_csv()` + method. If save is True and no save_path is provided, the file will + be saved as `EEGPartitionNumber_k.csv`, where k is an integer to + prevent overwriting. + + Default = None. + verbose : bool, optional + Controls whether to print information during function execution, which can + be helpful for tracking progress, especially with large datasets. + + Default = False. Returns ------- lenEEG : DataFrame - Three columns Pandas DataFrame. - The first column has the full path to the EEG file, - the second the file name, the third its number of partitions. - - Note - ---- - freq*window must give an integer with the number of samples. - - Note - ---- - This function can handle array with more than 2 dimensions. In this case a - warning will be generated and calculation will be performed in this way. - First the length of the last dimension will be used to calculate the number - of partitions, then the number will be multiplied by the product of the shape - of all dimensions from the first to the second to last (the last two - dimensions are supposed to be the Channel and Sample dimension of a single - EEG file). + A three-column Pandas DataFrame containing: + - The full path to the EEG files in the first column, + - The file names in the second column, + - The number of partitions in the third column. + + Notes + ----- + - The product of `freq` and `window` must yield an integer representing + the number of samples. + - This function can handle arrays with more than two dimensions. + In such cases, a warning is issued, and the calculation proceeds as follows: + the length of the last dimension is used to determine the number of + partitions, which is then multiplied by the product of the shapes of + all preceding dimensions (the last two dimensions should correspond to + channel and sample dimensions of a single EEG file). Example ------- @@ -186,7 +176,10 @@ def get_eeg_partition_number( References ---------- - .. [bids] https://github.com/MedMaxLab/BIDSAlign + .. [1] Zanola et al "BIDSAlign: a library for automatic merging and + preprocessing of multiple EEG repositories." + doi: https://doi.org/10.1088/1741-2552/ad6a8c. + GitHub repository: https://github.com/MedMaxLab/BIDSAlign """ # Check Inputs diff --git a/selfeeg/models/encoders.py b/selfeeg/models/encoders.py index afe355c..a64b67e 100644 --- a/selfeeg/models/encoders.py +++ b/selfeeg/models/encoders.py @@ -854,9 +854,13 @@ class ResNet1DEncoder(nn.Module): An nn.Module defining the resnet block. Layers: list of 4 int, optional A list of integers indicating the number of times - the resnet block is repeated . + the resnet block is repeated at each stage. + It must be a list of length 4 with positive integers. + Shorter lists are padded to 1 on the right. + Only the first four elements of longer lists are considered. + Zeros are changed to 1. - Default = [2,2,2,2] + Default = [2, 2, 2, 2] inplane: int, optional The number of output filters. kernLength: int, optional @@ -925,8 +929,18 @@ def __init__( self.inplane = inplane self.kernLength = kernLength self.connection = addConnection - - # PRE-RESIDUAL + # Checks on the Layer list + if len(Layers) < 4: + Layers = Layers + [1 for _ in range(4 - len(Layers))] + else: + Layers = Layers[:4] + if any(Layers): + Layers = [i if i > 0 else 1 for i in Layers] + for i in Layers: + if not isinstance(i, int): + raise ValueError("Layers must be a length 4 list of positive integers") + + # PRE-RESIDUAL if preBlock is None: self.preBlocks = nn.Sequential( nn.Conv2d( @@ -959,17 +973,20 @@ def __init__( # POST-RESIDUAL if postBlock is None: - self.postBlocks = nn.Sequential( - nn.Conv2d( - self.inplane, - inplane, - kernel_size=(1, kernLength), - stride=(1, 1), - padding=(0, 0), - bias=False, - ), - nn.AdaptiveAvgPool2d((Chans, 1)), - ) + if self.connection: + self.postBlocks = nn.Sequential( + nn.Conv2d( + self.inplane, + inplane, + kernel_size=(1, kernLength), + stride=(1, 1), + padding=(0, 0), + bias=False, + ), + nn.AdaptiveAvgPool2d((Chans, 1)), + ) + else: + self.postBlocks = nn.AdaptiveAvgPool2d((Chans, 1)) else: self.postBlocks = postBlock diff --git a/selfeeg/models/zoo.py b/selfeeg/models/zoo.py index 7848ed7..63f07aa 100755 --- a/selfeeg/models/zoo.py +++ b/selfeeg/models/zoo.py @@ -809,9 +809,13 @@ class ResNet1D(nn.Module): Default: selfeeg.models.BasicBlock1 Layers: list of 4 int, optional A list of integers indicating the number of times the resnet block - is repeated. + is repeated at each stage. + It must be a list of length 4 with positive integers. + Shorter lists are padded to 1 on the right. + Only the first four elements of longer lists are considered. + Zeros are changed to 1. - Default = [2,2,2,2] + Default = [2, 2, 2, 2] inplane: int, optional The number of output filters. @@ -889,7 +893,7 @@ def __init__( Chans, Samples, block: nn.Module = BasicBlock1, - Layers: "list of 4 int" = [0, 0, 0, 0], + Layers: "list of 4 int" = [2, 2, 2, 2], inplane: int = 16, kernLength: int = 7, addConnection: bool = False, diff --git a/setup.py b/setup.py index b871b90..ce2bc67 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,6 @@ "scipy>=1.10.1", "torch>=2.0.0", "torchaudio>=2.0.2", - "torchvision>=0.15.2", "tqdm", ] diff --git a/test/EEGself/augmentation/compose_test.py b/test/EEGself/augmentation/compose_test.py index 61b29d4..e35b8d6 100644 --- a/test/EEGself/augmentation/compose_test.py +++ b/test/EEGself/augmentation/compose_test.py @@ -81,6 +81,21 @@ def test_SequentialAug(self): self.assertTrue(torch.equal(BatchEEGaug1, BatchEEGaug2 * (-1))) print(" Sequential augmentation OK") + def test_CircularAug(self): + print("Testing Circular augmentation...", end="", flush=True) + Circular = aug.CircularAug(aug.flip_vertical, aug.identity) + BatchEEGaug1 = Circular(self.BatchEEG) + self.assertTrue(torch.equal(BatchEEGaug1, self.BatchEEG * (-1))) + BatchEEGaug1 = Circular(self.BatchEEG) + self.assertTrue(torch.equal(BatchEEGaug1, self.BatchEEG)) + + # repeat to check Circular calls + BatchEEGaug1 = Circular(self.BatchEEG) + self.assertTrue(torch.equal(BatchEEGaug1, self.BatchEEG * (-1))) + BatchEEGaug1 = Circular(self.BatchEEG) + self.assertTrue(torch.equal(BatchEEGaug1, self.BatchEEG)) + print(" Circular augmentation OK") + def test_RandomAug(self): print("Testing Random augmentation...", end="", flush=True) Aug_scal = aug.StaticSingleAug(aug.scaling, {"value": 2, "batch_equal": True}) @@ -97,7 +112,7 @@ def test_RandomAug(self): self.assertTrue(abs(counter[1] - 0.3) < 1e-2) print(" Random augmentation OK") - def test_UltimateAugmentationComposition(self): + def test_AugmentationComposition(self): print( "Testing final augmentation composition based on all previous classes...", end="", diff --git a/test/EEGself/augmentation/functional_test.py b/test/EEGself/augmentation/functional_test.py index 680c252..92daed1 100644 --- a/test/EEGself/augmentation/functional_test.py +++ b/test/EEGself/augmentation/functional_test.py @@ -322,10 +322,10 @@ def test_add_noise_SNR(self): for i in aug_args: xaug = aug.add_noise_SNR(**i) - x = torch.zeros(16, 32, 1024) + torch.sin(torch.linspace(0, 8 * np.pi, 1024)) + x = torch.zeros(16, 32, 512) + torch.sin(torch.linspace(0, 8 * np.pi, 512)) xaug, noise = aug.add_noise_SNR(x, 10, get_noise=True) SNR = 10 * torch.log10(((x**2).sum().mean()) / ((noise**2).sum().mean())) - self.assertTrue(math.isclose(SNR, 10, rel_tol=1e-2)) + self.assertTrue(math.isclose(SNR, 10, rel_tol=1e-1)) print(" noise SNR OK: tested", N + len(aug_args), "combinations of input arguments") def test_add_band_noise(self): From dd0ed67410b1ffe7896c5d135e89e53f2966d73f Mon Sep 17 00:00:00 2001 From: fedepup Date: Mon, 17 Feb 2025 15:45:09 +0100 Subject: [PATCH 2/2] solve resnet linear layer bug 2 --- selfeeg/models/encoders.py | 1 - selfeeg/models/zoo.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/selfeeg/models/encoders.py b/selfeeg/models/encoders.py index a64b67e..0924ffb 100644 --- a/selfeeg/models/encoders.py +++ b/selfeeg/models/encoders.py @@ -1049,7 +1049,6 @@ def forward(self, x): embeddings = torch.cat((out1, out2), dim=-1) else: embeddings = out1 - return embeddings diff --git a/selfeeg/models/zoo.py b/selfeeg/models/zoo.py index 63f07aa..3f79d2b 100755 --- a/selfeeg/models/zoo.py +++ b/selfeeg/models/zoo.py @@ -926,7 +926,9 @@ def __init__( 1 if nb_classes <= 2 else nb_classes, ) else: - self.Dense = nn.Linear(Chans * inplane, 1 if nb_classes <= 2 else nb_classes) + self.Dense = nn.Linear( + Chans * self.encoder.inplane, 1 if nb_classes <= 2 else nb_classes + ) else: self.Dense = classifier