[AAAI-26 Oral] Official pytorch implementation of 'Ambiguity-aware Truncated Flow Matching for Ambiguous Medical Image Segmentation'
A simultaneous enhancement of accuracy and diversity of predictions remains a challenge in ambiguous medical image segmentation (AMIS) due to the inherent trade-offs. While truncated diffusion probabilistic models (TDPMs) hold strong potential with a paradigm optimization, existing TDPMs suffer from entangled accuracy and diversity of predictions with insufficient fidelity and plausibility. To address the aforementioned challenges, we propose Ambiguity-aware Truncated Flow Matching (ATFM), which introduces a novel inference paradigm and dedicated model components. Firstly, we propose Data-Hierarchical Inference, a redefinition of AMIS-specific inference paradigm, which enhances accuracy and diversity at data-distribution and data-sample level, respectively, for an effective disentanglement. Secondly, Gaussian Truncation Representation (GTR) is introduced to enhance both fidelity of predictions and reliability of truncation distribution, by explicitly modeling it as a Gaussian distribution at
You can build the dependencies by executing the following command
conda create -n FM python=3.9
source activate FM
pip install -r requirement.txt
Two public datasets: LIDC and ISIC Subset are implemented in this work. You can download the datasets from the following links:
- LIDC Dataset as preprocessed by @Stefan Knegt
- ISIC Subset as preprocessed by @killanzepf
Please modify the dataset paths accordingly in metadata_managr.py
- Step 1: Train Gaussian Truncation Representation for both datasets
python train_GTR.py --what LIDC --epochs 1000 --betchsize 256 --save_model True --save_model_step 50 python train_GTR.py --what isic3_style_concat --epochs 400 --batchsize 8 --save_model True --save_model_step 50 - Step 2: Train Segmentation Flow Matching based on the frozen GTR
python train_prior_LIDC.py python train_prior_ISIC.py
Note: The GPU memory consumption for ISIC Subset is 24198MiB (23.63GiB) even with batchsize=1. CUDA out-of-memory errors may occur when too much memory is reserved. You may use torch.utils.checkpoint to reduce memory usage:
pred = model(image) # Original forward
pred = torch.utils.checkpoint(model, image) # Forward with lower memory comsumption
python test_prior_LIDC.py
python test_prior_ISIC.py
Visualization of the predictions will show that ATFM produces a series of results with both high accuracy and high diversity, while fidelity and plausibility are simultaneously improved.

- We thank @killanzepf for the preprocessed dataset and the GTR baseline.
- We thank @aleksandrinvictor for the Flow Matching baseline.
- We thank @Stefan Knegt for the preprocessed dataset.
If you find this work helpful, please cite:
@article{li2025ambiguity,
title={Ambiguity-aware Truncated Flow Matching for Ambiguous Medical Image Segmentation},
author={Li, Fanding and Li, Xiangyu and Su, Xianghe and Qiu, Xingyu and Dong, Suyu and Wang, Wei and Wang, Kuanquan and Luo, Gongning and Li, Shuo},
journal={arXiv preprint arXiv:2511.06857},
year={2025}
}
The AAAI proceedings citation will be provided once available.
