Skip to content

Commit

Permalink
Merge branch 'develop' into n_subjects
Browse files Browse the repository at this point in the history
  • Loading branch information
bruAristimunha authored Jan 5, 2024
2 parents 35af427 + 4b297b9 commit 4e19bd2
Show file tree
Hide file tree
Showing 22 changed files with 274 additions and 66 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/test-braindecode.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ jobs:
- name: Install and test braindecode
run: |
source $VENV
poetry run pip install torch
cd braindecode
poetry run pip install -e .[tests]
poetry run pytest test/
pip install -e .[tests]
pytest test/
# poetry run pip install -U https://api.github.com/repos/braindecode/braindecode/zipball/master
38 changes: 38 additions & 0 deletions .github/workflows/whats-new.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: Check What's new update

on:
push:
branches: [develop]
pull_request:
branches: [develop]

jobs:
check-whats-news:
runs-on: ubuntu-latest

steps:
- name: Check for file changes in PR
run: |
pr_number=${{ github.event.pull_request.number }}
response=$(curl -s -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
"https://api.github.com/repos/${{ github.repository }}/pulls/${pr_number}/files")
file_changed=false
file_to_check="docs/source/whats_new.rst" # Specify the path to your file
for file in $(echo "${response}" | jq -r '.[] | .filename'); do
if [ "$file" == "$file_to_check" ]; then
file_changed=true
break
fi
done
if $file_changed; then
echo "File ${file_to_check} has been changed in the PR."
else
echo "File ${file_to_check} has not been changed in the PR."
echo "::error::File ${file_to_check} has not been changed in the PR."
exit 1
fi
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # This token is provided by Actions, you do not need to create your own token
3 changes: 3 additions & 0 deletions docs/source/install/install_pip.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ Installing from PyPI

MOABB can be installed via pip from `PyPI <https://pypi.org/project/moabb>`__.

.. warning::
MOABB is only compatible with **Python 3.8, 3.9, 3.10 and 3.11**.

.. note::
We recommend the most updated version of pip to install from PyPI.

Expand Down
4 changes: 3 additions & 1 deletion docs/source/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ Enhancements
~~~~~~~~~~~~

- Adding cache option to the evaluation (:gh:`517` by `Bruno Aristimunha`_)
- Option to interpolate channel in paradigms' `match_all` method (:gh:`480` by `Gregoire Cattan`_)

Bugs
~~~~

- None
- Fix TRCA implementation for different stimulation freqs and for signal filtering (:gh:522 by `Sylvain Chevallier`_)
- Fix saving to BIDS runs with a description string in their name (:gh:`530` by `Pierre Guetschel`_)

API changes
~~~~~~~~~~~
Expand Down
4 changes: 1 addition & 3 deletions examples/plot_cross_subject_ssvep.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@
pipelines["CCA"] = make_pipeline(SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=2))

pipelines_TRCA = {}
pipelines_TRCA["TRCA"] = make_pipeline(
SSVEP_TRCA(interval=interval, freqs=freqs, n_fbands=5)
)
pipelines_TRCA["TRCA"] = make_pipeline(SSVEP_TRCA(interval=interval, freqs=freqs))

pipelines_MSET_CCA = {}
pipelines_MSET_CCA["MSET_CCA"] = make_pipeline(SSVEP_MsetCCA(freqs=freqs))
Expand Down
22 changes: 20 additions & 2 deletions moabb/datasets/bids_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@ def subject_bids_to_moabb(subject: str):
return int(subject)


def run_moabb_to_bids(run: str):
"""Convert the run to run index plus eventually description."""
p = r"([0-9]+)(|[a-zA-Z]+[a-zA-Z0-9]*)"
idx, desc = re.fullmatch(p, run).groups()
out = {"run": idx}
if desc:
out["recording"] = desc
return out


def run_bids_to_moabb(path: mne_bids.BIDSPath):
"""Extracts the run index plus eventually description from a path."""
if path.recording is None:
return path.run
return f"{path.run}{path.recording}"


