Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
91de4b3
Add RewardModel and PrefGRPORewardModel classes for reward computation
LouisRouss Sep 10, 2025
db14e8a
Refactor RewardModel and PrefGRPORewardModel to enhance image handlin…
LouisRouss Sep 11, 2025
02e1ca8
Add return_latents option to Diffuser's denoise method for latent rep…
LouisRouss Sep 13, 2025
2ff6348
Add attribute delegation and enhanced dir() support to Diffuser class
LouisRouss Sep 13, 2025
b54d0c7
Fix dtype argument in model initialization
LouisRouss Sep 13, 2025
7e60ff7
Add one_step_denoise_grpo method for GRPO training in Flow class
LouisRouss Sep 13, 2025
b0240f0
Refactor training classes to use a common trainer and reorganize impo…
LouisRouss Sep 16, 2025
f60fd65
Add GRPO support to Diffuser and Flow classes with new methods and ut…
LouisRouss Sep 16, 2025
2a3ffc8
Enhance RewardModel and PrefGRPORewardModel with n_image_per_prompt s…
LouisRouss Sep 16, 2025
43f6cd3
Add GRPO support with new BatchData structures and update training cl…
LouisRouss Sep 19, 2025
6b014f7
fix typing
LouisRouss Sep 19, 2025
fd4252a
fix loss calculation grpo flow
LouisRouss Sep 20, 2025
ba8f074
Refactor loss computation in Flow class to use a list for step-wise l…
LouisRouss Sep 20, 2025
10a823d
Refactor trainer imports and implement validation step in GRPOTrainer
LouisRouss Sep 20, 2025
7519c8c
Finish GRPO training loop and fix epoch level scheduler logic
LouisRouss Sep 20, 2025
2f1940e
Add clip in reward model
LouisRouss Sep 21, 2025
d442082
Implement StepResult and Sampler classes for diffusion process; add E…
LouisRouss Sep 23, 2025
693e403
adapt to abstraction sampler and clean GRPO logic
LouisRouss Sep 23, 2025
342982a
Refactor ContextEmbedder to implement properties for n_output and out…
LouisRouss Sep 23, 2025
f7c4200
Refactor PrefGRPORewardModel to standardize clip model ID usage and i…
LouisRouss Sep 23, 2025
dc033ab
Refactor sampler classes to standardize set_steps method for improved…
LouisRouss Sep 25, 2025
8f17901
Add DDIM and DDPM sampler implementations with step and parameter set…
LouisRouss Sep 25, 2025
2a9847a
Refactor Flow and EulerMaruyama classes for improved parameter handli…
LouisRouss Sep 25, 2025
c64db67
improve tensor handling and device compatibility in flow and euler me…
LouisRouss Sep 25, 2025
2d476d4
- Refactor model input handling in Diffuser, Flow, and GRPOTrainer cl…
LouisRouss Sep 27, 2025
d2425fb
Add a generic abstract sampler class over modelization specific sampl…
LouisRouss Sep 27, 2025
0b9e159
Refactor diffusion model classes to standardize sampler initializatio…
LouisRouss Sep 27, 2025
3c81ce9
update docstring
LouisRouss Sep 27, 2025
854d82a
Refactor denoise method signatures in Diffuser, Flow, and GaussianDif…
LouisRouss Sep 27, 2025
89d8c90
Allow MMDiT to use a context embedder without pooled embedding
LouisRouss Sep 28, 2025
d079027
- Add loguru dependency
LouisRouss Sep 28, 2025
4ae13d2
Refactor preprocess method in DinoV2
LouisRouss Sep 28, 2025
8838fba
Enhance input validation in encode method of DCAE class to support ad…
LouisRouss Sep 28, 2025
ee3d906
Implement DDT architecture and refactor modulation classes for enhanc…
LouisRouss Sep 29, 2025
562f1f3
add dinoV3 and precompute functions
LouisRouss Oct 1, 2025
2b8a642
Refactor SD3TextEmbedder to improve type casting and add attention ma…
LouisRouss Oct 5, 2025
5c24b79
Update step method docstring in Euler and EulerMaruyama classes to re…
LouisRouss Oct 22, 2025
20945ec
add dependencies
LouisRouss Oct 22, 2025
9960d67
improve attn unet
LouisRouss Oct 22, 2025
94c9c44
Refactor ContextEmbedder to use ContextEmbedderOutput for forward method
LouisRouss Oct 22, 2025
d333dc5
Enhance MMDiTAttention and MMDiTBlock to support attention masks and …
LouisRouss Oct 26, 2025
da43da7
- rename mask to attn mask in context embedder output
LouisRouss Oct 26, 2025
ee8f7d0
Update DDT to utilize ContextEmbedderOutput for improved context hand…
LouisRouss Oct 26, 2025
5aed411
use transformers instead of open clip
LouisRouss Oct 26, 2025
ba5028a
Add attention mask to U-Net and use torch scaled do product attn
LouisRouss Oct 26, 2025
59b1d1d
finish forward method of PrefGRPORewardModel
LouisRouss Oct 26, 2025
08f489b
fix torch stack xt_std
LouisRouss Oct 29, 2025
d7ab331
fix unet
LouisRouss Oct 29, 2025
a74f3b0
Add GRPO loss computation and update method signatures in Diffuser, F…
LouisRouss Oct 29, 2025
ce670ba
Remove 'local/' from Pyright include paths in pyproject.toml
LouisRouss Oct 29, 2025
42c1a64
Fix feature appending condition in DDT class to check for None
LouisRouss Oct 30, 2025
5f2e856
Fix default value for attn_mask in PreComputedEmbedder to ensure prop…
LouisRouss Oct 30, 2025
c38be13
fix reward model
LouisRouss Oct 30, 2025
f084799
Refactor encoding input range checks in DCAE class for clarity and ac…
LouisRouss Oct 30, 2025
f2840de
clean code - update docstring
LouisRouss Nov 2, 2025
f58555b
Merge branch 'main' into feature/PrefGRPO
LouisRouss Nov 2, 2025
cd1a40f
- remove pre computed embedder #TODO cleaner
LouisRouss Nov 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 1 addition & 34 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ share/python-wheels/
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

