ProtBERT 与 ESM-2 的融合模型,用于肽/蛋白位点预测。本仓库实现了融合架构与 LoRA 参数高效微调,支持多种特征融合策略,并与 Hugging Face Trainer 无缝集成。
- ProtBERT + ESM-2 双编码器,包含投影与融合模块
- 融合策略:attention、concat、weighted_sum、add
- 在注意力模块上应用 LoRA,参数高效微调
- 基于 token 的分类头(按氨基酸位点标注)
- 健壮评估指标(F1、AUC),自动对特殊/填充符号做掩码
- 集成早停与 TensorBoard 可视化的 Trainer 流程
- 原论文聚焦于细胞穿膜肽(CPP)预测。本实现默认提供按位点(token-level)的标注流程。如需序列级 CPP 分类,请参见下文“改为序列级 CPP 任务”。
- 文档中的示例数据路径使用
saisdata/
下的 LIP 风格演示 pkl 文件,仅用于格式示例,请替换为你的实际数据集。
- Python >= 3.8(推荐 3.9–3.11)
- CUDA >= 11.7(若使用 GPU 训练)
pip install -r requirements.txt
若需使用 PyTorch 提供的 CUDA 轮子(wheels):
# CUDA 11.8
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt
python test_imports.py
每个 .pkl
文件包含一个字典列表:
[
{"sequence": "ACDE...", "label": [0, 1, 0, ...]}, # 按位点标签
...
]
python train.py \
--train_pkl saisdata/LIP_data_public.pkl \
--eval_pkl saisdata/LIP_data_test_A.pkl \
--output_dir ./models/fusion_model \
--epochs 15 \
--per_device_batch_size 4 \
--use_focal_loss --gradient_checkpointing --bf16
尝试其他融合策略:
--fusion_strategy attention|concat|weighted_sum|add
tensorboard --logdir ./models/fusion_model
若你的任务是序列级(整段序列是否为 CPP):
- 将按位点标签替换为每条序列的单一标量标签。
- 对特征做池化(如 CLS 向量或均值池化),并将分类器输出改为
[batch, num_labels]
。 - 更新
compute_metrics
,以处理二维预测与标签。
本仓库的模块化设计(dataset.py
、collator.py
、train.py
)使上述改动相对直接。
如使用本仓库,请引用原论文及本实现。
- FusPB-ESM2: Fusion model of ProtBERT and ESM-2 for cell-penetrating peptide prediction. Computational Biology and Chemistry (2024). PubMed
MIT License(见 LICENSE
)。
- Hugging Face Transformers、PEFT
- Rostlab/ProtBERT、Meta FAIR ESM-2