@dataclass
class BIDSInterfaceBase(abc.ABC):
"""Base class for BIDSInterface.
Expand Down Expand Up @@ -173,7 +190,7 @@ def load(self, preload=False):
session_moabb = path.session
session = sessions_data.setdefault(session_moabb, {})
run = self._load_file(path, preload=preload)
session[path.run] = run
session[run_bids_to_moabb(path)] = run
log.info("Finished reading cache of %s", repr(self))
return sessions_data

Expand Down Expand Up @@ -223,12 +240,13 @@ def save(self, sessions_data):
)
continue

run_kwargs = run_moabb_to_bids(run)
bids_path = mne_bids.BIDSPath(
root=self.root,
subject=subject_moabb_to_bids(self.subject),
session=session,
task=self.dataset.paradigm,
run=run,
**run_kwargs,
description=self.desc,
extension=self._extension,
datatype=self._datatype,
Expand Down
2 changes: 1 addition & 1 deletion moabb/datasets/compound_dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
BI_Il,
Cattan2019_VR_Il,
)
from .utils import _init_compound_dataset_list
from .utils import _init_compound_dataset_list, compound # noqa: F401


_init_compound_dataset_list()
Expand Down
26 changes: 23 additions & 3 deletions moabb/datasets/compound_dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,12 @@ class CompoundDataset(BaseDataset):
interval: list with 2 entries
See `BaseDataset`.
paradigm: ['p300','imagery', 'ssvep', 'rstate']
Defines what sort of dataset this is
"""

def __init__(self, subjects_list: list, code: str, interval: list, paradigm: str):
def __init__(self, subjects_list: list, code: str, interval: list):
self._set_subjects_list(subjects_list)
dataset, _, _, _ = self.subjects_list[0]
paradigm = self._get_paradigm()
super().__init__(
subjects=list(range(1, self.count + 1)),
sessions_per_subject=self._get_sessions_per_subject(),
Expand All @@ -52,6 +51,17 @@ def __init__(self, subjects_list: list, code: str, interval: list, paradigm: str
paradigm=paradigm,
)

@property
def datasets(self):
all_datasets = [entry[0] for entry in self.subjects_list]
found_flags = set()
filtered_dataset = []
for dataset in all_datasets:
if dataset.code not in found_flags:
filtered_dataset.append(dataset)
found_flags.add(dataset.code)
return filtered_dataset

@property
def count(self):
return len(self.subjects_list)
Expand All @@ -78,6 +88,16 @@ def _set_subjects_list(self, subjects_list: list):
for compoundDataset in subjects_list:
self.subjects_list.extend(compoundDataset.subjects_list)

def _get_paradigm(self):
dataset, _, _, _ = self.subjects_list[0]
paradigm = dataset.paradigm
# Check all of the datasets have the same paradigm
for i in range(1, len(self.subjects_list)):
entry = self.subjects_list[i]
dataset = entry[0]
assert dataset.paradigm == paradigm
return paradigm

def _with_data_origin(self, data: dict, shopped_subject):
data_origin = self.subjects_list[shopped_subject - 1]

Expand Down
1 change: 0 additions & 1 deletion moabb/datasets/compound_dataset/bi_illiteracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(self, subjects_list, dataset=None, code=None):
subjects_list=subjects_list,
code=code,
interval=[0, 1.0],
paradigm="p300",
)


Expand Down
15 changes: 15 additions & 0 deletions moabb/datasets/compound_dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import inspect
from typing import List

import moabb.datasets.compound_dataset as db
from moabb.datasets.base import BaseDataset
from moabb.datasets.compound_dataset.base import CompoundDataset


Expand All @@ -11,3 +13,16 @@ def _init_compound_dataset_list():
for ds in inspect.getmembers(db, inspect.isclass):
if issubclass(ds[1], CompoundDataset) and not ds[0] == "CompoundDataset":
compound_dataset_list.append(ds[1])