Expand Down Expand Up @@ -83,36 +81,12 @@ notebooks/
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
# PEP 582
__pypackages__/

# Celery stuff
Expand Down Expand Up @@ -155,13 +129,6 @@ dmypy.json
# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# running logs
examples/wandb
outputs/
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ Here is a To-Do list, feel welcome to help to any point along this list. The alr
- [ ] add some more context embedders
- [ ] add reflow algorithm
- [ ] add EDM
- [ ] think about how to add a sampler abstraction and use it in the different Diffusion classes (generalist class with euler, heuns etc)
- [x] think about how to add a sampler abstraction and use it in the different Diffusion classes (generalist class with euler, heuns etc)
- [ ] Train our models on toy datasets for different tasks (conditional generation, Image to Image ...)
- [ ] Add possibility to train LORA/DORA
- [ ] add different sampler
- [x] add different sampler
- [ ] Try out Differential Transformers
- [ ] Check to add https://arxiv.org/pdf/2406.02507
- [ ] inject lessons learned from nvidia https://developer.nvidia.com/blog/rethinking-how-to-train-diffusion-models/
9 changes: 6 additions & 3 deletions examples/train_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data import DataLoader

from diffulab.diffuse import Diffuser
from diffulab.training import Trainer
from diffulab.training import BaseTrainer


@hydra.main(version_base=None, config_path="../configs", config_name="train_mnist_flow_matching")
Expand Down Expand Up @@ -52,8 +52,7 @@ def count_parameters(model: torch.nn.Module) -> int:
params=denoiser.parameters(),
)

