Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
fa8faed
Add Photon model and pipeline support
Oct 8, 2025
64ddfe5
just store the T5Gemma encoder
Oct 9, 2025
2947da0
enhance_vae_properties if vae is provided only
Oct 9, 2025
2575997
remove autocast for text encoder forwad
Oct 9, 2025
27421cb
BF16 example
david-PHR Oct 9, 2025
1321ab4
conditioned CFG
Oct 10, 2025
32807a1
remove enhance vae and use vae.config directly when possible
Oct 10, 2025
117e835
move PhotonAttnProcessor2_0 in transformer_photon
Oct 10, 2025
c86aed2
remove einops dependency and now inherits from AttentionMixin
Oct 10, 2025
5f6359f
unify the structure of the forward block
Oct 10, 2025
3396143
update doc
Oct 10, 2025
c78f444
update doc
Oct 10, 2025
3f70395
fix T5Gemma loading from hub
Oct 10, 2025
d09ff3c
fix timestep shift
Oct 13, 2025
91486cf
remove lora support from doc
Oct 13, 2025
23dd181
Rename EmbedND for PhotoEmbedND
DavidBert Oct 13, 2025
7efad33
remove modulation dataclass
DavidBert Oct 13, 2025
ef9c48d
put _attn_forward and _ffn_forward logic in PhotonBlock's forward
DavidBert Oct 13, 2025
178cc6e
renam LastLayer for FinalLayer
DavidBert Oct 13, 2025
924643a
remove lora related code
DavidBert Oct 13, 2025
faa00b9
rename vae_spatial_compression_ratio for vae_scale_factor
DavidBert Oct 13, 2025
804dafd
support prompt_embeds in call
DavidBert Oct 13, 2025
6f90e41
move xattention conditionning out computation out of the denoising loop
DavidBert Oct 13, 2025
59f4bda
add negative prompts
DavidBert Oct 13, 2025
9ad5720
Use _import_structure for lazy loading
DavidBert Oct 13, 2025
027dbd5
make quality + style
DavidBert Oct 13, 2025
ff28f65
add pipeline test + corresponding fixes
DavidBert Oct 15, 2025
28b9cf2
utility function that determines the default resolution given the VAE
DavidBert Oct 15, 2025
b596595
Refactor PhotonAttention to match Flux pattern
DavidBert Oct 16, 2025
c522119
built-in RMSNorm
DavidBert Oct 16, 2025
3239f26
Revert accidental .gitignore change
DavidBert Oct 16, 2025
b7bbb04
parameter names match the standard diffusers conventions
DavidBert Oct 16, 2025
83e0396
renaming and remove unecessary attributes setting
DavidBert Oct 16, 2025
582b64a
Update docs/source/en/api/pipelines/photon.md
DavidBert Oct 16, 2025
33926e0
Update docs/source/en/api/pipelines/photon.md
DavidBert Oct 16, 2025
c9e0a20
Update docs/source/en/api/pipelines/photon.md
DavidBert Oct 16, 2025
2877b60
Update docs/source/en/api/pipelines/photon.md
DavidBert Oct 16, 2025
ed87475
quantization example
DavidBert Oct 16, 2025
8aa65ba
added doc to toctree
DavidBert Oct 16, 2025
fba7b33
Merge branch 'photon' of https://github.com/Photoroom/diffusers into …
DavidBert Oct 16, 2025
bef0845
use dispatch_attention_fn for multiple attention backend support
DavidBert Oct 17, 2025
6ef3091
naming changes
DavidBert Oct 18, 2025
742d0d3
make fix copy
DavidBert Oct 20, 2025
a4f90d4
Update docs/source/en/api/pipelines/photon.md
DavidBert Oct 20, 2025
7fefe09
Add PhotonTransformer2DModel to TYPE_CHECKING imports
DavidBert Oct 20, 2025
4444379
make fix-copies
DavidBert Oct 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions scripts/convert_photon_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
import torch
from safetensors.torch import save_file


sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))

from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.pipelines.photon import PhotonPipeline

Expand Down Expand Up @@ -74,14 +71,12 @@ def create_parameter_mapping(depth: int) -> dict:
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"

# QK norm moved to attention module and renamed to match Attention's qk_norm structure
# Old: qk_norm.query_norm / qk_norm.key_norm -> New: norm_q / norm_k
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"

# K norm for text tokens moved to attention module
# Old: k_norm -> New: norm_added_k
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"

Expand Down Expand Up @@ -306,7 +301,7 @@ def main(args):
parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format")

parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file)"
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )"
)

parser.add_argument(
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@
"MusicLDMPipeline",
"OmniGenPipeline",
"PaintByExamplePipeline",
"PhotonPipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
Expand Down Expand Up @@ -1172,6 +1173,7 @@
MusicLDMPipeline,
OmniGenPipeline,
PaintByExamplePipeline,
PhotonPipeline,
PIAPipeline,
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@
LuminaNextDiT2DModel,
MochiTransformer3DModel,
OmniGenTransformer2DModel,
PhotonTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
QwenImageTransformer2DModel,
Expand Down
100 changes: 53 additions & 47 deletions src/diffusers/models/transformers/transformer_photon.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor, nn
Expand All @@ -21,7 +21,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_processor import Attention
from ..attention_dispatch import dispatch_attention_fn
from ..embeddings import get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
Expand All @@ -35,7 +35,7 @@ def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, dev
r"""
Generates 2D patch coordinate indices for a batch of images.

Parameters:
Args:
batch_size (`int`):
Number of images in the batch.
height (`int`):
Expand Down Expand Up @@ -63,7 +63,7 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor:
r"""
Applies rotary positional embeddings (RoPE) to a query tensor.

Parameters:
Args:
xq (`torch.Tensor`):
Input tensor of shape `(..., dim)` representing the queries.
freqs_cis (`torch.Tensor`):
Expand All @@ -82,11 +82,12 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor:

class PhotonAttnProcessor2_0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we write it in a fashion similar to

?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second this suggestion - in particular, I think it would be more in line with other diffusers models implementations to reuse the layers defined in Attention, such as to_q/to_k/to_v, etc. instead of defining them in PhotonBlock (e.g. PhotonBlock.img_qkv_proj), and to keep the entire attention implementation in the PhotonAttnProcessor2_0 class.

Attention supports stuff like QK norms and fusing projections, so that could potentially be reused as well. If you need some custom logic not found in Attention, you could potentially add it in there or create a new Attention-style class like Flux does:

class FluxAttention(torch.nn.Module, AttentionModuleMixin):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the change and updated both the conversion script and the checkpoints on the hub.

r"""
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Properly integrates with
diffusers Attention module while handling Photon-specific logic.
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
"""

_attention_backend = None
_parallel_config = None

def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
Expand All @@ -104,7 +105,7 @@ def __call__(
"""
Apply Photon attention using PhotonAttention module.

Parameters:
Args:
attn: PhotonAttention module containing projection layers
hidden_states: Image tokens [B, L_img, D]
encoder_hidden_states: Text tokens [B, L_txt, D]
Expand All @@ -113,9 +114,7 @@ def __call__(
"""

if encoder_hidden_states is None:
raise ValueError(
"PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens."
)
raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")

# Project image tokens to Q, K, V
img_qkv = attn.img_qkv_proj(hidden_states)
Expand Down Expand Up @@ -164,14 +163,24 @@ def __call__(
joint_mask = torch.cat([attention_mask, ones_img], dim=-1)
attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1)

# Apply scaled dot-product attention
attn_output = torch.nn.functional.scaled_dot_product_attention(
img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask_tensor
# Apply attention using dispatch_attention_fn for backend support
# Reshape to match dispatch_attention_fn expectations: [B, L, H, D]
query = img_q.transpose(1, 2) # [B, L_img, H, D]
key = k.transpose(1, 2) # [B, L_txt + L_img, H, D]
value = v.transpose(1, 2) # [B, L_txt + L_img, H, D]

attn_output = dispatch_attention_fn(
query,
key,
value,
attn_mask=attn_mask_tensor,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)

# Reshape from [B, H, L_img, D] to [B, L_img, H*D]
batch_size, num_heads, seq_len, head_dim = attn_output.shape
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim)
# Reshape from [B, L_img, H, D] to [B, L_img, H*D]
batch_size, seq_len, num_heads, head_dim = attn_output.shape
attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim)

# Apply output projection
attn_output = attn.to_out[0](attn_output)
Expand All @@ -183,8 +192,8 @@ def __call__(

class PhotonAttention(nn.Module, AttentionModuleMixin):
r"""
Photon-style attention module that handles multi-source tokens and RoPE.
Similar to FluxAttention but adapted for Photon's architecture.
Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
Photon's architecture.
"""

_default_processor_cls = PhotonAttnProcessor2_0
Expand Down Expand Up @@ -242,14 +251,14 @@ def forward(


# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class PhotoEmbedND(nn.Module):
class PhotonEmbedND(nn.Module):
r"""
N-dimensional rotary positional embedding.

This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding
dimension. The embeddings are combined and returned as a single tensor

Parameters:
Args:
dim (int):
Base embedding dimension (must be even).
theta (int):
Expand All @@ -258,7 +267,7 @@ class PhotoEmbedND(nn.Module):
List of embedding dimensions for each axis (each must be even).
"""

def __init__(self, dim: int, theta: int, axes_dim: list[int]):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
Expand Down Expand Up @@ -288,7 +297,7 @@ class MLPEmbedder(nn.Module):
r"""
A simple 2-layer MLP used for embedding inputs.

Parameters:
Args:
in_dim (`int`):
Dimensionality of the input features.
hidden_dim (`int`):
Expand Down Expand Up @@ -316,7 +325,7 @@ class Modulation(nn.Module):
Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
two tuples `(shift, scale, gate)`.

Parameters:
Args:
dim (`int`):
Dimensionality of the input vector. The output will have `6 * dim` features internally.

Expand All @@ -340,7 +349,7 @@ class PhotonBlock(nn.Module):
r"""
Multimodal transformer block with text–image cross-attention, modulation, and MLP.

Parameters:
Args:
hidden_size (`int`):
Dimension of the hidden representations.
num_heads (`int`):
Expand Down Expand Up @@ -421,7 +430,7 @@ def forward(
r"""
Runs modulation-gated cross-attention and MLP, with residual connections.

Parameters:
Args:
hidden_states (`torch.Tensor`):
Image tokens of shape `(B, L_img, hidden_size)`.
encoder_hidden_states (`torch.Tensor`):
Expand Down Expand Up @@ -468,7 +477,7 @@ class FinalLayer(nn.Module):
This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level
outputs.

Parameters:
Args:
hidden_size (`int`):
Dimensionality of the input tokens.
patch_size (`int`):
Expand Down Expand Up @@ -505,7 +514,7 @@ def img2seq(img: Tensor, patch_size: int) -> Tensor:
r"""
Flattens an image tensor into a sequence of non-overlapping patches.

Parameters:
Args:
img (`torch.Tensor`):
Input image tensor of shape `(B, C, H, W)`.
patch_size (`int`):
Expand All @@ -523,7 +532,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor:
r"""
Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).

Parameters:
Args:
seq (`torch.Tensor`):
Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W //
patch_size)`.
Expand All @@ -550,7 +559,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
r"""
Transformer-based 2D model for text to image generation.

Parameters:
Args:
in_channels (`int`, *optional*, defaults to 16):
Number of input channels in the latent image.
patch_size (`int`, *optional*, defaults to 2):
Expand Down Expand Up @@ -650,7 +659,7 @@ def __init__(

self.hidden_size = hidden_size
self.num_heads = num_heads
self.pe_embedder = PhotoEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
Expand Down Expand Up @@ -683,11 +692,10 @@ def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> T

def forward(
self,
image_latent: Tensor,
hidden_states: Tensor,
timestep: Tensor,
cross_attn_conditioning: Tensor,
micro_conditioning: Tensor,
cross_attn_mask: None | Tensor = None,
encoder_hidden_states: Tensor,
attention_mask: Optional[Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
Expand All @@ -697,16 +705,14 @@ def forward(
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.

Parameters:
image_latent (`torch.Tensor`):
Args:
hidden_states (`torch.Tensor`):
Input latent image tensor of shape `(B, C, H, W)`.
timestep (`torch.Tensor`):
Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
cross_attn_conditioning (`torch.Tensor`):
encoder_hidden_states (`torch.Tensor`):
Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
micro_conditioning (`torch.Tensor`):
Extra conditioning vector (currently unused, reserved for future use).
Copy link
Collaborator

@dg845 dg845 Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was removing micro_conditioning here (in bef0845) intentional? I think it would be fine to retain it and the transformer tests (specifically PhotonTransformerTests.prepare_dummy_input) also use this argument.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it was intentional, I removed it in the tests too.

cross_attn_mask (`torch.Tensor`, *optional*):
attention_mask (`torch.Tensor`, *optional*):
Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
attention_kwargs (`dict`, *optional*):
Additional arguments passed to attention layers.
Expand All @@ -719,15 +725,15 @@ def forward(
- `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
"""
# Process text conditioning
txt = self.txt_in(cross_attn_conditioning)
txt = self.txt_in(encoder_hidden_states)

# Convert image to sequence and embed
img = img2seq(image_latent, self.patch_size)
img = img2seq(hidden_states, self.patch_size)
img = self.img_in(img)

# Generate positional embeddings
bs, _, h, w = image_latent.shape
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device)
bs, _, h, w = hidden_states.shape
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device)
pe = self.pe_embedder(img_ids)

# Compute time embedding
Expand All @@ -742,20 +748,20 @@ def forward(
txt,
vec,
pe,
cross_attn_mask,
attention_mask,
)
else:
img = block(
hidden_states=img,
encoder_hidden_states=txt,
temb=vec,
image_rotary_emb=pe,
attention_mask=cross_attn_mask,
attention_mask=attention_mask,
)

# Final layer and convert back to image
img = self.final_layer(img, vec)
output = seq2img(img, self.patch_size, image_latent.shape)
output = seq2img(img, self.patch_size, hidden_states.shape)

if not return_dict:
return (output,)
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@
StableDiffusionXLPAGPipeline,
)
from .paint_by_example import PaintByExamplePipeline
from .photon import PhotonPipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .qwenimage import (
Expand Down
11 changes: 5 additions & 6 deletions src/diffusers/pipelines/photon/pipeline_photon.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ def clean_text(self, text: str) -> str:
>>> from diffusers import PhotonPipeline

>>> # Load pipeline with from_pretrained
>>> pipe = PhotonPipeline.from_pretrained("path/to/photon_checkpoint")
>>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft")
>>> pipe.to("cuda")

>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
>>> image.save("photon_output.png")
```
"""
Expand Down Expand Up @@ -717,11 +717,10 @@ def __call__(

# Forward through transformer
noise_pred = self.transformer(
image_latent=latents_in,
hidden_states=latents_in,
timestep=t_cont,
cross_attn_conditioning=ca_embed,
micro_conditioning=None,
cross_attn_mask=ca_mask,
encoder_hidden_states=ca_embed,
attention_mask=ca_mask,
return_dict=False,
)[0]

Expand Down
Loading