Skip to content

Commit

Permalink
Merge branch 'develop' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
sylvchev authored Jan 7, 2024
2 parents 66cb632 + 77d070a commit 89d5cb3
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Enhancements

- Adding cache option to the evaluation (:gh:`517` by `Bruno Aristimunha`_)
- Option to interpolate channel in paradigms' `match_all` method (:gh:`480` by `Gregoire Cattan`_)
- Adding leave k-Subjects out evaluations (:gh:`470` by `Bruno Aristimunha`_)

Bugs
~~~~
Expand Down
2 changes: 2 additions & 0 deletions moabb/evaluations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
return_epochs=False,
return_raws=False,
mne_labels=False,
n_splits=None,
save_model=False,
cache_config=None,
):
Expand All @@ -77,6 +78,7 @@ def __init__(
self.return_epochs = return_epochs
self.return_raws = return_raws
self.mne_labels = mne_labels
self.n_splits = n_splits
self.save_model = save_model
self.cache_config = cache_config
# check paradigm
Expand Down
11 changes: 10 additions & 1 deletion moabb/evaluations/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.base import clone
from sklearn.metrics import get_scorer
from sklearn.model_selection import (
GroupKFold,
LeaveOneGroupOut,
StratifiedKFold,
StratifiedShuffleSplit,
Expand Down Expand Up @@ -632,6 +633,9 @@ class CrossSubjectEvaluation(BaseEvaluation):
use MNE raw to train pipelines.
mne_labels: bool, default=False
if returning MNE epoch, use original dataset label if True
n_splits: int, default=None
Number of splits for cross-validation. If None, the number of splits
is equal to the number of subjects.
"""

# flake8: noqa: C901
Expand Down Expand Up @@ -675,7 +679,12 @@ def evaluate(
scorer = get_scorer(self.paradigm.scoring)

# perform leave one subject out CV
cv = LeaveOneGroupOut()
if self.n_splits is None:
cv = LeaveOneGroupOut()
else:
cv = GroupKFold(n_splits=self.n_splits)
n_subjects = self.n_splits

inner_cv = StratifiedKFold(3, shuffle=True, random_state=self.random_state)

# Implement Grid Search
Expand Down

0 comments on commit 89d5cb3

Please sign in to comment.