def compound(*datasets: List[BaseDataset], interval=[0, 1.0]):
subjects_list = [
(d, subject, None, None) for d in datasets for subject in d.subject_list
]
code = "".join([d.code for d in datasets])
ret = CompoundDataset(
subjects_list=subjects_list,
code=code,
interval=interval,
)
return ret
2 changes: 1 addition & 1 deletion moabb/datasets/gigadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


log = logging.getLogger(__name__)
GIGA_URL = "ftp://parrot.genomics.cn/gigadb/pub/10.5524/100001_101000/100295/mat_data/"
GIGA_URL = "https://ftp.cngb.org/pub/gigadb/pub/10.5524/100001_101000/100295/mat_data/"


class Cho2017(BaseDataset):
Expand Down
37 changes: 36 additions & 1 deletion moabb/datasets/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
from operator import methodcaller
from typing import Dict, List, Tuple, Union
from warnings import warn

import mne
import numpy as np
Expand Down Expand Up @@ -199,13 +200,15 @@ def __init__(
tmax: float,
baseline: Tuple[float, float],
channels: List[str] = None,
interpolate_missing_channels: bool = False,
):
assert isinstance(event_id, dict) # not None
self.event_id = event_id
self.tmin = tmin
self.tmax = tmax
self.baseline = baseline
self.channels = channels
self.interpolate_missing_channels = interpolate_missing_channels

def transform(self, X, y=None):
raw = X["raw"]
Expand All @@ -218,9 +221,40 @@ def transform(self, X, y=None):
if self.channels is None:
picks = mne.pick_types(raw.info, eeg=True, stim=False)
else:
available_channels = raw.info["ch_names"]
if self.interpolate_missing_channels:
missing_channels = list(set(self.channels).difference(available_channels))

# add missing channels (contains only zeros by default)
try:
raw.add_reference_channels(missing_channels)
except IndexError:
# Index error can occurs if the channels we add are not part of this epoch montage
# Then log a warning
montage = raw.info["dig"]
warn(
f"Montage disabled as one of these channels, {missing_channels}, is not part of the montage {montage}"
)
# and disable the montage
raw.info.pop("dig")
# run again with montage disabled
raw.add_reference_channels(missing_channels)

# Trick: mark these channels as bad
raw.info["bads"].extend(missing_channels)
# ...and use mne bad channel interpolation to generate the value of the missing channels
try:
raw.interpolate_bads(origin="auto")
except ValueError:
# use default origin if montage info not available
raw.interpolate_bads(origin=(0, 0, 0.04))
# update the name of the available channels
available_channels = self.channels

picks = mne.pick_channels(
raw.info["ch_names"], include=self.channels, ordered=True
available_channels, include=self.channels, ordered=True
)
assert len(picks) == len(self.channels)

epochs = mne.Epochs(
raw,
Expand All @@ -236,6 +270,7 @@ def transform(self, X, y=None):
event_repeated="drop",
on_missing="ignore",
)
warn(f"warnEpochs {epochs}")
return epochs


Expand Down
12 changes: 9 additions & 3 deletions moabb/evaluations/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from abc import ABC, abstractmethod

import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.model_selection import GridSearchCV

Expand Down Expand Up @@ -172,6 +173,8 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None):
for _, pipeline in pipelines.items():
if not (isinstance(pipeline, BaseEstimator)):
raise (ValueError("pipelines must only contains Pipelines " "instance"))

res_per_db = []
for dataset in self.datasets:
log.info("Processing dataset: {}".format(dataset.code))
process_pipeline = self.paradigm.make_process_pipelines(
Expand All @@ -191,10 +194,13 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None):
)
for res in results:
self.push_result(res, pipelines, process_pipeline)
res_per_db.append(
self.results.to_dataframe(
pipelines=pipelines, process_pipeline=process_pipeline
)
)

return self.results.to_dataframe(
pipelines=pipelines, process_pipeline=process_pipeline
)
return pd.concat(res_per_db, ignore_index=True)

def push_result(self, res, pipelines, process_pipeline):
message = "{} | ".format(res["pipeline"])
Expand Down
Loading

0 comments on commit 4e19bd2

Please sign in to comment.