From 77d070a26a23318e8ff1fd89bfc735bf109668b5 Mon Sep 17 00:00:00 2001 From: Bru Date: Fri, 5 Jan 2024 07:23:02 -0300 Subject: [PATCH] Creating K-Subjects out evaluations (#470) * Creating K-Subjects out * Fixing the tqdm * [pre-commit.ci] auto fixes from pre-commit.com hooks * Removing code not use * Updating whats_new.rst * Fixing whats new * fixing again --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source/whats_new.rst | 1 + moabb/evaluations/base.py | 2 ++ moabb/evaluations/evaluations.py | 11 ++++++++++- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index b361f18c5..6a322032b 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -20,6 +20,7 @@ Enhancements - Adding cache option to the evaluation (:gh:`517` by `Bruno Aristimunha`_) - Option to interpolate channel in paradigms' `match_all` method (:gh:`480` by `Gregoire Cattan`_) +- Adding leave k-Subjects out evaluations (:gh:`470` by `Bruno Aristimunha`_) Bugs ~~~~ diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index b9aafc5f4..b99eefaaa 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -66,6 +66,7 @@ def __init__( return_epochs=False, return_raws=False, mne_labels=False, + n_splits=None, save_model=False, cache_config=None, ): @@ -77,6 +78,7 @@ def __init__( self.return_epochs = return_epochs self.return_raws = return_raws self.mne_labels = mne_labels + self.n_splits = n_splits self.save_model = save_model self.cache_config = cache_config # check paradigm diff --git a/moabb/evaluations/evaluations.py b/moabb/evaluations/evaluations.py index b3fce6b64..fc11fb7b0 100644 --- a/moabb/evaluations/evaluations.py +++ b/moabb/evaluations/evaluations.py @@ -8,6 +8,7 @@ from sklearn.base import clone from sklearn.metrics import get_scorer from sklearn.model_selection import ( + GroupKFold, LeaveOneGroupOut, StratifiedKFold, StratifiedShuffleSplit, @@ -632,6 +633,9 @@ class CrossSubjectEvaluation(BaseEvaluation): use MNE raw to train pipelines. mne_labels: bool, default=False if returning MNE epoch, use original dataset label if True + n_splits: int, default=None + Number of splits for cross-validation. If None, the number of splits + is equal to the number of subjects. """ # flake8: noqa: C901 @@ -675,7 +679,12 @@ def evaluate( scorer = get_scorer(self.paradigm.scoring) # perform leave one subject out CV - cv = LeaveOneGroupOut() + if self.n_splits is None: + cv = LeaveOneGroupOut() + else: + cv = GroupKFold(n_splits=self.n_splits) + n_subjects = self.n_splits + inner_cv = StratifiedKFold(3, shuffle=True, random_state=self.random_state) # Implement Grid Search