Skip to content

octal-zhihao/Semantic_Seg_Library

Repository files navigation

Pytorch-Lightning 语义分割库 🚀

一个基于 PyTorch Lightning 的轻量级、模块化语义分割框架,使用 VOC2012 数据集作为示例,同时也易于扩展到其他数据集。


这是一个功能全面、代码结构清晰的语义分割项目。它利用 PyTorch Lightning 简化了训练流程,并集成了多种主流的分割模型和数据增强策略,让你能快速上手并进行实验。

目录

✨ 项目特性

  • 高性能框架:基于 PyTorch Lightning 构建,代码简洁高效,易于维护和扩展。
  • 丰富的模型库
    • 经典模型:DeepLabV3, FCN (ResNet50/101 backbone).
    • U-Net 家族:支持多种 Encoder (ResNet, MobileNet, EfficientNet).
    • 前沿模型:Segformer, DeepLabV3+, Unet++.
  • 灵活的数据增强:内置 default, light, strong 三种级别的数据增强策略,满足不同实验需求.
  • 全面的训练流程:支持学习率动态调整 (CosineAnnealingLR)、早停 (EarlyStopping)、权重衰减 和模型断点续训。
  • 完善的评估与日志
    • 使用 torchmetrics 计算 mIoU,支持整体和按类别评估.
    • 集成 WandbLogger,轻松记录和可视化训练过程.
  • 易用的推理脚本:提供独立的脚本,用于在测试集上进行推理并生成可视化的分割结果掩码图.
  • 辅助工具:包含数据集类别分布分析工具,帮助你更好地理解数据.

🛠️ 安装与环境配置

  1. 克隆项目

    git clone https://github.com/octal-zhihao/semantic_seg_library.git
    cd semantic_seg_library
  2. 创建 Conda 环境并安装依赖

    conda create -n seg python=3.8
    conda activate seg
    pip install -r requirements.txt

    主要的依赖项包括 pytorch, pytorch-lightning, torchvision, torchmetrics, segmentation-models-pytorch, wandb, matplotlib, numpy 等。

⚡️ 快速开始

数据准备

  1. 下载 PASCAL VOC2012 数据集。
  2. 将数据集解压后放置于项目根目录下的 data/ 文件夹中。最终的目录结构应如下所示:
    root/
    ├── data/
    │   └── VOCdevkit/
    │       └── VOC2012/
    │           ├── JPEGImages/
    │           ├── SegmentationClass/
    │           └── ...
    ├── datasets/
    ├── model/
    └── main.py
    

模型训练

你可以通过 main.py 脚本启动训练。所有参数都可以通过命令行进行配置。

示例命令: 使用 unet_efficientnet-b0 作为骨干网络进行训练。

python main.py \
    --backbone unet_efficientnet-b0 \
    --batch_size 16 \
    --lr 1e-3 \
    --max_epochs 100 \
    --img_size 224 \
    --augment strong \
    --accelerator gpu \
    --devices 1

也可在main.py中设置默认参数后直接运行python main.py运行。

项目还提供了一个 train_all.sh 脚本,可以一键训练所有预设的模型,或自定义一次性运行多个模型。

bash train_all.sh

训练过程中, checkpoints 会保存在 mycheckpoints/ 目录下,日志会同步至 Wandb。

模型推理

使用 predict_test.py 脚本对 VOC2012 测试集进行推理。请确保已有一个训练好的模型 checkpoint。

示例命令:

python predict_test.py \
    --checkpoint "path/to/your/checkpoints" \
    --voc_root "data/VOCdevkit/VOC2012" \
    --out_dir "outputpath"

生成的彩色分割掩码图将保存在 --out_dir 指定的目录中。

⚙️ 配置项说明

以下是 main.py 脚本中常用的配置参数。

数据参数

  • --data_dir: 数据集根目录 (默认: ./data)。
  • --batch_size: 批处理大小 (默认: 16)。
  • --num_workers: 数据加载器的工作线程数 (默认: 8)。
  • --img_size: 图像缩放尺寸 (默认: 224)。
  • --augment: 数据增强策略,可选 light, strong, default (默认: default)。

模型参数

  • --backbone: 选择模型骨干网络。
    • deeplabv3_resnet50, deeplabv3_resnet101
    • fcn_resnet50, fcn_resnet101
    • lraspp
    • unet_resnet50, unet_resnet101, unet_mobilenet_v2, unet_efficientnet-b0, unet_efficientnet-b3
    • Segformer, DeepLabV3Plus, UnetPlusPlus (默认: unet_efficientnet-b0)
  • --num_classes: 类别数 (默认: 21)。

训练与优化器参数

  • --lr: 学习率 (默认: 1e-3)。
  • --weight_decay: 权重衰减 (默认: 1e-4)。
  • --T_max: CosineAnnealingLR 调度器的周期 (默认: 100)。
  • --eta_min: CosineAnnealingLR 的最小学习率 (默认: 1e-6)。
  • --max_epochs: 最大训练轮数 (默认: 100)。
  • --precision: 浮点数精度 (16 或 32, 默认: 32)。
  • --accelerator: 训练设备 cpu, gpu, mps (默认: gpu)。
  • --devices: 使用的设备数量 (默认: 1)。
  • --early_stop_patience: 早停的耐心值 (默认: 10)。

📄 许可协议

该项目采用 MIT License 许可协议。

📧 联系方式

如果你有任何问题或建议,欢迎通过以下方式联系我:

About

A benchmark algorithm library for semantic segmentation tasks.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published