MedSegDiff - Pytorch
MedSegDiff is a cutting-edge Pytorch implementation for medical image segmentation. This implementation is based on a research paper titled "MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model." The project emerges from Baidu and utilizes a Diffusion Probabilistic Model (DDPM) with enhanced conditioning on the feature level, and it incorporates filtering of features in Fourier space.
Appreciation
The project acknowledges the generous sponsorship from StabilityAI and gives special thanks to contributors like Isamu and Daniel for their diligent work in adding a training script specifically for a skin lesion dataset.
Installation
To get started with the MedSegDiff, installation is straightforward. Users can simply use the following pip command in their terminal:
$ pip install med-seg-diff-pytorch
Usage
Here's a brief overview of how to use the MedSegDiff package:
import torch
from med_seg_diff_pytorch import Unet, MedSegDiff
# Define a Unet model
model = Unet(
dim = 64,
image_size = 128,
mask_channels = 1, # Segmentation has 1 channel
input_img_channels = 3, # Input images have 3 channels
dim_mults = (1, 2, 4, 8)
)
# Set up the MedSegDiff model
diffusion = MedSegDiff(
model,
timesteps = 1000
).cuda()
# Example segmented and input images
segmented_imgs = torch.rand(8, 1, 128, 128) # Normalized from 0 to 1
input_imgs = torch.rand(8, 3, 128, 128)
# Compute loss and backpropagate
loss = diffusion(segmented_imgs, input_imgs)
loss.backward()
# Prediction after training
pred = diffusion.sample(input_imgs) # Predict segmented images
pred.shape # (8, 3, 128, 128)
Training
To train a model using MedSegDiff, run the following command:
accelerate launch driver.py --mask_channels=1 --input_img_channels=3 --image_size=64 --data_path='./data' --dim=64 --epochs=100 --batch_size=1 --scale_lr --gradient_accumulation_steps=4
If users wish to use self-conditioning, where the current mask influences the results, they can add the flag --self_condition
.
Future Plans and Additions
- The project has incorporated basic training code, allowing for a Trainer to handle custom datasets, thanks to contributions from @isamu-isozaki.
- It also supports a full-blown transformer of any depth within the process, inspired by ideas from another research paper on simple diffusion.
Citations
For further reading and research references, the project refers to key papers:
@article{Wu2022MedSegDiffMI,
title = {MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model},
author = {Junde Wu and Huihui Fang and Yu Zhang and Yehui Yang and Yanwu Xu},
journal = {ArXiv},
year = {2022},
volume = {abs/2211.00611}
}
@inproceedings{Hoogeboom2023simpleDE,
title = {simple diffusion: End-to-end diffusion for high resolution images},
author = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
year = {2023}
}
MedSegDiff offers a promising approach in the field of medical image segmentation, harnessing the power of advanced machine learning techniques to provide high-accuracy results.