diff --git a/.github/workflows/test-braindecode.yml b/.github/workflows/test-braindecode.yml index 6660ce6fe..8c1f4aa8f 100644 --- a/.github/workflows/test-braindecode.yml +++ b/.github/workflows/test-braindecode.yml @@ -68,9 +68,8 @@ jobs: - name: Install and test braindecode run: | source $VENV - poetry run pip install torch cd braindecode - poetry run pip install -e .[tests] - poetry run pytest test/ + pip install -e .[tests] + pytest test/ # poetry run pip install -U https://api.github.com/repos/braindecode/braindecode/zipball/master diff --git a/.github/workflows/whats-new.yml b/.github/workflows/whats-new.yml new file mode 100644 index 000000000..4c121cda4 --- /dev/null +++ b/.github/workflows/whats-new.yml @@ -0,0 +1,38 @@ +name: Check What's new update + +on: + push: + branches: [develop] + pull_request: + branches: [develop] + +jobs: + check-whats-news: + runs-on: ubuntu-latest + + steps: + - name: Check for file changes in PR + run: | + pr_number=${{ github.event.pull_request.number }} + response=$(curl -s -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \ + "https://api.github.com/repos/${{ github.repository }}/pulls/${pr_number}/files") + + file_changed=false + file_to_check="docs/source/whats_new.rst" # Specify the path to your file + + for file in $(echo "${response}" | jq -r '.[] | .filename'); do + if [ "$file" == "$file_to_check" ]; then + file_changed=true + break + fi + done + + if $file_changed; then + echo "File ${file_to_check} has been changed in the PR." + else + echo "File ${file_to_check} has not been changed in the PR." + echo "::error::File ${file_to_check} has not been changed in the PR." + exit 1 + fi + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # This token is provided by Actions, you do not need to create your own token diff --git a/docs/source/install/install_pip.rst b/docs/source/install/install_pip.rst index 2029738ed..f05561585 100644 --- a/docs/source/install/install_pip.rst +++ b/docs/source/install/install_pip.rst @@ -5,6 +5,9 @@ Installing from PyPI MOABB can be installed via pip from `PyPI `__. +.. warning:: + MOABB is only compatible with **Python 3.8, 3.9, 3.10 and 3.11**. + .. note:: We recommend the most updated version of pip to install from PyPI. diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index 301f3a0ce..b361f18c5 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -19,11 +19,13 @@ Enhancements ~~~~~~~~~~~~ - Adding cache option to the evaluation (:gh:`517` by `Bruno Aristimunha`_) +- Option to interpolate channel in paradigms' `match_all` method (:gh:`480` by `Gregoire Cattan`_) Bugs ~~~~ -- None +- Fix TRCA implementation for different stimulation freqs and for signal filtering (:gh:522 by `Sylvain Chevallier`_) +- Fix saving to BIDS runs with a description string in their name (:gh:`530` by `Pierre Guetschel`_) API changes ~~~~~~~~~~~ diff --git a/examples/plot_cross_subject_ssvep.py b/examples/plot_cross_subject_ssvep.py index e6319b509..20dd35052 100644 --- a/examples/plot_cross_subject_ssvep.py +++ b/examples/plot_cross_subject_ssvep.py @@ -104,9 +104,7 @@ pipelines["CCA"] = make_pipeline(SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=2)) pipelines_TRCA = {} -pipelines_TRCA["TRCA"] = make_pipeline( - SSVEP_TRCA(interval=interval, freqs=freqs, n_fbands=5) -) +pipelines_TRCA["TRCA"] = make_pipeline(SSVEP_TRCA(interval=interval, freqs=freqs)) pipelines_MSET_CCA = {} pipelines_MSET_CCA["MSET_CCA"] = make_pipeline(SSVEP_MsetCCA(freqs=freqs)) diff --git a/moabb/datasets/bids_interface.py b/moabb/datasets/bids_interface.py index 6c9b0a7c3..546132d87 100644 --- a/moabb/datasets/bids_interface.py +++ b/moabb/datasets/bids_interface.py @@ -56,6 +56,23 @@ def subject_bids_to_moabb(subject: str): return int(subject) +def run_moabb_to_bids(run: str): + """Convert the run to run index plus eventually description.""" + p = r"([0-9]+)(|[a-zA-Z]+[a-zA-Z0-9]*)" + idx, desc = re.fullmatch(p, run).groups() + out = {"run": idx} + if desc: + out["recording"] = desc + return out + + +def run_bids_to_moabb(path: mne_bids.BIDSPath): + """Extracts the run index plus eventually description from a path.""" + if path.recording is None: + return path.run + return f"{path.run}{path.recording}" + + @dataclass class BIDSInterfaceBase(abc.ABC): """Base class for BIDSInterface. @@ -173,7 +190,7 @@ def load(self, preload=False): session_moabb = path.session session = sessions_data.setdefault(session_moabb, {}) run = self._load_file(path, preload=preload) - session[path.run] = run + session[run_bids_to_moabb(path)] = run log.info("Finished reading cache of %s", repr(self)) return sessions_data @@ -223,12 +240,13 @@ def save(self, sessions_data): ) continue + run_kwargs = run_moabb_to_bids(run) bids_path = mne_bids.BIDSPath( root=self.root, subject=subject_moabb_to_bids(self.subject), session=session, task=self.dataset.paradigm, - run=run, + **run_kwargs, description=self.desc, extension=self._extension, datatype=self._datatype, diff --git a/moabb/datasets/compound_dataset/__init__.py b/moabb/datasets/compound_dataset/__init__.py index 6232e2129..1fc660352 100644 --- a/moabb/datasets/compound_dataset/__init__.py +++ b/moabb/datasets/compound_dataset/__init__.py @@ -8,7 +8,7 @@ BI_Il, Cattan2019_VR_Il, ) -from .utils import _init_compound_dataset_list +from .utils import _init_compound_dataset_list, compound # noqa: F401 _init_compound_dataset_list() diff --git a/moabb/datasets/compound_dataset/base.py b/moabb/datasets/compound_dataset/base.py index 8d4fcf608..2aa7be427 100644 --- a/moabb/datasets/compound_dataset/base.py +++ b/moabb/datasets/compound_dataset/base.py @@ -36,13 +36,12 @@ class CompoundDataset(BaseDataset): interval: list with 2 entries See `BaseDataset`. - paradigm: ['p300','imagery', 'ssvep', 'rstate'] - Defines what sort of dataset this is """ - def __init__(self, subjects_list: list, code: str, interval: list, paradigm: str): + def __init__(self, subjects_list: list, code: str, interval: list): self._set_subjects_list(subjects_list) dataset, _, _, _ = self.subjects_list[0] + paradigm = self._get_paradigm() super().__init__( subjects=list(range(1, self.count + 1)), sessions_per_subject=self._get_sessions_per_subject(), @@ -52,6 +51,17 @@ def __init__(self, subjects_list: list, code: str, interval: list, paradigm: str paradigm=paradigm, ) + @property + def datasets(self): + all_datasets = [entry[0] for entry in self.subjects_list] + found_flags = set() + filtered_dataset = [] + for dataset in all_datasets: + if dataset.code not in found_flags: + filtered_dataset.append(dataset) + found_flags.add(dataset.code) + return filtered_dataset + @property def count(self): return len(self.subjects_list) @@ -78,6 +88,16 @@ def _set_subjects_list(self, subjects_list: list): for compoundDataset in subjects_list: self.subjects_list.extend(compoundDataset.subjects_list) + def _get_paradigm(self): + dataset, _, _, _ = self.subjects_list[0] + paradigm = dataset.paradigm + # Check all of the datasets have the same paradigm + for i in range(1, len(self.subjects_list)): + entry = self.subjects_list[i] + dataset = entry[0] + assert dataset.paradigm == paradigm + return paradigm + def _with_data_origin(self, data: dict, shopped_subject): data_origin = self.subjects_list[shopped_subject - 1] diff --git a/moabb/datasets/compound_dataset/bi_illiteracy.py b/moabb/datasets/compound_dataset/bi_illiteracy.py index 055db3b46..ec6ab350b 100644 --- a/moabb/datasets/compound_dataset/bi_illiteracy.py +++ b/moabb/datasets/compound_dataset/bi_illiteracy.py @@ -15,7 +15,6 @@ def __init__(self, subjects_list, dataset=None, code=None): subjects_list=subjects_list, code=code, interval=[0, 1.0], - paradigm="p300", ) diff --git a/moabb/datasets/compound_dataset/utils.py b/moabb/datasets/compound_dataset/utils.py index f41e0914c..06099bccf 100644 --- a/moabb/datasets/compound_dataset/utils.py +++ b/moabb/datasets/compound_dataset/utils.py @@ -1,6 +1,8 @@ import inspect +from typing import List import moabb.datasets.compound_dataset as db +from moabb.datasets.base import BaseDataset from moabb.datasets.compound_dataset.base import CompoundDataset @@ -11,3 +13,16 @@ def _init_compound_dataset_list(): for ds in inspect.getmembers(db, inspect.isclass): if issubclass(ds[1], CompoundDataset) and not ds[0] == "CompoundDataset": compound_dataset_list.append(ds[1]) + + +def compound(*datasets: List[BaseDataset], interval=[0, 1.0]): + subjects_list = [ + (d, subject, None, None) for d in datasets for subject in d.subject_list + ] + code = "".join([d.code for d in datasets]) + ret = CompoundDataset( + subjects_list=subjects_list, + code=code, + interval=interval, + ) + return ret diff --git a/moabb/datasets/gigadb.py b/moabb/datasets/gigadb.py index 2a8ecacb3..dc9014280 100644 --- a/moabb/datasets/gigadb.py +++ b/moabb/datasets/gigadb.py @@ -13,7 +13,7 @@ log = logging.getLogger(__name__) -GIGA_URL = "ftp://parrot.genomics.cn/gigadb/pub/10.5524/100001_101000/100295/mat_data/" +GIGA_URL = "https://ftp.cngb.org/pub/gigadb/pub/10.5524/100001_101000/100295/mat_data/" class Cho2017(BaseDataset): diff --git a/moabb/datasets/preprocessing.py b/moabb/datasets/preprocessing.py index bb18363ab..a9aac1e81 100644 --- a/moabb/datasets/preprocessing.py +++ b/moabb/datasets/preprocessing.py @@ -2,6 +2,7 @@ from collections import OrderedDict from operator import methodcaller from typing import Dict, List, Tuple, Union +from warnings import warn import mne import numpy as np @@ -199,6 +200,7 @@ def __init__( tmax: float, baseline: Tuple[float, float], channels: List[str] = None, + interpolate_missing_channels: bool = False, ): assert isinstance(event_id, dict) # not None self.event_id = event_id @@ -206,6 +208,7 @@ def __init__( self.tmax = tmax self.baseline = baseline self.channels = channels + self.interpolate_missing_channels = interpolate_missing_channels def transform(self, X, y=None): raw = X["raw"] @@ -218,9 +221,40 @@ def transform(self, X, y=None): if self.channels is None: picks = mne.pick_types(raw.info, eeg=True, stim=False) else: + available_channels = raw.info["ch_names"] + if self.interpolate_missing_channels: + missing_channels = list(set(self.channels).difference(available_channels)) + + # add missing channels (contains only zeros by default) + try: + raw.add_reference_channels(missing_channels) + except IndexError: + # Index error can occurs if the channels we add are not part of this epoch montage + # Then log a warning + montage = raw.info["dig"] + warn( + f"Montage disabled as one of these channels, {missing_channels}, is not part of the montage {montage}" + ) + # and disable the montage + raw.info.pop("dig") + # run again with montage disabled + raw.add_reference_channels(missing_channels) + + # Trick: mark these channels as bad + raw.info["bads"].extend(missing_channels) + # ...and use mne bad channel interpolation to generate the value of the missing channels + try: + raw.interpolate_bads(origin="auto") + except ValueError: + # use default origin if montage info not available + raw.interpolate_bads(origin=(0, 0, 0.04)) + # update the name of the available channels + available_channels = self.channels + picks = mne.pick_channels( - raw.info["ch_names"], include=self.channels, ordered=True + available_channels, include=self.channels, ordered=True ) + assert len(picks) == len(self.channels) epochs = mne.Epochs( raw, @@ -236,6 +270,7 @@ def transform(self, X, y=None): event_repeated="drop", on_missing="ignore", ) + warn(f"warnEpochs {epochs}") return epochs diff --git a/moabb/evaluations/base.py b/moabb/evaluations/base.py index 6faefe850..b99eefaaa 100644 --- a/moabb/evaluations/base.py +++ b/moabb/evaluations/base.py @@ -1,6 +1,7 @@ import logging from abc import ABC, abstractmethod +import pandas as pd from sklearn.base import BaseEstimator from sklearn.model_selection import GridSearchCV @@ -172,6 +173,8 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None): for _, pipeline in pipelines.items(): if not (isinstance(pipeline, BaseEstimator)): raise (ValueError("pipelines must only contains Pipelines " "instance")) + + res_per_db = [] for dataset in self.datasets: log.info("Processing dataset: {}".format(dataset.code)) process_pipeline = self.paradigm.make_process_pipelines( @@ -191,10 +194,13 @@ def process(self, pipelines, param_grid=None, postprocess_pipeline=None): ) for res in results: self.push_result(res, pipelines, process_pipeline) + res_per_db.append( + self.results.to_dataframe( + pipelines=pipelines, process_pipeline=process_pipeline + ) + ) - return self.results.to_dataframe( - pipelines=pipelines, process_pipeline=process_pipeline - ) + return pd.concat(res_per_db, ignore_index=True) def push_result(self, res, pipelines, process_pipeline): message = "{} | ".format(res["pipeline"]) diff --git a/moabb/paradigms/base.py b/moabb/paradigms/base.py index 9b049fd66..ecaee2437 100644 --- a/moabb/paradigms/base.py +++ b/moabb/paradigms/base.py @@ -83,6 +83,7 @@ def __init__( self.resample = resample self.tmin = tmin self.tmax = tmax + self.interpolate_missing_channels = False @property @abc.abstractmethod @@ -399,6 +400,7 @@ def _get_epochs_pipeline(self, return_epochs, return_raws, dataset): tmax=bmax, baseline=baseline, channels=self.channels, + interpolate_missing_channels=self.interpolate_missing_channels, ), ), ) @@ -429,7 +431,13 @@ def _get_array_pipeline( return None return Pipeline(steps) - def match_all(self, datasets: List[BaseDataset], shift=-0.5): + def match_all( + self, + datasets: List[BaseDataset], + shift=-0.5, + channel_merge_strategy: str = "intersect", + ignore=["stim"], + ): """ Initialize this paradigm to match all datasets in parameter: - `self.resample` is set to match the minimum frequency in all datasets, minus `shift`. @@ -442,29 +450,48 @@ def match_all(self, datasets: List[BaseDataset], shift=-0.5): ---------- datasets: List[BaseDataset] A dataset instance. + shift: List[BaseDataset] + Shift the sampling frequency by this value + E.g.: if sampling=128 and shift=-0.5, then it returns 127.5 Hz + channel_merge_strategy: str (default: 'intersect') + Accepts two values: + - 'intersect': keep only channels common to all datasets + - 'union': keep all channels from all datasets, removing duplicate + ignore: List[string] + A list of channels to ignore + + ..versionadded:: 0.6.0 """ resample = None - channels = None + channels: set = None for dataset in datasets: - X, _, _ = self.get_data( - dataset, subjects=[dataset.subject_list[0]], return_epochs=True - ) + first_subject = dataset.subject_list[0] + data = dataset.get_data(subjects=[first_subject])[first_subject] + first_session = list(data.keys())[0] + session = data[first_session] + first_run = list(session.keys())[0] + X = session[first_run] info = X.info sfreq = info["sfreq"] ch_names = info["ch_names"] # get the minimum sampling frequency between all datasets resample = sfreq if resample is None else min(resample, sfreq) # get the channels common to all datasets - channels = ( - set(ch_names) - if channels is None - else set(channels).intersection(ch_names) - ) + if channels is None: + channels = set(ch_names) + elif channel_merge_strategy == "intersect": + channels = channels.intersection(ch_names) + self.interpolate_missing_channels = False + else: + channels = channels.union(ch_names) + self.interpolate_missing_channels = True # If resample=128 for example, then MNE can returns 128 or 129 samples # depending on the dataset, even if the length of the epochs is 1s # `shift=-0.5` solves this particular issue. self.resample = resample + shift - self.channels = list(channels) + + # exclude ignored channels + self.channels = list(channels.difference(ignore)) @abc.abstractmethod def _get_events_pipeline(self, dataset): diff --git a/moabb/pipelines/classification.py b/moabb/pipelines/classification.py index 07fba8d4c..d9afff3fa 100644 --- a/moabb/pipelines/classification.py +++ b/moabb/pipelines/classification.py @@ -114,9 +114,6 @@ class SSVEP_TRCA(BaseEstimator, ClassifierMixin): Frequencies corresponding to the SSVEP components. These are necessary to design the filterbank bands. - n_fbands : int, default=5 - Number of sub-bands considered for filterbank analysis. - downsample: int, default=1 Factor by which downsample the data. A downsample value of N will result on a sampling frequency of (sfreq // N) by taking one sample every N of @@ -188,7 +185,6 @@ def __init__( self, interval, freqs, - n_fbands=5, downsample=1, is_ensemble=True, method="original", @@ -196,7 +192,7 @@ def __init__( ): self.freqs = freqs self.peaks = np.array([float(f) for f in freqs.keys()]) - self.n_fbands = n_fbands + self.n_fbands = len(self.peaks) self.downsample = downsample self.interval = interval self.slen = interval[1] - interval[0] diff --git a/moabb/pipelines/utils.py b/moabb/pipelines/utils.py index f836bb01a..105b2a8b4 100644 --- a/moabb/pipelines/utils.py +++ b/moabb/pipelines/utils.py @@ -289,6 +289,10 @@ def filterbank(X, sfreq, idx_fb, peaks): Code based on the Matlab implementation from authors of [1]_ (https://github.com/mnakanishi/TRCA-SSVEP). """ + if idx_fb > len(peaks): + raise ( + ValueError("idx_fb should be less than number of SSVEP stimulus frequency") + ) # Calibration data comes in batches of trials if X.ndim == 3: @@ -299,39 +303,33 @@ def filterbank(X, sfreq, idx_fb, peaks): elif X.ndim == 2: num_chans = X.shape[0] num_trials = 1 + else: + print("error") sfreq = sfreq / 2 - min_freq = np.min(peaks) + peaks = np.sort(peaks) max_freq = np.max(peaks) if max_freq < 40: - top = 100 + top = 40 else: - top = 115 + top = 60 # Check for Nyquist if top >= sfreq: top = sfreq - 10 - diff = max_freq - min_freq # Lowcut frequencies for the pass band (depends on the frequencies of SSVEP) # No more than 3dB loss in the passband - - passband = [min_freq - 2 + x * diff for x in range(7)] + passband = [peaks[i] - 1 for i in range(len(peaks))] # At least 40db attenuation in the stopband - if min_freq - 4 > 0: - stopband = [ - min_freq - 4 + x * (diff - 2) if x < 3 else min_freq - 4 + x * diff - for x in range(7) - ] - else: - stopband = [2 + x * (diff - 2) if x < 3 else 2 + x * diff for x in range(7)] + stopband = [peaks[i] - 2 for i in range(len(peaks))] Wp = [passband[idx_fb] / sfreq, top / sfreq] - Ws = [stopband[idx_fb] / sfreq, (top + 7) / sfreq] + Ws = [stopband[idx_fb] / sfreq, (top + 20) / sfreq] - N, Wn = scp.cheb1ord(Wp, Ws, 3, 40) # Chebyshev type I filter order selection. + N, Wn = scp.cheb1ord(Wp, Ws, 3, 15) # Chebyshev type I filter order selection. B, A = scp.cheby1(N, 0.5, Wn, btype="bandpass") # Chebyshev type I filter design diff --git a/moabb/tests/datasets.py b/moabb/tests/datasets.py index 5842cc35b..3ae91cd79 100644 --- a/moabb/tests/datasets.py +++ b/moabb/tests/datasets.py @@ -351,7 +351,6 @@ def test_fake_dataset(self): subjects_list, code="CompoundDataset-test", interval=[0, 1], - paradigm=self.paradigm, ) data = compound_data.get_data() @@ -385,7 +384,6 @@ def test_compound_dataset_composition(self): subjects_list, code="CompoundDataset-test", interval=[0, 1], - paradigm=self.paradigm, ) # Add it two time to a subjects_list @@ -394,9 +392,11 @@ def test_compound_dataset_composition(self): subjects_list, code="CompoundDataset-test", interval=[0, 1], - paradigm=self.paradigm, ) + # Assert there is only one source dataset in the compound dataset + self.assertEqual(len(compound_data.datasets), 1) + # Assert that the coumpouned dataset has two times more subject than the original one. data = compound_data.get_data() self.assertEqual(len(data), 2) @@ -408,7 +408,7 @@ def test_get_sessions_per_subject(self): n_runs=self.n_runs, n_subjects=self.n_subjects, event_list=["Target", "NonTarget"], - paradigm=self.paradigm, + paradigm=self.ds.paradigm, ) # Add the two datasets to a CompoundDataset @@ -417,9 +417,11 @@ def test_get_sessions_per_subject(self): subjects_list, code="CompoundDataset", interval=[0, 1], - paradigm=self.paradigm, ) + # Assert there are two source datasets (ds and ds2) in the compound dataset + self.assertEqual(len(compound_dataset.datasets), 2) + # Test private method _get_sessions_per_subject returns the minimum number of sessions per subjects self.assertEqual(compound_dataset._get_sessions_per_subject(), self.n_sessions) @@ -430,7 +432,7 @@ def test_event_id_correctly_updated(self): n_runs=self.n_runs, n_subjects=self.n_subjects, event_list=["Target2", "NonTarget2"], - paradigm=self.paradigm, + paradigm=self.ds.paradigm, ) # Add the two datasets to a CompoundDataset @@ -440,7 +442,6 @@ def test_event_id_correctly_updated(self): subjects_list, code="CompoundDataset", interval=[0, 1], - paradigm=self.paradigm, ) # Check that the event_id of the compound_dataset is the same has the first dataset diff --git a/moabb/tests/evaluations.py b/moabb/tests/evaluations.py index b0c5e0be8..b698ff223 100644 --- a/moabb/tests/evaluations.py +++ b/moabb/tests/evaluations.py @@ -14,6 +14,7 @@ from sklearn.pipeline import FunctionTransformer, Pipeline, make_pipeline from moabb.analysis.results import get_string_rep +from moabb.datasets.compound_dataset import compound from moabb.datasets.fake import FakeDataset from moabb.evaluations import evaluations as ev from moabb.evaluations.utils import create_save_path, save_model_cv, save_model_list @@ -82,6 +83,41 @@ def test_eval_results(self): # We should have 9 columns in the results data frame self.assertEqual(len(results[0].keys()), 9 if _carbonfootprint else 8) + def test_compound_dataset(self): + ch1 = ["C3", "Cz", "Fz"] + dataset1 = FakeDataset( + paradigm="imagery", + event_list=["left_hand", "right_hand"], + channels=ch1, + sfreq=128, + ) + ch2 = ["C3", "C4", "Cz"] + dataset2 = FakeDataset( + paradigm="imagery", + event_list=["left_hand", "right_hand"], + channels=ch2, + sfreq=256, + ) + merged_dataset = compound(dataset1, dataset2) + + # We want to interpolate channels that are not in common between the two datasets + self.eval.paradigm.match_all( + merged_dataset.datasets, channel_merge_strategy="union" + ) + + process_pipeline = self.eval.paradigm.make_process_pipelines(dataset)[0] + results = [ + r + for r in self.eval.evaluate( + dataset, pipelines, param_grid=None, process_pipeline=process_pipeline + ) + ] + + # We should get 4 results, 2 sessions 2 subjects + self.assertEqual(len(results), 4) + # We should have 9 columns in the results data frame + self.assertEqual(len(results[0].keys()), 9 if _carbonfootprint else 8) + def test_eval_grid_search(self): # Test grid search param_grid = {"C": {"csp__metric": ["euclid", "riemann"]}} diff --git a/moabb/tests/paradigms.py b/moabb/tests/paradigms.py index 756e2011c..a7528dfb2 100644 --- a/moabb/tests/paradigms.py +++ b/moabb/tests/paradigms.py @@ -352,7 +352,7 @@ def test_match_all(self): dataset1 = FakeDataset( paradigm="p300", event_list=["Target", "NonTarget"], - channels=("C3", "Cz", "Fz"), + channels=["C3", "Cz", "Fz"], sfreq=64, ) dataset2 = FakeDataset( @@ -368,11 +368,30 @@ def test_match_all(self): sfreq=512, ) shift = -0.5 - paradigm.match_all([dataset1, dataset2, dataset3], shift=shift) + + paradigm.match_all( + [dataset1, dataset2, dataset3], shift=shift, channel_merge_strategy="union" + ) # match_all should returns the smallest frequency minus 0.5. # See comment inside the match_all method self.assertEqual(paradigm.resample, 64 + shift) + self.assertEqual(paradigm.channels.sort(), ["C3", "Cz", "Fz", "C4"].sort()) + self.assertEqual(paradigm.interpolate_missing_channels, True) + X, _, _ = paradigm.get_data(dataset1, subjects=[1]) + n_channels, _ = X[0].shape + self.assertEqual(n_channels, 4) + + paradigm.match_all( + [dataset1, dataset2, dataset3], + shift=shift, + channel_merge_strategy="intersect", + ) + self.assertEqual(paradigm.resample, 64 + shift) self.assertEqual(paradigm.channels.sort(), ["C3", "Cz"].sort()) + self.assertEqual(paradigm.interpolate_missing_channels, False) + X, _, _ = paradigm.get_data(dataset1, subjects=[1]) + n_channels, _ = X[0].shape + self.assertEqual(n_channels, 2) def test_BaseP300_paradigm(self): paradigm = SimpleP300() diff --git a/tutorials/noplot_tutorial_5_build_a_custom_dataset.py b/tutorials/noplot_tutorial_5_build_a_custom_dataset.py index 75c87210f..7e5fcd98b 100644 --- a/tutorials/noplot_tutorial_5_build_a_custom_dataset.py +++ b/tutorials/noplot_tutorial_5_build_a_custom_dataset.py @@ -65,7 +65,6 @@ def __init__(self): subjects_list=subjects_list, code="CustomDataset1", interval=[0, 1.0], - paradigm="p300", ) @@ -81,7 +80,6 @@ def __init__(self): subjects_list=subjects_list, code="CustomDataset2", interval=[0, 1.0], - paradigm="p300", ) @@ -103,7 +101,6 @@ def __init__(self): subjects_list=subjects_list, code="CustomDataset3", interval=[0, 1.0], - paradigm="p300", ) diff --git a/tutorials/plot_Getting_Started.py b/tutorials/plot_Getting_Started.py index e7e297120..97539f393 100644 --- a/tutorials/plot_Getting_Started.py +++ b/tutorials/plot_Getting_Started.py @@ -1,6 +1,7 @@ -"""========================= -Getting Started -========================= +""" +============================ +Tutorial 0: Getting Started +============================ This tutorial takes you through a basic working example of how to use this codebase, including all the different components, up to the results diff --git a/tutorials/tutorial_3_benchmarking_multiple_pipelines.py b/tutorials/tutorial_3_benchmarking_multiple_pipelines.py index ee131b7b7..966feef50 100644 --- a/tutorials/tutorial_3_benchmarking_multiple_pipelines.py +++ b/tutorials/tutorial_3_benchmarking_multiple_pipelines.py @@ -1,4 +1,4 @@ -""" Basic tutorial on how to use MOABB. +""" =========================================== Tutorial 3: Benchmarking multiple pipelines ===========================================