diff --git a/docs/source/_static/p3_avo_sample_trials.png b/docs/source/_static/p3_avo_sample_trials.png new file mode 100644 index 00000000..68c307c2 Binary files /dev/null and b/docs/source/_static/p3_avo_sample_trials.png differ diff --git a/examples/core/p300_transfer_learning.py b/examples/core/p300_transfer_learning.py new file mode 100644 index 00000000..dcef5d14 --- /dev/null +++ b/examples/core/p300_transfer_learning.py @@ -0,0 +1,695 @@ +""".. _tutorial-p3-transfer-learning: + +EEG P3 Transfer Learning with AS-MMD +==================================== + +This tutorial demonstrates how to train a domain-adaptive deep learning model for +EEG P3 component classification across two different datasets using Adaptive +Symmetric Maximum Mean Discrepancy (AS-MMD). + +**Paper:** Chen, W., Delorme, A. (2025). Adaptive Split-MMD Training for Small-Sample +Cross-Dataset P300 EEG Classification. arXiv: `2510.21969 `_ + +Key Concepts +============ + +This tutorial covers: + +- **Domain Adaptation**: Training on multiple datasets with different recording setups +- **Deep Learning**: Using EEGConformer, a transformer-based model for EEG +- **AS-MMD**: A technique that aligns feature distributions across datasets +- **Cross-Validation**: Robust evaluation using nested stratified folds + +By the end, you'll understand how to: + +1. Load and preprocess multi-dataset EEG recordings +2. Build a domain-adaptive classifier +3. Evaluate performance across domains +4. Apply the method to your own datasets +""" + +# %% +# Part 1: Loading and Preprocessing Data +# ======================================== +# +# First, we load the datasets using EEGDashDataset. We'll use two public oddball +# datasets: +# +# 1. **ERP CORE P3**: 40 participants with active visual oddball paradigm +# (Download: https://osf.io/etdkz/files → "P3 Raw Data BIDS-Compatible") +# +# 2. **AVO (ds005863)**: 127 participants, available on OpenNeuro +# (Download: https://openneuro.org/datasets/ds005863) +# +# These datasets differ in equipment, recording sites, and participant demographics, +# making them ideal for testing domain adaptation. +from pathlib import Path +from eegdash.dataset import DS005863 +from braindecode.datasets import MOABBDataset +# Here, we are using an dataset that it in osf and other in openneuro. +# We are conveniently using EEGDashDataset and MOABBDataset to load them. +# but you can directly download from osf and use only EEGDashDataset if you prefer. + +cache_folder = Path.home() / "eegdash" +cache_folder.mkdir(parents=True, exist_ok=True) +cache_config = dict( + use=True, + save_raw=True, + path=cache_folder, +) +# Load datasets +ds_p3 = MOABBDataset( + dataset_name="ErpCore2021_P3", + subject_ids=[i for i in range(1, 3)], # all 5 subjects + dataset_load_kwargs={"cache_config": cache_config}, +) + + +ds_avo = DS005863( + cache_dir=cache_folder, + task="visualoddball", + subject=[f"{i:03d}" for i in range(1, 3)], + download=True, +) + +print(f"P3: {len(ds_p3)} recordings") +print(f"AVO: {len(ds_avo)} recordings") + +# %% +# Data Preprocessing Pipeline +# ---------------------------- +# +# Before training, we apply standard EEG preprocessing: +# +# - **Event labeling**: Identify oddball vs. standard stimuli +# - **Filtering**: 0.5-30 Hz bandpass to focus on relevant oscillations +# - **Resampling**: Downsample to 128 Hz to reduce computation +# - **Channel selection**: Keep Fz, Pz, P3, P4, Oz (standard P3 locations) +# - **Windowing**: Extract 1.2 sec epochs (-0.1s before to 1.1s after stimulus) +# - **Normalization**: Z-score normalization per trial + +import numpy as np +import torch +import mne +from braindecode.preprocessing import ( + preprocess, + Preprocessor, + create_windows_from_events, +) + +mne.set_log_level("ERROR") + +# Preprocessing parameters +LOW_FREQ = 0.5 +HIGH_FREQ = 30 +RESAMPLE_FREQ = 128 +TRIAL_START_OFFSET = -0.1 # 100 ms before stimulus +TRIAL_DURATION = 1.1 # Total window 1.1 seconds +COMMON_CHANNELS = ["Fz", "Pz", "P3", "P4", "Oz"] + + +def preprocess_dataset(dataset, channels, dataset_type="P3"): + """Apply preprocessing pipeline to an EEG dataset. + + Returns numpy arrays: (n_trials, n_channels, n_times) + """ + print(f"\nPreprocessing {dataset_type} dataset...") + + # Define preprocessing steps + preprocessors = [ + Preprocessor("set_eeg_reference", ref_channels="average", projection=True), + Preprocessor("resample", sfreq=RESAMPLE_FREQ), + Preprocessor("filter", l_freq=LOW_FREQ, h_freq=HIGH_FREQ), + Preprocessor( + "pick_channels", ch_names=[ch.lower() for ch in channels], ordered=False + ), + ] + + # Apply preprocessing + preprocess(dataset, preprocessors) + + # Extract windowed trials around stimulus onset + trial_start = int(TRIAL_START_OFFSET * RESAMPLE_FREQ) + trial_stop = int((TRIAL_START_OFFSET + TRIAL_DURATION) * RESAMPLE_FREQ) + + windows_ds = create_windows_from_events( + dataset, + trial_start_offset_samples=trial_start, + trial_stop_offset_samples=trial_stop, + preload=True, + drop_bad_windows=True, + ) + + X, y = [], [] + for i in range(len(windows_ds)): + data, label, *_ = windows_ds[i] + X.append(data) + y.append(label) + + print(f"Extracted {len(X)} trials from {dataset_type}") + return np.array(X), np.array(y) + + +# Preprocess both datasets +X_p3, y_p3 = preprocess_dataset(ds_p3, COMMON_CHANNELS, "P3") +X_avo, y_avo = preprocess_dataset(ds_avo, COMMON_CHANNELS, "AVO") + +# Combine datasets for training +X_all = np.vstack([X_p3, X_avo]) +y_all = np.hstack([y_p3, y_avo]) +src_all = np.array(["P3"] * len(X_p3) + ["AVO"] * len(X_avo)) + +print(f"\nCombined dataset: {len(X_all)} trials ({X_all.shape})") +print(f" P3: {np.sum(src_all == 'P3')} trials") +print(f" AVO: {np.sum(src_all == 'AVO')} trials") + + +# %% +# Part 2: Model Architecture and Training +# ======================================== +# +# Building the Domain-Adaptive Model +# ----------------------------------- +# +# We use **EEGConformer**, a transformer-based architecture designed for EEG signals. +# The key idea in AS-MMD is to combine: +# +# 1. **Classification loss**: Standard cross-entropy on both domains +# 2. **Domain alignment**: MMD loss to match feature distributions +# 3. **Prototype alignment**: Align class centers across domains +# 4. **Data augmentation**: Mixup + Gaussian noise for regularization + +from braindecode.models import EEGConformer +import torch.nn.functional as F + + +def normalize_data(x, eps=1e-7): + """Normalize each trial independently.""" + mean = x.mean(dim=2, keepdim=True) + std = x.std(dim=2, keepdim=True) + std = torch.clamp(std, min=eps) + return (x - mean) / std + + +# %% +# Domain Adaptation Techniques +# ---------------------------- +# +# **Mixup**: Interpolates between sample pairs +def mixup_data(x, y, alpha=0.4): + """Mix samples from the same batch.""" + if alpha > 0: + lam = np.random.beta(alpha, alpha) + else: + lam = 1.0 + + batch_size = x.size(0) + index = torch.randperm(batch_size, device=x.device) + mixed_x = lam * x + (1 - lam) * x[index] + return mixed_x, y, y[index], lam + + +# **Focal Loss**: Down-weights easy examples +def compute_focal_loss(scores, targets, gamma=2.0, alpha=0.25): + """Focal loss for class imbalance.""" + ce_loss = F.cross_entropy(scores, targets, reduction="none") + pt = torch.exp(-ce_loss) + focal_loss = alpha * (1 - pt) ** gamma * ce_loss + return focal_loss.mean() + + +# **Maximum Mean Discrepancy**: Measures domain distribution mismatch +def compute_mmd_rbf(x, y, eps=1e-8): + """RBF-kernel MMD for distribution alignment.""" + if x.dim() > 2: + x = x.view(x.size(0), -1) + if y.dim() > 2: + y = y.view(y.size(0), -1) + + z = torch.cat([x, y], dim=0) + if z.size(0) > 1: + dists = torch.cdist(z, z, p=2.0) + sigma = torch.median(dists) + sigma = torch.clamp(sigma, min=eps) + else: + sigma = torch.tensor(1.0, device=z.device) + + gamma = 1.0 / (2.0 * (sigma**2) + eps) + k_xx = torch.exp(-gamma * torch.cdist(x, x, p=2.0) ** 2) + k_yy = torch.exp(-gamma * torch.cdist(y, y, p=2.0) ** 2) + k_xy = torch.exp(-gamma * torch.cdist(x, y, p=2.0) ** 2) + + m, n = x.size(0), y.size(0) + if m <= 1 or n <= 1: + return torch.tensor(0.0, device=x.device) + + mmd = (k_xx.sum() - torch.trace(k_xx)) / (m * (m - 1) + eps) + mmd += (k_yy.sum() - torch.trace(k_yy)) / (n * (n - 1) + eps) + mmd -= 2.0 * k_xy.mean() + return mmd + + +# **Prototype Alignment**: Align class centers across domains +def compute_prototypes(features, labels, n_classes=2): + """Compute mean feature vector per class.""" + if features.dim() > 2: + features = features.view(features.size(0), -1) + + prototypes = [] + for c in range(n_classes): + mask = labels == c + if mask.sum() > 0: + proto = features[mask].mean(dim=0) + else: + proto = torch.zeros(features.size(1), device=features.device) + prototypes.append(proto) + return torch.stack(prototypes) + + +def compute_prototype_loss(features, labels, prototypes): + """Align features to their class prototypes.""" + if features.dim() > 2: + features = features.view(features.size(0), -1) + + loss = 0.0 + for i, label in enumerate(labels): + proto = prototypes[label] + loss += F.mse_loss(features[i], proto) + return loss / max(1, len(labels)) + + +# %% +# Training Configuration +# ---------------------- +# +# Define hyperparameters for stable cross-domain training + +BATCH_SIZE = 22 +LEARNING_RATE = 0.001 +WEIGHT_DECAY = 2.5e-4 +MAX_EPOCHS = 100 +EARLY_STOPPING_PATIENCE = 10 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# %% +# Part 3: Training and Evaluation +# =============================== +# +# The Training Loop +# ----------------- +# +# For each batch, we compute four loss components: +# +# 1. **Classification loss** (source + target): Standard cross-entropy +# 2. **Mixup loss** (target domain): Interpolated samples for regularization +# 3. **MMD loss**: Aligns logit-space feature distributions +# 4. **Prototype loss**: Pulls small-domain features to large-domain class centers +# +# All losses are combined with domain-adaptive weights that increase during training. + +from torch.utils.data import TensorDataset, DataLoader +from sklearn.metrics import roc_auc_score + + +def evaluate_model(model, data_loader, device): + """Evaluate model on a dataset and compute metrics.""" + model.eval() + all_preds = [] + all_targets = [] + all_probs = [] + + with torch.no_grad(): + for x, y in data_loader: + x = normalize_data(x).to(device) + y = y.to(device) + scores = model(x) + all_preds.append(scores.argmax(1).cpu().numpy()) + all_targets.append(y.cpu().numpy()) + all_probs.append(torch.softmax(scores, dim=1)[:, 1].cpu().numpy()) + + preds = np.concatenate(all_preds) + targets = np.concatenate(all_targets) + probs = np.concatenate(all_probs) + + accuracy = (preds == targets).mean() + auc = roc_auc_score(targets, probs) if len(np.unique(targets)) > 1 else 0.5 + + return {"accuracy": float(accuracy), "auc": float(auc)} + + +def make_loader(X, y, shuffle=False): + dataset = TensorDataset(torch.FloatTensor(X), torch.LongTensor(y)) + return DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=shuffle) + + +def train_asmmd_model( + Xtr_p3, + ytr_p3, + Xva_p3, + yva_p3, + Xtr_avo, + ytr_avo, + Xva_avo, + yva_avo, + n_channels, + n_times, + seed=42, +): + """Train a single AS-MMD model. + + Parameters + ---------- + Xtr_*, ytr_* : numpy arrays + Training data and labels for each domain + Xva_*, yva_* : numpy arrays + Validation data and labels for each domain + + """ + torch.manual_seed(seed) + np.random.seed(seed) + + # Create data loaders + + train_p3 = make_loader(Xtr_p3, ytr_p3, shuffle=True) + val_p3 = make_loader(Xva_p3, yva_p3, shuffle=False) + train_avo = make_loader(Xtr_avo, ytr_avo, shuffle=True) + val_avo = make_loader(Xva_avo, yva_avo, shuffle=False) + + # Initialize model + model = EEGConformer( + n_chans=n_channels, + n_outputs=2, # Binary: oddball vs. standard + n_times=n_times, + n_filters_time=40, + filter_time_length=25, + pool_time_length=75, + pool_time_stride=15, + drop_prob=0.5, + att_depth=3, + att_heads=4, + att_drop_prob=0.5, + ).to(DEVICE) + + optimizer = torch.optim.Adamax( + model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY + ) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=MAX_EPOCHS) + + # Compute domain-specific weights + n_p3, n_avo = len(Xtr_p3), len(Xtr_avo) + small_domain = "P3" if n_p3 < n_avo else "AVO" + large_domain = "AVO" if small_domain == "P3" else "P3" + + # Training loop + best_score = 0.0 + best_state = None + patience = 0 + + for epoch in range(1, MAX_EPOCHS + 1): + model.train() + + # Warmup: gradually increase domain adaptation strength + warmup_epoch = min(1.0, epoch / 20) + + loaders = {"P3": train_p3, "AVO": train_avo} + itr_small = iter(loaders[small_domain]) + + for xb_large, yb_large in loaders[large_domain]: + # Large domain batch + x_large = normalize_data(xb_large).to(DEVICE) + y_large = yb_large.to(DEVICE) + scores_large = model(x_large) + loss_cls = F.cross_entropy(scores_large, y_large) + + # Small domain batch + try: + xb_small, yb_small = next(itr_small) + except StopIteration: + itr_small = iter(loaders[small_domain]) + xb_small, yb_small = next(itr_small) + + x_small = normalize_data(xb_small).to(DEVICE) + y_small = yb_small.to(DEVICE) + + # Mixup on small domain + x_mixed, y_a, y_b, lam = mixup_data(x_small, y_small) + scores_mixed = model(x_mixed) + loss_mixup = lam * compute_focal_loss(scores_mixed, y_a) + ( + 1 - lam + ) * compute_focal_loss(scores_mixed, y_b) + + # MMD alignment + scores_orig = model(x_small) + loss_mmd = warmup_epoch * compute_mmd_rbf( + scores_large.detach(), scores_orig.detach() + ) + + # Prototype alignment + with torch.no_grad(): + proto_large = compute_prototypes( + scores_large.detach(), y_large, n_classes=2 + ) + loss_proto = warmup_epoch * compute_prototype_loss( + scores_orig, y_small, proto_large + ) + + # Combined loss + loss = loss_cls + loss_mixup + 0.3 * loss_mmd + 0.5 * loss_proto + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) + optimizer.step() + + scheduler.step() + + # Validation + val_p3_metrics = evaluate_model(model, val_p3, DEVICE) + val_avo_metrics = evaluate_model(model, val_avo, DEVICE) + + # Track best model on small domain + small_val = ( + val_p3_metrics["accuracy"] + if small_domain == "P3" + else val_avo_metrics["accuracy"] + ) + + if small_val > best_score: + best_score = small_val + best_state = model.state_dict() + patience = 0 + else: + patience += 1 + + if (epoch % 10 == 0) or (epoch == 1): + print( + f"Epoch {epoch:3d} | P3 val: {val_p3_metrics['accuracy']:.3f} | " + f"AVO val: {val_avo_metrics['accuracy']:.3f} | Score: {small_val:.3f}" + ) + + if patience >= EARLY_STOPPING_PATIENCE: + print(f"Early stopping at epoch {epoch}") + break + + # Load best model + if best_state is not None: + model.load_state_dict(best_state) + + return model + + +# %% +# Nested Cross-Validation +# ----------------------- +# +# We use nested CV to robustly estimate model performance: +# +# - **Outer folds (5)**: For test set evaluation +# - **Inner split**: Train/val split for hyperparameter tuning +# - **Repeats (5)**: Multiple random seeds for stability + +from sklearn.model_selection import StratifiedKFold, train_test_split +import pandas as pd +import warnings + +warnings.filterwarnings("ignore") + + +def run_nested_cv(X_all, y_all, src_all, channels): + """Run nested cross-validation with AS-MMD.""" + n_channels = X_all.shape[1] + n_times = X_all.shape[2] + + results = [] + SEEDS = [42, 123, 456, 789, 321] + + for repeat in range(2): # 2 repeats for quick demo (use 5 for final results) + print(f"\n{'=' * 60}") + print(f"Repeat {repeat + 1}/2") + print("=" * 60) + + cv = StratifiedKFold( + n_splits=3, shuffle=True, random_state=SEEDS[repeat] + ) # 3 folds for demo + + for fold_idx, (train_idx, test_idx) in enumerate(cv.split(X_all, y_all)): + print(f"\nFold {fold_idx + 1}/3") + + X_tr, y_tr, src_tr = X_all[train_idx], y_all[train_idx], src_all[train_idx] + X_te, y_te, src_te = X_all[test_idx], y_all[test_idx], src_all[test_idx] + + # Split train into train/val + tr_idx, va_idx = train_test_split( + np.arange(len(X_tr)), test_size=0.15, stratify=y_tr, random_state=42 + ) + + # Extract per-domain data + def get_domain(X, y, src, idx, domain): + mask = src == domain + indices = np.intersect1d(np.where(mask)[0], idx) + return X[indices], y[indices] + + Xtr_p3, ytr_p3 = get_domain(X_tr, y_tr, src_tr, tr_idx, "P3") + Xtr_avo, ytr_avo = get_domain(X_tr, y_tr, src_tr, tr_idx, "AVO") + Xva_p3, yva_p3 = get_domain(X_tr, y_tr, src_tr, va_idx, "P3") + Xva_avo, yva_avo = get_domain(X_tr, y_tr, src_tr, va_idx, "AVO") + + if len(Xtr_p3) == 0 or len(Xtr_avo) == 0: + print(" Skipping: insufficient training samples") + continue + + print(f" Train: P3={len(Xtr_p3)}, AVO={len(Xtr_avo)}") + print(f" Val: P3={len(Xva_p3)}, AVO={len(Xva_avo)}") + + # Train model + model = train_asmmd_model( + Xtr_p3, + ytr_p3, + Xva_p3, + yva_p3, + Xtr_avo, + ytr_avo, + Xva_avo, + yva_avo, + n_channels, + n_times, + seed=SEEDS[repeat], + ) + + # Evaluate on test set + def test_domain(domain_label): + mask = src_te == domain_label + if not np.any(mask): + return {"accuracy": 0.0, "auc": 0.5}, 0 + loader = make_loader(X_te[mask], y_te[mask]) + metrics = evaluate_model(model, loader, DEVICE) + return metrics, np.sum(mask) + + def make_loader(X, y): + return DataLoader( + TensorDataset(torch.FloatTensor(X), torch.LongTensor(y)), + batch_size=BATCH_SIZE, + shuffle=False, + ) + + m_p3, n_p3 = test_domain("P3") + m_avo, n_avo = test_domain("AVO") + + overall_acc = (m_p3["accuracy"] * n_p3 + m_avo["accuracy"] * n_avo) / ( + n_p3 + n_avo + 1e-8 + ) + + print( + f" Test: P3={m_p3['accuracy']:.3f} (n={n_p3}), AVO={m_avo['accuracy']:.3f} (n={n_avo})" + ) + + results.append( + { + "repeat": repeat + 1, + "fold": fold_idx + 1, + "p3_acc": m_p3["accuracy"], + "p3_auc": m_p3["auc"], + "avo_acc": m_avo["accuracy"], + "avo_auc": m_avo["auc"], + "overall_acc": overall_acc, + } + ) + + return pd.DataFrame(results) + + +# %% +# Execute Training +# ---------------- + +print("\nStarting AS-MMD Training with Nested Cross-Validation...") +print("=" * 60) + +results_df = run_nested_cv(X_all, y_all, src_all, COMMON_CHANNELS) + +# Print summary +print("\n" + "=" * 60) +print("RESULTS SUMMARY") +print("=" * 60) +print( + f"\nOverall Accuracy: {results_df['overall_acc'].mean():.4f} ± {results_df['overall_acc'].std():.4f}" +) +print("\nP3 Dataset:") +print( + f" Accuracy: {results_df['p3_acc'].mean():.4f} ± {results_df['p3_acc'].std():.4f}" +) +print(f" AUC: {results_df['p3_auc'].mean():.4f} ± {results_df['p3_auc'].std():.4f}") +print("\nAVO Dataset:") +print( + f" Accuracy: {results_df['avo_acc'].mean():.4f} ± {results_df['avo_acc'].std():.4f}" +) +print(f" AUC: {results_df['avo_auc'].mean():.4f} ± {results_df['avo_auc'].std():.4f}") +print("=" * 60) + +# Save results +results_df.to_csv("asmmd_results.csv", index=False) +print("\nResults saved to: asmmd_results.csv") + +# %% +# Key Takeaways +# ============= +# +# **Main Components of AS-MMD:** +# +# 1. **Classification Loss**: Standard cross-entropy on both datasets +# 2. **Mixup Regularization**: Interpolate between samples for better generalization +# 3. **MMD Alignment**: Match feature distributions across domains +# 4. **Prototype Alignment**: Pull small-domain features toward large-domain class centers +# 5. **Warmup Schedule**: Gradually introduce domain adaptation during training +# +# **When to Use This Method:** +# +# - You have limited data from your target domain +# - You have access to a related source domain (different equipment/site) +# - You want a single model that performs well on both domains +# - You need robust cross-dataset performance +# +# **Tips for Your Own Data:** +# +# - Verify channel names match between datasets (case-insensitive lowercasing helps) +# - Adjust BATCH_SIZE if memory is limited (try 16 or 32) +# - Increase MAX_EPOCHS if curves haven't plateaued +# - Tune MMD weight (0.2-0.5) and prototype weight (0.5-0.8) based on domain similarity +# - Use more CV folds (5-10) for final results +# +# **References:** +# +# - Chen, W., Delorme, A. (2025). Adaptive Split-MMD Training for Small-Sample Cross-Dataset P300 Classification. +# - Song et al. (2019). "EEGConformer: Convolutional Transformer for EEG Decoding" +# - Long et al. (2015). "Learning Transferable Features with Deep Adaptation Networks" + +# %% +# Next Steps +# ========== +# +# - Try different EEG components (e.g., N1, P2, N2 instead of P3) +# - Extend to multi-class classification (e.g., oddball paradigm variants) +# - Apply to other tasks (motor imagery, sleep staging, seizure detection) +# - Experiment with other backbones (ResNet, LSTM) instead of EEGConformer +# - Implement subject-independent vs. subject-specific models diff --git a/examples/eeg2025/tutorial_challenge_1.py b/examples/eeg2025/tutorial_challenge_1.py index b03cb4ef..4980155e 100644 --- a/examples/eeg2025/tutorial_challenge_1.py +++ b/examples/eeg2025/tutorial_challenge_1.py @@ -1,108 +1,78 @@ -""".. _challenge-1: +"""Challenge 1: Cross-Task Transfer Learning! +========================================== +.. _challenge-1: .. meta:: :html_theme.sidebar_secondary.remove: true - -Challenge 1: Cross-Task Transfer Learning! -========================================== - .. contents:: This example covers: :local: :depth: 2 - """ ###################################################################### - # .. image:: https://colab.research.google.com/assets/colab-badge.svg # :target: https://colab.research.google.com/github/eeg2025/startkit/blob/main/challenge_1.ipynb # :alt: Open In Colab - ###################################################################### - # Preliminary notes # ----------------- # Before we begin, I just want to make a deal with you, ok? # This is a community competition with a strong open-source foundation. # When I say open-source, I mean volunteer work. - # - # So, if you see something that does not work or could be improved, first, **please be kind**, and # we will fix it together on GitHub, okay? - # - # The entire decoding community will only go further when we stop # solving the same problems over and over again, and it starts working together. - ###################################################################### - # How can we use the knowledge from one EEG Decoding task into another? # --------------------------------------------------------------------- # Transfer learning is a widespread technique used in deep learning. It # uses knowledge learned from one source task/domain in another target # task/domain. It has been studied in depth in computer vision, natural # language processing, and speech, but what about EEG brain decoding? - # - # The cross-task transfer learning scenario in EEG decoding is remarkably # underexplored compared to the development of new models, # `Aristimunha et al. (2023) `__, even # though it can be much more useful for real applications, see # `Wimpff et al. (2025) `__, # `Wu et al. (2025) `__. - # - # Our Challenge 1 addresses a key goal in neurotechnology: decoding # cognitive function from EEG using the pre-trained knowledge from another. # In other words, developing models that can effectively # transfer/adapt/adjust/fine-tune knowledge from passive EEG tasks to # active tasks. - # - # The ability to generalize and transfer is something critical that we # believe should be focused on. To go beyond just comparing metrics numbers # that are often not comparable, given the specificities of EEG, such as # pre-processing, inter-subject variability, and many other unique # components of this type of data. - # - # This means your submitted model might be trained on a subset of tasks # and fine-tuned on data from another condition, evaluating its capacity to # generalize with task-specific fine-tuning. - ###################################################################### - # __________ - # - # Note: For simplicity purposes, we will only show how to do the decoding # directly in our target task, and it is up to the teams to think about # how to use the passive task to perform the pre-training. - ####################################################################### - # Install dependencies # -------------------- # For the challenge, we will need two significant dependencies: # `braindecode` and `eegdash`. The libraries will install PyTorch, # Pytorch Audio, Scikit-learn, MNE, MNE-BIDS, and many other packages # necessary for the many functions. - # - # Install dependencies on colab or your local machine, as eegdash # have braindecode as a dependency. # you can just run ``pip install eegdash``. - ###################################################################### - # Imports and setup # ----------------- from pathlib import Path @@ -125,12 +95,9 @@ from joblib import Parallel, delayed ###################################################################### - # Check GPU availability # ---------------------- - # - # Identify whether a CUDA-enabled GPU is available # and set the device accordingly. # If using Google Colab, ensure that the runtime is set to use a GPU. @@ -147,153 +114,98 @@ "selecting 'T4 GPU'\nunder 'Hardware accelerator'." ) print(msg) - ###################################################################### - # What are we decoding? # --------------------- - # - # To start to talk about what we want to analyse, the important thing # is to understand some basic concepts. - # - ###################################################################### - # The brain decodes the problem # ----------------------------- - # - # Broadly speaking, here *brain decoding* is the following problem: # given brain time-series signals :math:`X \in \mathbb{R}^{C \times T}` with # labels :math:`y \in \mathcal{Y}`, we implement a neural network :math:`f` that # **decodes/translates** brain activity into the target label. - # - # We aim to translate recorded brain activity into its originating # stimulus, behavior, or mental state, `King, J-R. et al. (2020) `__. - # - # The neural network :math:`f` applies a series of transformation layers # (e.g., ``torch.nn.Conv2d``, ``torch.nn.Linear``, ``torch.nn.ELU``, ``torch.nn.BatchNorm2d``) # to the data to filter, extract features, and learn embeddings # relevant to the optimization objective—in other words: - # - # .. math:: - # - # f_{\theta}: X \to y, - # - # where :math:`C` (``n_chans``) is the number of channels/electrodes and :math:`T` (``n_times``) # is the temporal window length/epoch size over the interval of interest. # Here, :math:`\theta` denotes the parameters learned by the neural network. - # - # Input/Output definition # --------------------------- # For the competition, the HBN-EEG (Healthy Brain Network EEG Datasets) # dataset has ``n_chans = 129`` with the last channels as a `reference channel `_, # and we define the window length as ``n_times = 200``, corresponding to 2-second windows. - # - # Your model should follow this definition exactly; any specific selection of channels, # filtering, or domain-adaptation technique must be performed **within the layers of the neural network model**. - # - # In this tutorial, we will use the ``EEGNeX`` model from ``braindecode`` as an example. # You can use any model you want, as long as it follows the input/output # definitions above. - ###################################################################### - # Understand the task: Contrast Change Detection (CCD) # -------------------------------------------------------- # If you are interested to get more neuroscience insight, we recommend these two references, `HBN-EEG `__ and `Langer, N et al. (2017) `__. # Your task (**label**) is to predict the response time for the subject during this windows. - # - # In the Video, we have an example of recording cognitive activity: - # - # The Contrast Change Detection (CCD) task relates to # `Steady-State Visual Evoked Potentials (SSVEP) `__ # and `Event-Related Potentials (ERP) `__. - # - # Algorithmically, what the subject sees during recording is: - # - # * Two flickering striped discs: one tilted left, one tilted right. # * After a variable delay, **one disc's contrast gradually increases** **while the other decreases**. # * They **press left or right** to indicate which disc got stronger. # * They receive **feedback** (🙂 correct / 🙁 incorrect). - # - # **The task parallels SSVEP and ERP:** - # - # * The continuous flicker **tags the EEG at fixed frequencies (and harmonics)** → SSVEP-like signals. # * The **ramp onset**, the **button press**, and the **feedback** are **time-locked events** that yield ERP-like components. - # - # Your task (**label**) is to predict the response time for the subject during this windows. - # - ####################################################################### - # In the figure below, we have the timeline representation of the cognitive task: - # - # .. image:: https://eeg2025.github.io/assets/img/image-2.jpg - ###################################################################### - # Stimulus demonstration # ---------------------- # .. raw:: html - # - #
# #
- # - ###################################################################### - # PyTorch Dataset for the competition # ----------------------------------- # Now, we have a Pytorch Dataset object that contains the set of recordings for the task # `contrastChangeDetection`. - # - from eegdash.dataset import EEGChallengeDataset from eegdash.hbn.windows import ( annotate_trials_with_target, @@ -305,7 +217,6 @@ # Match tests' cache layout under ~/eegdash_cache/eeg_challenge_cache DATA_DIR = (Path.home() / "eegdash_cache" / "eeg_challenge_cache").resolve() DATA_DIR.mkdir(parents=True, exist_ok=True) - dataset_ccd = EEGChallengeDataset( task="contrastChangeDetection", release="R5", cache_dir=DATA_DIR, mini=True ) @@ -315,56 +226,38 @@ print( f"Number of unique subjects in the dataset: {dataset_ccd.description['subject'].nunique()}" ) - # - # This dataset object have very rich Raw object details that can help you to # understand better the data. The framework behind this is braindecode, # and if you want to understand in depth what is happening, we recommend the # braindecode github itself. - # - # We can also access the Raw object for visualization purposes, we will see just one object. raw = dataset_ccd.datasets[0].raw # get the Raw object of the first recording # And to download all the data all data directly, you can do: raws = Parallel(n_jobs=-1)(delayed(lambda d: d.raw)(d) for d in dataset_ccd.datasets) - ###################################################################### - # Alternatives for Downloading the data # ------------------------------------- - # - # You can also perform this operation with wget or the aws cli. # These options will probably be faster! # Please check more details in the `HBN` data webpage `HBN-EEG `__. # You need to download the 100Hz preprocessed data in BDF format. - # - # Example of wget for release R1 # wget https://sccn.ucsd.edu/download/eeg2025/R1_L100_bdf.zip -O R1_L100_bdf.zip - # - # Example of AWS CLI for release R1 - # - # aws s3 sync s3://nmdatasets/NeurIPS25/R1_L100_bdf data/R1_L100_bdf --no-sign-request - ###################################################################### - # Create windows of interest # ----------------------------- # So we epoch after the stimulus moment with a beginning shift of 500 ms. - EPOCH_LEN_S = 2.0 SFREQ = 100 # by definition here - transformation_offline = [ Preprocessor( annotate_trials_with_target, @@ -377,14 +270,11 @@ Preprocessor(add_aux_anchors, apply_on_array=False), ] preprocess(dataset_ccd, transformation_offline, n_jobs=1) - ANCHOR = "stimulus_anchor" SHIFT_AFTER_STIM = 0.5 WINDOW_LEN = 2.0 - # Keep only recordings that actually contain stimulus anchors dataset = keep_only_recordings_with(ANCHOR, dataset_ccd) - # Create single-interval windows (stim-locked, long enough to include the response) single_windows = create_windows_from_events( dataset, @@ -395,7 +285,6 @@ window_stride_samples=SFREQ, preload=True, ) - # Injecting metadata into the extra mne annotation. single_windows = add_extras_columns( single_windows, @@ -411,16 +300,13 @@ "response_type", ), ) - ###################################################################### - # Inspect the label distribution # ------------------------------- import numpy as np from skorch.helper import SliceDataset y_label = np.array(list(SliceDataset(single_windows, 1))) - # Plot histogram of the response times with matplotlib import matplotlib.pyplot as plt @@ -432,43 +318,34 @@ plt.tight_layout() plt.show() - ###################################################################### - # Split the data # --------------- # Extract meta information meta_information = single_windows.get_metadata() - valid_frac = 0.1 test_frac = 0.1 seed = 2025 - subjects = meta_information["subject"].unique() - train_subj, valid_test_subject = train_test_split( subjects, test_size=(valid_frac + test_frac), random_state=check_random_state(seed), shuffle=True, ) - valid_subj, test_subj = train_test_split( valid_test_subject, test_size=test_frac, random_state=check_random_state(seed + 1), shuffle=True, ) - # Sanity check assert (set(valid_subj) | set(test_subj) | set(train_subj)) == set(subjects) - # Create train/valid/test splits for the windows subject_split = single_windows.split("subject") train_set = [] valid_set = [] test_set = [] - for s in subject_split: if s in train_subj: train_set.append(subject_split[s]) @@ -476,24 +353,19 @@ valid_set.append(subject_split[s]) elif s in test_subj: test_set.append(subject_split[s]) - train_set = BaseConcatDataset(train_set) valid_set = BaseConcatDataset(valid_set) test_set = BaseConcatDataset(test_set) - print("Number of examples in each split in the minirelease") print(f"Train:\t{len(train_set)}") print(f"Valid:\t{len(valid_set)}") print(f"Test:\t{len(test_set)}") - ###################################################################### - # Create dataloaders # ------------------- batch_size = 128 # Set num_workers to 0 to avoid multiprocessing issues in notebooks/tutorials num_workers = 0 - train_loader = DataLoader( train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers ) @@ -503,9 +375,7 @@ test_loader = DataLoader( test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers ) - ###################################################################### - # Build the model # ----------------- # For neural network models, **to start**, we suggest using `braindecode models `__ zoo. @@ -520,19 +390,15 @@ n_times=200, # 2 seconds sfreq=100, # sample frequency 100 Hz ) - print(model) model.to(device) - ###################################################################### - # Define training and validation functions # ------------------------------------------- # The rest is our classic PyTorch/torch lighting/skorch training pipeline, # you can use any training framework you want. # We provide a simple training and validation loop below. - # @@ -547,44 +413,35 @@ def train_one_epoch( print_batch_stats: bool = True, ): model.train() - total_loss = 0.0 sum_sq_err = 0.0 n_samples = 0 - progress_bar = tqdm( enumerate(dataloader), total=len(dataloader), disable=not print_batch_stats ) - for batch_idx, batch in progress_bar: # Support datasets that may return (X, y) or (X, y, ...) X, y = batch[0], batch[1] X, y = X.to(device).float(), y.to(device).float() - optimizer.zero_grad(set_to_none=True) preds = model(X) loss = loss_fn(preds, y) loss.backward() optimizer.step() - total_loss += loss.item() - # Flatten to 1D for regression metrics and accumulate squared error preds_flat = preds.detach().view(-1) y_flat = y.detach().view(-1) sum_sq_err += torch.sum((preds_flat - y_flat) ** 2).item() n_samples += y_flat.numel() - if print_batch_stats: running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5 progress_bar.set_description( f"Epoch {epoch}, Batch {batch_idx + 1}/{len(dataloader)}, " f"Loss: {loss.item():.6f}, RMSE: {running_rmse:.6f}" ) - if scheduler is not None: scheduler.step() - avg_loss = total_loss / len(dataloader) rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5 return avg_loss, rmse @@ -599,46 +456,37 @@ def valid_model( print_batch_stats: bool = True, ): model.eval() - total_loss = 0.0 sum_sq_err = 0.0 n_batches = len(dataloader) n_samples = 0 - iterator = tqdm( enumerate(dataloader), total=n_batches, disable=not print_batch_stats ) - for batch_idx, batch in iterator: # Supports (X, y) or (X, y, ...) X, y = batch[0], batch[1] X, y = X.to(device).float(), y.to(device).float() - preds = model(X) batch_loss = loss_fn(preds, y).item() total_loss += batch_loss - preds_flat = preds.detach().view(-1) y_flat = y.detach().view(-1) sum_sq_err += torch.sum((preds_flat - y_flat) ** 2).item() n_samples += y_flat.numel() - if print_batch_stats: running_rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5 iterator.set_description( f"Val Batch {batch_idx + 1}/{n_batches}, " f"Loss: {batch_loss:.6f}, RMSE: {running_rmse:.6f}" ) - avg_loss = total_loss / n_batches if n_batches else float("nan") rmse = (sum_sq_err / max(n_samples, 1)) ** 0.5 - print(f"Val RMSE: {rmse:.6f}, Val Loss: {avg_loss:.6f}\n") return avg_loss, rmse ###################################################################### - # Train the model # ------------------ lr = 1e-3 @@ -647,32 +495,26 @@ def valid_model( 5 # For demonstration purposes, we use just 5 epochs here. You can increase this. ) early_stopping_patience = 50 - optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs - 1) loss_fn = torch.nn.MSELoss() - patience = 5 min_delta = 1e-4 best_rmse = float("inf") epochs_no_improve = 0 best_state, best_epoch = None, None - for epoch in range(1, n_epochs + 1): print(f"Epoch {epoch}/{n_epochs}: ", end="") - train_loss, train_rmse = train_one_epoch( train_loader, model, loss_fn, optimizer, scheduler, epoch, device ) val_loss, val_rmse = valid_model(test_loader, model, loss_fn, device) - print( f"Train RMSE: {train_rmse:.6f}, " f"Average Train Loss: {train_loss:.6f}, " f"Val RMSE: {val_rmse:.6f}, " f"Average Val Loss: {val_loss:.6f}" ) - if val_rmse < best_rmse - min_delta: best_rmse = val_rmse best_state = copy.deepcopy(model.state_dict()) @@ -685,12 +527,9 @@ def valid_model( f"Early stopping at epoch {epoch}. Best Val RMSE: {best_rmse:.6f} (epoch {best_epoch})" ) break - if best_state is not None: model.load_state_dict(best_state) - ###################################################################### - # Save the model # ----------------- torch.save(model.state_dict(), "weights_challenge_1.pt") diff --git a/examples/eeg2025/tutorial_challenge_2.py b/examples/eeg2025/tutorial_challenge_2.py index 5fb72993..114c22fe 100644 --- a/examples/eeg2025/tutorial_challenge_2.py +++ b/examples/eeg2025/tutorial_challenge_2.py @@ -1,65 +1,46 @@ -""".. _challenge-2: +"""Challenge 2: Predicting the p-factor from EEG +============================================= +.. _challenge-2: .. meta:: :html_theme.sidebar_secondary.remove: true - -Challenge 2: Predicting the p-factor from EEG -============================================= - .. contents:: This example covers: :local: :depth: 2 - """ - ###################################################################### - # .. image:: https://colab.research.google.com/assets/colab-badge.svg # :target: https://colab.research.google.com/github/eeg2025/startkit/blob/main/challenge_2.ipynb # :alt: Open In Colab - ###################################################################### - # Preliminary notes # ----------------- # Before we begin, I just want to make a deal with you, ok? # This is a community competition with a strong open-source foundation. # When I say open-source, I mean volunteer work. - # - # So, if you see something that does not work or could be improved, first, **please be kind**, and # we will fix it together on GitHub, okay? - # - # The entire decoding community will only go further when we stop # solving the same problems over and over again, and it starts working together. - ###################################################################### - # Overview # -------- # The psychopathology factor (P-factor) is a widely recognized construct in mental health research, representing a common underlying dimension of psychopathology across various disorders. # Currently, the P-factor is often assessed using self-report questionnaires or clinician ratings, which can be subjective, prone to bias, and time-consuming. # **The Challenge 2** consists of developing a model to predict the P-factor from EEG recordings. - # - # The challenge encourages learning physiologically meaningful signal representations and discovery of reproducible biomarkers. # Models of any size should emphasize robust, interpretable features that generalize across subjects, # sessions, and acquisition sites. - # - # Unlike a standard in-distribution classification task, this regression problem stresses out-of-distribution robustness # and extrapolation. The goal is not only to minimize error on seen subjects, but also to transfer effectively to unseen data. # Ensure the dataset is available locally. If not, see the # `dataset download guide `__. - ###################################################################### - # Contents of this start kit # -------------------------- # .. note:: If you need additional explanations on the @@ -68,50 +49,32 @@ # `braindecode `__'s # deep learning models, or brain decoding in general, please refer to the # start-kit of challenge 1 which delves deeper into these topics. - # - # More contents will be released during the competition inside the # :mod:`eegdash` `examples webpage `__. - # - # .. admonition:: Prerequisites # :class: important - # - # The tutorial assumes prior knowledge of: - # - # - Standard neural network architectures (e.g., CNNs) # - Optimization by batch gradient descent and backpropagation # - Overfitting, early stopping, and regularization # - Some knowledge of PyTorch # - Basic familiarity with EEG and preprocessing # - An appreciation for open-source work :) - ###################################################################### - # Install dependencies on Colab # ----------------------------- - # - # .. note:: These installs are optional; skip on local environments # where you already have the dependencies installed. - # - # .. code-block:: bash - # - # pip install eegdash - ###################################################################### - # Imports # ------- from pathlib import Path @@ -119,7 +82,6 @@ import os import random from joblib import Parallel, delayed - import torch from torch.utils.data import DataLoader from torch import optim @@ -130,15 +92,12 @@ from eegdash import EEGChallengeDataset ###################################################################### - # .. warning:: # In case of Colab, before starting, make sure you're on a GPU instance # for faster training! If running on Google Colab, please request a GPU runtime # by clicking `Runtime/Change runtime type` in the top bar menu, then selecting # 'T4 GPU' under 'Hardware accelerator'. - ###################################################################### - # Identify whether a CUDA-enabled GPU is available # ------------------------------------------------ device = "cuda" if torch.cuda.is_available() else "cpu" @@ -152,49 +111,36 @@ "selecting 'T4 GPU'\nunder 'Hardware accelerator'." ) print(msg) - ###################################################################### - # Understanding the P-factor regression task. # ------------------------------------------- - # - # The psychopathology factor (P-factor) is a widely recognized construct in mental health research, representing a common underlying dimension of psychopathology across various disorders. # The P-factor is thought to reflect the shared variance among different psychiatric conditions, suggesting that individuals with higher P-factor scores may be more vulnerable to a range of mental health issues. # Currently, the P-factor is often assessed using self-report questionnaires or clinician ratings, which can be subjective, prone to bias, and time-consuming. # In the dataset of this challenge, the P-factor was assessed using the Child # Behavior Checklist (CBCL) `McElroy et al., (2017) `__. - # - # The goal of Challenge 2 is to develop a model to predict the P-factor from EEG recordings. # **The feasibility of using EEG data for this purpose is still an open question**. # The solution may involve finding meaningful representations of the EEG data that correlate with the P-factor scores. # The challenge encourages learning physiologically meaningful signal representations and discovery of reproducible biomarkers. # If contestants are successful in this task, it could pave the way for more objective and efficient assessments of the P-factor in clinical settings. - ###################################################################### - # Define local path and (down)load the data # ----------------------------------------- # In this challenge 2 example, we load the EEG 2025 release using # :doc:`EEGChallengeDataset `. # **Note:** in this example notebook, we load the contrast change detection task from one mini release only as an example. Naturally, you are encouraged to train your models on all complete releases, using data from all the tasks you deem relevant. - ###################################################################### - # The first step is to define the cache folder! # Match tests' cache layout under ~/eegdash_cache/eeg_challenge_cache DATA_DIR = (Path.home() / "eegdash_cache" / "eeg_challenge_cache").resolve() - # Creating the path if it does not exist DATA_DIR.mkdir(parents=True, exist_ok=True) - # We define the list of releases to load. # Here, only release 5 is loaded. release_list = ["R5"] - all_datasets_list = [ EEGChallengeDataset( release=release, @@ -216,49 +162,33 @@ ] print("Datasets loaded") sub_rm = ["NDARWV769JM7"] - ###################################################################### - # Combine the datasets into a single one # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Here, we combine the datasets from the different releases into a single # ``BaseConcatDataset`` object. - # %% all_datasets = BaseConcatDataset(all_datasets_list) print(all_datasets.description) - raws = Parallel(n_jobs=os.cpu_count())( delayed(lambda d: d.raw)(d) for d in all_datasets.datasets ) - ###################################################################### - # Inspect your data # ----------------- # We can check what is inside the dataset consuming the # MNE-object inside the Braindecode dataset. - # - # The following snippet, if uncommented, will show the first 10 seconds of the raw EEG signal. # We can also inspect the data further by looking at the events and annotations. # We strongly recommend you to take a look into the details and check how the events are structured. - ###################################################################### - raw = all_datasets.datasets[0].raw # mne.io.Raw object - print(raw.info) - raw.plot(duration=10, scalings="auto", show=True) - print(raw.annotations) - SFREQ = 100 - ###################################################################### - # Wrap the data into a PyTorch-compatible dataset # -------------------------------------------------- # The class below defines a dataset wrapper that will extract 2-second windows, @@ -278,11 +208,9 @@ def __len__(self): def __getitem__(self, index): X, _, crop_inds = self.dataset[index] - # P-factor label: p_factor = self.dataset.description["p_factor"] p_factor = float(p_factor) - # Additional information: infos = { "subject": self.dataset.description["subject"], @@ -292,7 +220,6 @@ def __getitem__(self, index): "session": self.dataset.description.get("session", None) or "", "run": self.dataset.description.get("run", None) or "", } - # Randomly crop the signal to the desired length: i_window_in_trial, i_start, i_stop = crop_inds assert i_stop - i_start >= self.crop_size_samples, f"{i_stop=} {i_start=}" @@ -300,12 +227,10 @@ def __getitem__(self, index): i_start = i_start + start_offset i_stop = i_start + self.crop_size_samples X = X[:, start_offset : start_offset + self.crop_size_samples] - return X, p_factor, (i_window_in_trial, i_start, i_stop), infos # We filter out certain recordings, create fixed length windows and finally make use of our `DatasetWrapper`. - # %% # Filter out recordings that are too short all_datasets = BaseConcatDataset( @@ -318,7 +243,6 @@ def __getitem__(self, index): and not math.isnan(ds.description["p_factor"]) ] ) - # Create 4-seconds windows with 2-seconds stride windows_ds = create_fixed_length_windows( all_datasets, @@ -326,25 +250,19 @@ def __getitem__(self, index): window_stride_samples=2 * SFREQ, drop_last_window=True, ) - # Wrap each sub-dataset in the windows_ds windows_ds = BaseConcatDataset( [DatasetWrapper(ds, crop_size_samples=2 * SFREQ) for ds in windows_ds.datasets] ) - ###################################################################### - # Inspect the label distribution # ------------------------------- - # - import numpy as np from skorch.helper import SliceDataset y_label = np.array(list(SliceDataset(windows_ds, 1))) - # Plot histogram of the response times with matplotlib import matplotlib.pyplot as plt @@ -355,70 +273,51 @@ def __getitem__(self, index): ax.set_ylabel("Count") plt.tight_layout() plt.show() - ###################################################################### - # Define, train and save a model # --------------------------------- # Now we have our pytorch dataset necessary for the training! - # - # Below, we define a simple EEGNeX model from Braindecode. # All the braindecode models expect the input to be of shape (batch_size, n_channels, n_times) # and have a test coverage about the behavior of the model. # However, you can use any pytorch model you want. - # - ###################################################################### - # Initialize model # ----------------- - model = EEGNeX(n_chans=129, n_outputs=1, n_times=2 * SFREQ).to(device) - # Specify optimizer optimizer = optim.Adamax(params=model.parameters(), lr=0.002) - print(model) - # Finally, we can train our model. Here we define a simple training loop using pure PyTorch. # In this example, we only train for a single epoch. Feel free to increase the number of epochs. # Create PyTorch Dataloader - num_workers = ( 0 # Set num_workers to 0 to avoid multiprocessing issues in notebooks/tutorials. ) dataloader = DataLoader( windows_ds, batch_size=128, shuffle=True, num_workers=num_workers ) - n_epochs = 1 - # Train model for 1 epoch for epoch in range(n_epochs): for idx, batch in enumerate(dataloader): # Reset gradients optimizer.zero_grad() - # Unpack the batch X, y, crop_inds, infos = batch X = X.to(dtype=torch.float32, device=device) y = y.to(dtype=torch.float32, device=device).unsqueeze(1) - # Forward pass y_pred = model(X) - # Compute loss loss = l1_loss(y_pred, y) print(f"Epoch {0} - step {idx}, loss: {loss.item()}") - # Gradient backpropagation loss.backward() optimizer.step() - # Finally, we can save the model for later use torch.save(model.state_dict(), "weights_challenge_2.pt") print("Model saved as 'weights_challenge_2.pt'") diff --git a/pyproject.toml b/pyproject.toml index 10bb613d..9497bb35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ docs = [ "nbformat", "graphviz", "neato", + "moabb", # remove when all the moabb datasets are digested ] digestion = [