一个基于 PyTorch Lightning 的轻量级、模块化语义分割框架,使用 VOC2012 数据集作为示例,同时也易于扩展到其他数据集。
这是一个功能全面、代码结构清晰的语义分割项目。它利用 PyTorch Lightning 简化了训练流程,并集成了多种主流的分割模型和数据增强策略,让你能快速上手并进行实验。
- 高性能框架:基于 PyTorch Lightning 构建,代码简洁高效,易于维护和扩展。
- 丰富的模型库:
- 经典模型:DeepLabV3, FCN (ResNet50/101 backbone).
- U-Net 家族:支持多种 Encoder (ResNet, MobileNet, EfficientNet).
- 前沿模型:Segformer, DeepLabV3+, Unet++.
- 灵活的数据增强:内置
default,light,strong三种级别的数据增强策略,满足不同实验需求. - 全面的训练流程:支持学习率动态调整 (CosineAnnealingLR)、早停 (EarlyStopping)、权重衰减 和模型断点续训。
- 完善的评估与日志:
- 使用
torchmetrics计算 mIoU,支持整体和按类别评估. - 集成 WandbLogger,轻松记录和可视化训练过程.
- 使用
- 易用的推理脚本:提供独立的脚本,用于在测试集上进行推理并生成可视化的分割结果掩码图.
- 辅助工具:包含数据集类别分布分析工具,帮助你更好地理解数据.
-
克隆项目
git clone https://github.com/octal-zhihao/semantic_seg_library.git cd semantic_seg_library -
创建 Conda 环境并安装依赖
conda create -n seg python=3.8 conda activate seg pip install -r requirements.txt
主要的依赖项包括
pytorch,pytorch-lightning,torchvision,torchmetrics,segmentation-models-pytorch,wandb,matplotlib,numpy等。
- 下载 PASCAL VOC2012 数据集。
- 将数据集解压后放置于项目根目录下的
data/文件夹中。最终的目录结构应如下所示:root/ ├── data/ │ └── VOCdevkit/ │ └── VOC2012/ │ ├── JPEGImages/ │ ├── SegmentationClass/ │ └── ... ├── datasets/ ├── model/ └── main.py
你可以通过 main.py 脚本启动训练。所有参数都可以通过命令行进行配置。
示例命令:
使用 unet_efficientnet-b0 作为骨干网络进行训练。
python main.py \
--backbone unet_efficientnet-b0 \
--batch_size 16 \
--lr 1e-3 \
--max_epochs 100 \
--img_size 224 \
--augment strong \
--accelerator gpu \
--devices 1也可在main.py中设置默认参数后直接运行
python main.py运行。
项目还提供了一个 train_all.sh 脚本,可以一键训练所有预设的模型,或自定义一次性运行多个模型。
bash train_all.sh训练过程中, checkpoints 会保存在 mycheckpoints/ 目录下,日志会同步至 Wandb。
使用 predict_test.py 脚本对 VOC2012 测试集进行推理。请确保已有一个训练好的模型 checkpoint。
示例命令:
python predict_test.py \
--checkpoint "path/to/your/checkpoints" \
--voc_root "data/VOCdevkit/VOC2012" \
--out_dir "outputpath"生成的彩色分割掩码图将保存在 --out_dir 指定的目录中。
以下是 main.py 脚本中常用的配置参数。
--data_dir: 数据集根目录 (默认:./data)。--batch_size: 批处理大小 (默认: 16)。--num_workers: 数据加载器的工作线程数 (默认: 8)。--img_size: 图像缩放尺寸 (默认: 224)。--augment: 数据增强策略,可选light,strong,default(默认:default)。
--backbone: 选择模型骨干网络。deeplabv3_resnet50,deeplabv3_resnet101fcn_resnet50,fcn_resnet101lrasppunet_resnet50,unet_resnet101,unet_mobilenet_v2,unet_efficientnet-b0,unet_efficientnet-b3Segformer,DeepLabV3Plus,UnetPlusPlus(默认:unet_efficientnet-b0)
--num_classes: 类别数 (默认: 21)。
--lr: 学习率 (默认: 1e-3)。--weight_decay: 权重衰减 (默认: 1e-4)。--T_max: CosineAnnealingLR 调度器的周期 (默认: 100)。--eta_min: CosineAnnealingLR 的最小学习率 (默认: 1e-6)。--max_epochs: 最大训练轮数 (默认: 100)。--precision: 浮点数精度 (16 或 32, 默认: 32)。--accelerator: 训练设备cpu,gpu,mps(默认:gpu)。--devices: 使用的设备数量 (默认: 1)。--early_stop_patience: 早停的耐心值 (默认: 10)。
该项目采用 MIT License 许可协议。
如果你有任何问题或建议,欢迎通过以下方式联系我:
- GitHub Issues: https://github.com/octal-zhihao/semantic_seg_library/issues
- 邮箱: zhouzhihao9529@gmail.com (请替换为你的邮箱)