# TODO: add a run name for wandb
trainer = Trainer(
trainer = BaseTrainer(
n_epoch=cfg.trainer.n_epoch,
gradient_accumulation_step=cfg.trainer.gradient_accumulation_step,
precision_type=cfg.trainer.precision_type,
Expand All @@ -62,6 +61,10 @@ def count_parameters(model: torch.nn.Module) -> int:
ema_update_after_step=cfg.trainer.get("ema_update_after_step", 0),
ema_update_every=cfg.trainer.get("ema_update_every", 10),
run_config=OmegaConf.to_container(cfg, resolve=True), # type: ignore[reportArgumentType]
compile=cfg.trainer.get("compile", False),
init_kwargs={
"wandb": cfg.trainer.get("wandb", {}),
},
)

trainer.train(
Expand Down
4 changes: 2 additions & 2 deletions examples/train_repa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.data import DataLoader

from diffulab.diffuse import Diffuser
from diffulab.training import Trainer
from diffulab.training import BaseTrainer
from diffulab.training.losses.repa import RepaLoss


Expand Down Expand Up @@ -77,7 +77,7 @@ def count_parameters(model: torch.nn.Module) -> int:
+ list(repa_loss.resampler.parameters() if repa_loss.resampler else []),
)

trainer = Trainer(
trainer = BaseTrainer(
n_epoch=cfg.trainer.n_epoch,
gradient_accumulation_step=cfg.trainer.gradient_accumulation_step,
precision_type=cfg.trainer.precision_type,
Expand Down
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
[project]
name = "diffulab"
version = "0.1.0"
description = "Add your description here"
description = "DiffuLab is designed to provide a simple and flexible way to train diffusion models while allowing full customization of its core components"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"accelerate>=1.4.0",
"blobfile>=3.1.0",
"diffusers>=0.33.1",
"einops>=0.8.1",
"ema-pytorch>=0.7.7",
"hydra-core>=1.3.2",
"jaxtyping>=0.3.0",
"loguru>=0.7.3",
"mosaicml-streaming>=0.12.0",
"omegaconf>=2.3.0",
"open-clip-torch>=2.30.0",
"pyopenssl==23.2.0",
"sentencepiece>=0.2.1",
"tiktoken>=0.11.0",
"torch>=2.6.0",
"transformers>=4.49.0",
"wandb>=0.19.6",
Expand All @@ -31,6 +34,9 @@ dev = [
repa = [
"timm>=1.0.15",
]
prefgrpo = [
"qwen-vl-utils>=0.0.11",
]

[tool.uv.sources]
diffulab = {workspace = true}
Expand Down
4 changes: 3 additions & 1 deletion src/diffulab/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .datasets import BaseDataset, CIFAR10Dataset, ImageNetLatentREPA, MNISTDataset
from .diffuse import Diffuser, Flow, GaussianDiffusion
from .networks import DCAE, REPA, Denoiser, DinoV2, MMDiT, PerceiverResampler, SD3TextEmbedder, UNetModel, VisionTower
from .training import LossFunction, RepaLoss, Trainer
from .training import BaseTrainer, GRPOTrainer, LossFunction, RepaLoss, Trainer

__all__ = [
"BaseDataset",
Expand All @@ -22,5 +22,7 @@
"VisionTower",
"LossFunction",
"RepaLoss",
"BaseTrainer",
"GRPOTrainer",
"Trainer",
]
9 changes: 7 additions & 2 deletions src/diffulab/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
from torch import Tensor
from torch.utils.data import Dataset

from diffulab.networks.denoisers.common import ModelInput
from diffulab.networks.denoisers.common import ExtraInputGRPO, ModelInput, ModelInputGRPO


class BatchData(TypedDict, total=False):
model_inputs: Required[ModelInput]
extra: NotRequired[dict[str, Tensor | None]]
extra: NotRequired[dict[str, Tensor | list[str] | None]]


class BatchDataGRPO(TypedDict, total=False):
model_inputs: Required[ModelInputGRPO]
extra: Required[ExtraInputGRPO]


class BaseDataset(Dataset[BatchData], ABC):
Expand Down
31 changes: 27 additions & 4 deletions src/diffulab/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ def __init__(self, data_path: str, train: bool = True):
self.images, self.labels = self.load_data()

def load_data(self) -> tuple[NDArray[np.float32], NDArray[np.int64]]:
"""Load MNIST data from files."""
"""
Load MNIST data from files.

Returns:
A tuple of images and labels arrays.
"""
if self.train:
images_file = self.data_path / "train-images-idx3-ubyte"
labels_file = self.data_path / "train-labels-idx1-ubyte"
Expand All @@ -38,7 +43,13 @@ def load_data(self) -> tuple[NDArray[np.float32], NDArray[np.int64]]:
return images, labels

def _load_images(self, file: Path) -> NDArray[np.float32]:
"""Load and preprocess MNIST images."""
"""
Load and preprocess MNIST images.
Args:
file: Path to the MNIST images file.
Returns:
A numpy array of shape (num_images, 1, 32, 32) containing the resized images.
"""
with open(file, "rb") as f:
_, num_images, rows, cols = struct.unpack(">IIII", f.read(16))
images = np.frombuffer(f.read(), dtype=np.uint8).reshape(num_images, 1, rows, cols)
Expand All @@ -52,12 +63,24 @@ def _load_images(self, file: Path) -> NDArray[np.float32]:
return resized_images

def _load_labels(self, file: Path) -> NDArray[np.int64]:
"""Load MNIST labels."""
"""
Load MNIST labels.
Args:
file: Path to the MNIST labels file.
Returns:
A numpy array of shape (num_labels,) containing the labels.
"""
with open(file, "rb") as f:
_, _ = struct.unpack(">II", f.read(8))
labels = np.frombuffer(f.read(), dtype=np.uint8)
return labels.astype(np.int64)

def preprocess_image(self, image: NDArray[Any]) -> NDArray[np.float32]:
"""Normalize the image to [-1, 1] range."""
"""
Normalize the image to [-1, 1] range.
Args:
image: A numpy array representing the image.
Returns:
A normalized numpy array.
"""
return ((image.astype(np.float32) / 255.0) - 0.5) / 0.5
Loading
Loading