Skip to content

Commit

Permalink
[FIX] shuffle everything
Browse files Browse the repository at this point in the history
  • Loading branch information
bruAristimunha committed Oct 23, 2024
1 parent 590edb1 commit b151d61
Showing 1 changed file with 42 additions and 32 deletions.
74 changes: 42 additions & 32 deletions moabb/evaluations/splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ class WithinSessionSplitter(BaseCrossValidator):
random_state: int, RandomState instance or None, default=None
Important when `shuffle` is True. Controls the randomness of splits.
Pass an int for reproducible output across multiple function calls.
shuffle : bool, default=True
shuffle_session : bool, default=True
Whether to shuffle each class's samples before splitting into batches.
Note that the samples within each split will not be shuffled.
shuffle_subjects : bool, default=False
Apply shuffle in mixing subjects and sessions, this parameter allows
sample iterations of the sppliter.
Examples
-----------
Expand Down Expand Up @@ -57,43 +60,50 @@ class WithinSessionSplitter(BaseCrossValidator):
Test: index=[4 5], group=[1 1], sessions=['T' 'T']
"""

def __init__(self, n_folds: int = 5, random_state: int = 42, shuffle: bool = True):
# Check type
assert isinstance(n_folds, int)

def __init__(
self,
n_folds: int = 5,
random_state: int = 42,
shuffle_subjects: bool = False,
shuffle_session: bool = True,
):
self.n_folds = n_folds
# Setting random state
self.shuffle_subjects = shuffle_subjects
self.shuffle_session = shuffle_session
self.random_state = check_random_state(random_state)
self.shuffle = shuffle

def get_n_splits(self, metadata):
num_sessions_subjects = metadata.groupby(["subject", "session"]).ngroups
return self.n_folds * num_sessions_subjects

def split(self, y, metadata, **kwargs):
all_index = metadata.index.values
subjects = metadata.subject.values
sessions = metadata.session.values

# Get the unique combinations of subject and session
group_keys = metadata.groupby(["subject", "session"]).groups.keys()
group_keys = list(group_keys)

# Shuffle the order of groups if shuffle is True
if self.shuffle:
self.random_state.shuffle(group_keys)

for subject, session in group_keys:
# Get the indices for the current group
group_mask = (subjects == subject) & (sessions == session)
group_indices = all_index[group_mask]
group_y = y[group_mask]

# Use StratifiedKFold with the group-specific random state
cv = StratifiedKFold(
n_splits=self.n_folds,
shuffle=self.shuffle,
random_state=self.random_state,
)
for ix_train, ix_test in cv.split(group_indices, group_y):
yield group_indices[ix_train], group_indices[ix_test]
subjects = metadata.subject.unique()

# Shuffle subjects if required
if self.shuffle_subjects:
self.random_state.shuffle(subjects)

for subject in subjects:
subject_mask = metadata.subject == subject
subject_indices = all_index[subject_mask]
subject_metadata = metadata[subject_mask]
sessions = subject_metadata.session.unique()

# Shuffle sessions if required
if self.shuffle_session:
self.random_state.shuffle(sessions)

for session in sessions:
session_mask = subject_metadata.session == session
indices = subject_indices[session_mask]
group_y = y[indices]

# Use StratifiedKFold with the group-specific random state
cv = StratifiedKFold(
n_splits=self.n_folds,
shuffle=self.shuffle_session,
random_state=self.random_state,
)
for ix_train, ix_test in cv.split(indices, group_y):
yield indices[ix_train], indices[ix_test]

0 comments on commit b151d61

Please sign in to comment.