-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Add Photon model and pipeline support #12456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
fa8faed
64ddfe5
2947da0
2575997
27421cb
1321ab4
32807a1
117e835
c86aed2
5f6359f
3396143
c78f444
3f70395
d09ff3c
91486cf
23dd181
7efad33
ef9c48d
178cc6e
924643a
faa00b9
804dafd
6f90e41
59f4bda
9ad5720
027dbd5
ff28f65
28b9cf2
b596595
c522119
3239f26
b7bbb04
83e0396
582b64a
33926e0
c9e0a20
2877b60
ed87475
8aa65ba
fba7b33
bef0845
6ef3091
742d0d3
a4f90d4
7fefe09
4444379
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Dict, Optional, Tuple, Union | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
||
import torch | ||
from torch import Tensor, nn | ||
|
@@ -21,7 +21,7 @@ | |
from ...configuration_utils import ConfigMixin, register_to_config | ||
from ...utils import logging | ||
from ..attention import AttentionMixin, AttentionModuleMixin | ||
from ..attention_processor import Attention | ||
from ..attention_dispatch import dispatch_attention_fn | ||
from ..embeddings import get_timestep_embedding | ||
from ..modeling_outputs import Transformer2DModelOutput | ||
from ..modeling_utils import ModelMixin | ||
|
@@ -35,7 +35,7 @@ def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, dev | |
r""" | ||
Generates 2D patch coordinate indices for a batch of images. | ||
|
||
Parameters: | ||
Args: | ||
batch_size (`int`): | ||
Number of images in the batch. | ||
height (`int`): | ||
|
@@ -63,7 +63,7 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: | |
r""" | ||
Applies rotary positional embeddings (RoPE) to a query tensor. | ||
|
||
Parameters: | ||
Args: | ||
xq (`torch.Tensor`): | ||
Input tensor of shape `(..., dim)` representing the queries. | ||
freqs_cis (`torch.Tensor`): | ||
|
@@ -82,11 +82,12 @@ def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: | |
|
||
class PhotonAttnProcessor2_0: | ||
r""" | ||
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Properly integrates with | ||
diffusers Attention module while handling Photon-specific logic. | ||
Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention | ||
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn. | ||
""" | ||
|
||
_attention_backend = None | ||
_parallel_config = None | ||
|
||
def __init__(self): | ||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): | ||
|
@@ -104,7 +105,7 @@ def __call__( | |
""" | ||
Apply Photon attention using PhotonAttention module. | ||
|
||
Parameters: | ||
Args: | ||
attn: PhotonAttention module containing projection layers | ||
hidden_states: Image tokens [B, L_img, D] | ||
encoder_hidden_states: Text tokens [B, L_txt, D] | ||
|
@@ -113,9 +114,7 @@ def __call__( | |
""" | ||
|
||
if encoder_hidden_states is None: | ||
raise ValueError( | ||
"PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens." | ||
) | ||
raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.") | ||
|
||
# Project image tokens to Q, K, V | ||
img_qkv = attn.img_qkv_proj(hidden_states) | ||
|
@@ -164,14 +163,24 @@ def __call__( | |
joint_mask = torch.cat([attention_mask, ones_img], dim=-1) | ||
attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1) | ||
|
||
# Apply scaled dot-product attention | ||
attn_output = torch.nn.functional.scaled_dot_product_attention( | ||
img_q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask_tensor | ||
# Apply attention using dispatch_attention_fn for backend support | ||
# Reshape to match dispatch_attention_fn expectations: [B, L, H, D] | ||
query = img_q.transpose(1, 2) # [B, L_img, H, D] | ||
key = k.transpose(1, 2) # [B, L_txt + L_img, H, D] | ||
value = v.transpose(1, 2) # [B, L_txt + L_img, H, D] | ||
|
||
attn_output = dispatch_attention_fn( | ||
query, | ||
key, | ||
value, | ||
attn_mask=attn_mask_tensor, | ||
backend=self._attention_backend, | ||
parallel_config=self._parallel_config, | ||
) | ||
|
||
# Reshape from [B, H, L_img, D] to [B, L_img, H*D] | ||
batch_size, num_heads, seq_len, head_dim = attn_output.shape | ||
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, num_heads * head_dim) | ||
# Reshape from [B, L_img, H, D] to [B, L_img, H*D] | ||
batch_size, seq_len, num_heads, head_dim = attn_output.shape | ||
attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim) | ||
|
||
# Apply output projection | ||
attn_output = attn.to_out[0](attn_output) | ||
|
@@ -183,8 +192,8 @@ def __call__( | |
|
||
class PhotonAttention(nn.Module, AttentionModuleMixin): | ||
r""" | ||
Photon-style attention module that handles multi-source tokens and RoPE. | ||
Similar to FluxAttention but adapted for Photon's architecture. | ||
Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for | ||
Photon's architecture. | ||
""" | ||
|
||
_default_processor_cls = PhotonAttnProcessor2_0 | ||
|
@@ -242,14 +251,14 @@ def forward( | |
|
||
|
||
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py | ||
class PhotoEmbedND(nn.Module): | ||
class PhotonEmbedND(nn.Module): | ||
r""" | ||
N-dimensional rotary positional embedding. | ||
|
||
This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding | ||
dimension. The embeddings are combined and returned as a single tensor | ||
|
||
Parameters: | ||
Args: | ||
dim (int): | ||
Base embedding dimension (must be even). | ||
theta (int): | ||
|
@@ -258,7 +267,7 @@ class PhotoEmbedND(nn.Module): | |
List of embedding dimensions for each axis (each must be even). | ||
""" | ||
|
||
def __init__(self, dim: int, theta: int, axes_dim: list[int]): | ||
def __init__(self, dim: int, theta: int, axes_dim: List[int]): | ||
super().__init__() | ||
self.dim = dim | ||
self.theta = theta | ||
|
@@ -288,7 +297,7 @@ class MLPEmbedder(nn.Module): | |
r""" | ||
A simple 2-layer MLP used for embedding inputs. | ||
|
||
Parameters: | ||
Args: | ||
in_dim (`int`): | ||
Dimensionality of the input features. | ||
hidden_dim (`int`): | ||
|
@@ -316,7 +325,7 @@ class Modulation(nn.Module): | |
Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into | ||
two tuples `(shift, scale, gate)`. | ||
|
||
Parameters: | ||
Args: | ||
dim (`int`): | ||
Dimensionality of the input vector. The output will have `6 * dim` features internally. | ||
|
||
|
@@ -340,7 +349,7 @@ class PhotonBlock(nn.Module): | |
r""" | ||
Multimodal transformer block with text–image cross-attention, modulation, and MLP. | ||
|
||
Parameters: | ||
Args: | ||
hidden_size (`int`): | ||
Dimension of the hidden representations. | ||
num_heads (`int`): | ||
|
@@ -421,7 +430,7 @@ def forward( | |
r""" | ||
Runs modulation-gated cross-attention and MLP, with residual connections. | ||
|
||
Parameters: | ||
Args: | ||
hidden_states (`torch.Tensor`): | ||
Image tokens of shape `(B, L_img, hidden_size)`. | ||
encoder_hidden_states (`torch.Tensor`): | ||
|
@@ -468,7 +477,7 @@ class FinalLayer(nn.Module): | |
This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level | ||
outputs. | ||
|
||
Parameters: | ||
Args: | ||
hidden_size (`int`): | ||
Dimensionality of the input tokens. | ||
patch_size (`int`): | ||
|
@@ -505,7 +514,7 @@ def img2seq(img: Tensor, patch_size: int) -> Tensor: | |
r""" | ||
Flattens an image tensor into a sequence of non-overlapping patches. | ||
|
||
Parameters: | ||
Args: | ||
img (`torch.Tensor`): | ||
Input image tensor of shape `(B, C, H, W)`. | ||
patch_size (`int`): | ||
|
@@ -523,7 +532,7 @@ def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: | |
r""" | ||
Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`). | ||
|
||
Parameters: | ||
Args: | ||
seq (`torch.Tensor`): | ||
Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W // | ||
patch_size)`. | ||
|
@@ -550,7 +559,7 @@ class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): | |
r""" | ||
Transformer-based 2D model for text to image generation. | ||
|
||
Parameters: | ||
Args: | ||
in_channels (`int`, *optional*, defaults to 16): | ||
Number of input channels in the latent image. | ||
patch_size (`int`, *optional*, defaults to 2): | ||
|
@@ -650,7 +659,7 @@ def __init__( | |
|
||
self.hidden_size = hidden_size | ||
self.num_heads = num_heads | ||
self.pe_embedder = PhotoEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) | ||
self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) | ||
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) | ||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) | ||
self.txt_in = nn.Linear(context_in_dim, self.hidden_size) | ||
|
@@ -683,11 +692,10 @@ def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> T | |
|
||
def forward( | ||
self, | ||
image_latent: Tensor, | ||
hidden_states: Tensor, | ||
timestep: Tensor, | ||
cross_attn_conditioning: Tensor, | ||
micro_conditioning: Tensor, | ||
cross_attn_mask: None | Tensor = None, | ||
encoder_hidden_states: Tensor, | ||
attention_mask: Optional[Tensor] = None, | ||
attention_kwargs: Optional[Dict[str, Any]] = None, | ||
return_dict: bool = True, | ||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: | ||
|
@@ -697,16 +705,14 @@ def forward( | |
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of | ||
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space. | ||
|
||
Parameters: | ||
image_latent (`torch.Tensor`): | ||
Args: | ||
hidden_states (`torch.Tensor`): | ||
Input latent image tensor of shape `(B, C, H, W)`. | ||
timestep (`torch.Tensor`): | ||
Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning. | ||
cross_attn_conditioning (`torch.Tensor`): | ||
encoder_hidden_states (`torch.Tensor`): | ||
Text conditioning tensor of shape `(B, L_txt, context_in_dim)`. | ||
micro_conditioning (`torch.Tensor`): | ||
Extra conditioning vector (currently unused, reserved for future use). | ||
|
||
cross_attn_mask (`torch.Tensor`, *optional*): | ||
attention_mask (`torch.Tensor`, *optional*): | ||
Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence. | ||
attention_kwargs (`dict`, *optional*): | ||
Additional arguments passed to attention layers. | ||
|
@@ -719,15 +725,15 @@ def forward( | |
- `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`. | ||
""" | ||
# Process text conditioning | ||
txt = self.txt_in(cross_attn_conditioning) | ||
txt = self.txt_in(encoder_hidden_states) | ||
|
||
# Convert image to sequence and embed | ||
img = img2seq(image_latent, self.patch_size) | ||
img = img2seq(hidden_states, self.patch_size) | ||
img = self.img_in(img) | ||
|
||
# Generate positional embeddings | ||
bs, _, h, w = image_latent.shape | ||
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=image_latent.device) | ||
bs, _, h, w = hidden_states.shape | ||
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device) | ||
pe = self.pe_embedder(img_ids) | ||
|
||
# Compute time embedding | ||
|
@@ -742,20 +748,20 @@ def forward( | |
txt, | ||
vec, | ||
pe, | ||
cross_attn_mask, | ||
attention_mask, | ||
) | ||
else: | ||
img = block( | ||
hidden_states=img, | ||
encoder_hidden_states=txt, | ||
temb=vec, | ||
image_rotary_emb=pe, | ||
attention_mask=cross_attn_mask, | ||
attention_mask=attention_mask, | ||
) | ||
|
||
# Final layer and convert back to image | ||
img = self.final_layer(img, vec) | ||
output = seq2img(img, self.patch_size, image_latent.shape) | ||
output = seq2img(img, self.patch_size, hidden_states.shape) | ||
|
||
if not return_dict: | ||
return (output,) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we write it in a fashion similar to
diffusers/src/diffusers/models/transformers/transformer_flux.py
Line 75 in 8abc7ae
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I second this suggestion - in particular, I think it would be more in line with other
diffusers
models implementations to reuse the layers defined inAttention
, such asto_q
/to_k
/to_v
, etc. instead of defining them inPhotonBlock
(e.g.PhotonBlock.img_qkv_proj
), and to keep the entire attention implementation in thePhotonAttnProcessor2_0
class.Attention
supports stuff like QK norms and fusing projections, so that could potentially be reused as well. If you need some custom logic not found inAttention
, you could potentially add it in there or create a newAttention
-style class like Flux does:diffusers/src/diffusers/models/transformers/transformer_flux.py
Line 275 in 8abc7ae
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made the change and updated both the conversion script and the checkpoints on the hub.