diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index fb4fdf2098e6..c14ff4df0d23 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -570,6 +570,8 @@ title: Paint by Example - local: api/pipelines/pia title: Personalized Image Animator (PIA) + - local: api/pipelines/photon + title: Photon - local: api/pipelines/pixart title: PixArt-α - local: api/pipelines/pixart_sigma diff --git a/docs/source/en/api/pipelines/photon.md b/docs/source/en/api/pipelines/photon.md new file mode 100644 index 000000000000..f9d6ba5a1792 --- /dev/null +++ b/docs/source/en/api/pipelines/photon.md @@ -0,0 +1,131 @@ + + +# Photon + + +Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing. + +## Available models + +Photon offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts. + + +| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype | +|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:| +| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i-sft`](hhttps://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` | +| [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s + +Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information. + +## Loading the pipeline + +Load the pipeline with [`~DiffusionPipeline.from_pretrained`]. + +```py +from diffusers.pipelines.photon import PhotonPipeline + +# Load pipeline - VAE and text encoder will be loaded from HuggingFace +pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +prompt = "A front-facing portrait of a lion the golden savanna at sunset." +image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] +image.save("photon_output.png") +``` + +### Manual Component Loading + +Load components individually to customize the pipeline for instance to use quantized models. + +```py +import torch +from diffusers.pipelines.photon import PhotonPipeline +from diffusers.models import AutoencoderKL, AutoencoderDC +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from transformers import T5GemmaModel, GemmaTokenizerFast +from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig +from transformers import BitsAndBytesConfig as BitsAndBytesConfig + +quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) +# Load transformer +transformer = PhotonTransformer2DModel.from_pretrained( + "checkpoints/photon-512-t2i-sft", + subfolder="transformer", + quantization_config=quant_config, + torch_dtype=torch.bfloat16, +) + +# Load scheduler +scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + "checkpoints/photon-512-t2i-sft", subfolder="scheduler" +) + +# Load T5Gemma text encoder +t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2", + quantization_config=quant_config, + torch_dtype=torch.bfloat16) +text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16) +tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") +tokenizer.model_max_length = 256 + +# Load VAE - choose either Flux VAE or DC-AE +# Flux VAE +vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", + subfolder="vae", + quantization_config=quant_config, + torch_dtype=torch.bfloat16) + +pipe = PhotonPipeline( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae +) +pipe.to("cuda") +``` + + +## Memory Optimization + +For memory-constrained environments: + +```py +import torch +from diffusers.pipelines.photon import PhotonPipeline + +pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16) +pipe.enable_model_cpu_offload() # Offload components to CPU when not in use + +# Or use sequential CPU offload for even lower memory +pipe.enable_sequential_cpu_offload() +``` + +## PhotonPipeline + +[[autodoc]] PhotonPipeline + - all + - __call__ + +## PhotonPipelineOutput + +[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput diff --git a/scripts/convert_photon_to_diffusers.py b/scripts/convert_photon_to_diffusers.py new file mode 100644 index 000000000000..c66bc314181f --- /dev/null +++ b/scripts/convert_photon_to_diffusers.py @@ -0,0 +1,345 @@ +#!/usr/bin/env python3 +""" +Script to convert Photon checkpoint from original codebase to diffusers format. +""" + +import argparse +import json +import os +import sys +from dataclasses import asdict, dataclass +from typing import Dict, Tuple + +import torch +from safetensors.torch import save_file + +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.pipelines.photon import PhotonPipeline + + +DEFAULT_RESOLUTION = 512 + + +@dataclass(frozen=True) +class PhotonBase: + context_in_dim: int = 2304 + hidden_size: int = 1792 + mlp_ratio: float = 3.5 + num_heads: int = 28 + depth: int = 16 + axes_dim: Tuple[int, int] = (32, 32) + theta: int = 10_000 + time_factor: float = 1000.0 + time_max_period: int = 10_000 + + +@dataclass(frozen=True) +class PhotonFlux(PhotonBase): + in_channels: int = 16 + patch_size: int = 2 + + +@dataclass(frozen=True) +class PhotonDCAE(PhotonBase): + in_channels: int = 32 + patch_size: int = 1 + + +def build_config(vae_type: str) -> Tuple[dict, int]: + if vae_type == "flux": + cfg = PhotonFlux() + elif vae_type == "dc-ae": + cfg = PhotonDCAE() + else: + raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'") + + config_dict = asdict(cfg) + config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index] + return config_dict + + +def create_parameter_mapping(depth: int) -> dict: + """Create mapping from old parameter names to new diffusers names.""" + + # Key mappings for structural changes + mapping = {} + + # Map old structure (layers in PhotonBlock) to new structure (layers in PhotonAttention) + for i in range(depth): + # QKV projections moved to attention module + mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight" + mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight" + + # QK norm moved to attention module and renamed to match Attention's qk_norm structure + mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight" + mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight" + mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight" + mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight" + + # K norm for text tokens moved to attention module + mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight" + mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight" + + # Attention output projection + mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight" + + return mapping + + +def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]: + """Convert old checkpoint parameters to new diffusers format.""" + + print("Converting checkpoint parameters...") + + mapping = create_parameter_mapping(depth) + converted_state_dict = {} + + for key, value in old_state_dict.items(): + new_key = key + + # Apply specific mappings if needed + if key in mapping: + new_key = mapping[key] + print(f" Mapped: {key} -> {new_key}") + + converted_state_dict[new_key] = value + + print(f"✓ Converted {len(converted_state_dict)} parameters") + return converted_state_dict + + +def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel: + """Create and load PhotonTransformer2DModel from old checkpoint.""" + + print(f"Loading checkpoint from: {checkpoint_path}") + + # Load old checkpoint + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + old_checkpoint = torch.load(checkpoint_path, map_location="cpu") + + # Handle different checkpoint formats + if isinstance(old_checkpoint, dict): + if "model" in old_checkpoint: + state_dict = old_checkpoint["model"] + elif "state_dict" in old_checkpoint: + state_dict = old_checkpoint["state_dict"] + else: + state_dict = old_checkpoint + else: + state_dict = old_checkpoint + + print(f"✓ Loaded checkpoint with {len(state_dict)} parameters") + + # Convert parameter names if needed + model_depth = int(config.get("depth", 16)) + converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth) + + # Create transformer with config + print("Creating PhotonTransformer2DModel...") + transformer = PhotonTransformer2DModel(**config) + + # Load state dict + print("Loading converted parameters...") + missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False) + + if missing_keys: + print(f"⚠ Missing keys: {missing_keys}") + if unexpected_keys: + print(f"⚠ Unexpected keys: {unexpected_keys}") + + if not missing_keys and not unexpected_keys: + print("✓ All parameters loaded successfully!") + + return transformer + + +def create_scheduler_config(output_path: str, shift: float): + """Create FlowMatchEulerDiscreteScheduler config.""" + + scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift} + + scheduler_path = os.path.join(output_path, "scheduler") + os.makedirs(scheduler_path, exist_ok=True) + + with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f: + json.dump(scheduler_config, f, indent=2) + + print("✓ Created scheduler config") + + +def download_and_save_vae(vae_type: str, output_path: str): + """Download and save VAE to local directory.""" + from diffusers import AutoencoderDC, AutoencoderKL + + vae_path = os.path.join(output_path, "vae") + os.makedirs(vae_path, exist_ok=True) + + if vae_type == "flux": + print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...") + vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae") + else: # dc-ae + print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...") + vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers") + + vae.save_pretrained(vae_path) + print(f"✓ Saved VAE to {vae_path}") + + +def download_and_save_text_encoder(output_path: str): + """Download and save T5Gemma text encoder and tokenizer.""" + from transformers import GemmaTokenizerFast + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel + + text_encoder_path = os.path.join(output_path, "text_encoder") + tokenizer_path = os.path.join(output_path, "tokenizer") + os.makedirs(text_encoder_path, exist_ok=True) + os.makedirs(tokenizer_path, exist_ok=True) + + print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...") + t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2") + + # Extract and save only the encoder + t5gemma_encoder = t5gemma_model.encoder + t5gemma_encoder.save_pretrained(text_encoder_path) + print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}") + + print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...") + tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2") + tokenizer.model_max_length = 256 + tokenizer.save_pretrained(tokenizer_path) + print(f"✓ Saved tokenizer to {tokenizer_path}") + + +def create_model_index(vae_type: str, default_image_size: int, output_path: str): + """Create model_index.json for the pipeline.""" + + if vae_type == "flux": + vae_class = "AutoencoderKL" + else: # dc-ae + vae_class = "AutoencoderDC" + + model_index = { + "_class_name": "PhotonPipeline", + "_diffusers_version": "0.31.0.dev0", + "_name_or_path": os.path.basename(output_path), + "default_sample_size": default_image_size, + "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"], + "text_encoder": ["photon", "T5GemmaEncoder"], + "tokenizer": ["transformers", "GemmaTokenizerFast"], + "transformer": ["diffusers", "PhotonTransformer2DModel"], + "vae": ["diffusers", vae_class], + } + + model_index_path = os.path.join(output_path, "model_index.json") + with open(model_index_path, "w") as f: + json.dump(model_index, f, indent=2) + + +def main(args): + # Validate inputs + if not os.path.exists(args.checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}") + + config = build_config(args.vae_type) + + # Create output directory + os.makedirs(args.output_path, exist_ok=True) + print(f"✓ Output directory: {args.output_path}") + + # Create transformer from checkpoint + transformer = create_transformer_from_checkpoint(args.checkpoint_path, config) + + # Save transformer + transformer_path = os.path.join(args.output_path, "transformer") + os.makedirs(transformer_path, exist_ok=True) + + # Save config + with open(os.path.join(transformer_path, "config.json"), "w") as f: + json.dump(config, f, indent=2) + + # Save model weights as safetensors + state_dict = transformer.state_dict() + save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors")) + print(f"✓ Saved transformer to {transformer_path}") + + # Create scheduler config + create_scheduler_config(args.output_path, args.shift) + + download_and_save_vae(args.vae_type, args.output_path) + download_and_save_text_encoder(args.output_path) + + # Create model_index.json + create_model_index(args.vae_type, args.resolution, args.output_path) + + # Verify the pipeline can be loaded + try: + pipeline = PhotonPipeline.from_pretrained(args.output_path) + print("Pipeline loaded successfully!") + print(f"Transformer: {type(pipeline.transformer).__name__}") + print(f"VAE: {type(pipeline.vae).__name__}") + print(f"Text Encoder: {type(pipeline.text_encoder).__name__}") + print(f"Scheduler: {type(pipeline.scheduler).__name__}") + + # Display model info + num_params = sum(p.numel() for p in pipeline.transformer.parameters()) + print(f"✓ Transformer parameters: {num_params:,}") + + except Exception as e: + print(f"Pipeline verification failed: {e}") + return False + + print("Conversion completed successfully!") + print(f"Converted pipeline saved to: {args.output_path}") + print(f"VAE type: {args.vae_type}") + + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format") + + parser.add_argument( + "--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )" + ) + + parser.add_argument( + "--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline" + ) + + parser.add_argument( + "--vae_type", + type=str, + choices=["flux", "dc-ae"], + required=True, + help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)", + ) + + parser.add_argument( + "--resolution", + type=int, + choices=[256, 512, 1024], + default=DEFAULT_RESOLUTION, + help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.", + ) + + parser.add_argument( + "--shift", + type=float, + default=3.0, + help="Shift for the scheduler", + ) + + args = parser.parse_args() + + try: + success = main(args) + if not success: + sys.exit(1) + except Exception as e: + print(f"Conversion failed: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 686e8d99dabf..602be5d9d5f2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -231,6 +231,7 @@ "MultiControlNetModel", "OmniGenTransformer2DModel", "ParallelConfig", + "PhotonTransformer2DModel", "PixArtTransformer2DModel", "PriorTransformer", "QwenImageControlNetModel", @@ -511,6 +512,7 @@ "MusicLDMPipeline", "OmniGenPipeline", "PaintByExamplePipeline", + "PhotonPipeline", "PIAPipeline", "PixArtAlphaPipeline", "PixArtSigmaPAGPipeline", @@ -1171,6 +1173,7 @@ MusicLDMPipeline, OmniGenPipeline, PaintByExamplePipeline, + PhotonPipeline, PIAPipeline, PixArtAlphaPipeline, PixArtSigmaPAGPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 457f70448af3..d9f4c0148d16 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -95,6 +95,7 @@ _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] + _import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"] _import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] @@ -188,6 +189,7 @@ LuminaNextDiT2DModel, MochiTransformer3DModel, OmniGenTransformer2DModel, + PhotonTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, QwenImageTransformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6dc..7fdab560a702 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -31,6 +31,7 @@ from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_omnigen import OmniGenTransformer2DModel + from .transformer_photon import PhotonTransformer2DModel from .transformer_qwenimage import QwenImageTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_photon.py b/src/diffusers/models/transformers/transformer_photon.py new file mode 100644 index 000000000000..1a40a829719e --- /dev/null +++ b/src/diffusers/models/transformers/transformer_photon.py @@ -0,0 +1,768 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import Tensor, nn +from torch.nn.functional import fold, unfold + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention import AttentionMixin, AttentionModuleMixin +from ..attention_dispatch import dispatch_attention_fn +from ..embeddings import get_timestep_embedding +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm + + +logger = logging.get_logger(__name__) + + +def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> Tensor: + r""" + Generates 2D patch coordinate indices for a batch of images. + + Args: + batch_size (`int`): + Number of images in the batch. + height (`int`): + Height of the input images (in pixels). + width (`int`): + Width of the input images (in pixels). + patch_size (`int`): + Size of the square patches that the image is divided into. + device (`torch.device`): + The device on which to create the tensor. + + Returns: + `torch.Tensor`: + Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the + image grid. + """ + + img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device) + img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None] + img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :] + return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1) + + +def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: + r""" + Applies rotary positional embeddings (RoPE) to a query tensor. + + Args: + xq (`torch.Tensor`): + Input tensor of shape `(..., dim)` representing the queries. + freqs_cis (`torch.Tensor`): + Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs. + + Returns: + `torch.Tensor`: + Tensor of the same shape as `xq` with rotary embeddings applied. + """ + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) + # Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading + freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype) + xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] + return xq_out.reshape(*xq.shape).type_as(xq) + + +class PhotonAttnProcessor2_0: + r""" + Processor for implementing Photon-style attention with multi-source tokens and RoPE. Supports multiple attention + backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(torch.nn.functional, "scaled_dot_product_attention"): + raise ImportError("PhotonAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: "PhotonAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Apply Photon attention using PhotonAttention module. + + Args: + attn: PhotonAttention module containing projection layers + hidden_states: Image tokens [B, L_img, D] + encoder_hidden_states: Text tokens [B, L_txt, D] + attention_mask: Boolean mask for text tokens [B, L_txt] + image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2] + """ + + if encoder_hidden_states is None: + raise ValueError("PhotonAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.") + + # Project image tokens to Q, K, V + img_qkv = attn.img_qkv_proj(hidden_states) + B, L_img, _ = img_qkv.shape + img_qkv = img_qkv.reshape(B, L_img, 3, attn.heads, attn.head_dim) + img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # [3, B, H, L_img, D] + img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2] + + # Apply QK normalization to image tokens + img_q = attn.norm_q(img_q) + img_k = attn.norm_k(img_k) + + # Project text tokens to K, V + txt_kv = attn.txt_kv_proj(encoder_hidden_states) + B, L_txt, _ = txt_kv.shape + txt_kv = txt_kv.reshape(B, L_txt, 2, attn.heads, attn.head_dim) + txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # [2, B, H, L_txt, D] + txt_k, txt_v = txt_kv[0], txt_kv[1] + + # Apply K normalization to text tokens + txt_k = attn.norm_added_k(txt_k) + + # Apply RoPE to image queries and keys + if image_rotary_emb is not None: + img_q = apply_rope(img_q, image_rotary_emb) + img_k = apply_rope(img_k, image_rotary_emb) + + # Concatenate text and image keys/values + k = torch.cat((txt_k, img_k), dim=2) # [B, H, L_txt + L_img, D] + v = torch.cat((txt_v, img_v), dim=2) # [B, H, L_txt + L_img, D] + + # Build attention mask if provided + attn_mask_tensor = None + if attention_mask is not None: + bs, _, l_img, _ = img_q.shape + l_txt = txt_k.shape[2] + + if attention_mask.dim() != 2: + raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") + if attention_mask.shape[-1] != l_txt: + raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}") + + device = img_q.device + ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) + attention_mask = attention_mask.to(device=device, dtype=torch.bool) + joint_mask = torch.cat([attention_mask, ones_img], dim=-1) + attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1) + + # Apply attention using dispatch_attention_fn for backend support + # Reshape to match dispatch_attention_fn expectations: [B, L, H, D] + query = img_q.transpose(1, 2) # [B, L_img, H, D] + key = k.transpose(1, 2) # [B, L_txt + L_img, H, D] + value = v.transpose(1, 2) # [B, L_txt + L_img, H, D] + + attn_output = dispatch_attention_fn( + query, + key, + value, + attn_mask=attn_mask_tensor, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + + # Reshape from [B, L_img, H, D] to [B, L_img, H*D] + batch_size, seq_len, num_heads, head_dim = attn_output.shape + attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim) + + # Apply output projection + attn_output = attn.to_out[0](attn_output) + if len(attn.to_out) > 1: + attn_output = attn.to_out[1](attn_output) # dropout if present + + return attn_output + + +class PhotonAttention(nn.Module, AttentionModuleMixin): + r""" + Photon-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for + Photon's architecture. + """ + + _default_processor_cls = PhotonAttnProcessor2_0 + _available_processors = [PhotonAttnProcessor2_0] + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + bias: bool = False, + out_bias: bool = False, + eps: float = 1e-6, + processor=None, + ): + super().__init__() + + self.heads = heads + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.query_dim = query_dim + + self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias) + + self.norm_q = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True) + self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True) + + self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias) + self.norm_added_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(0.0)) + + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + **kwargs, + ) + + +# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py +class PhotonEmbedND(nn.Module): + r""" + N-dimensional rotary positional embedding. + + This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding + dimension. The embeddings are combined and returned as a single tensor + + Args: + dim (int): + Base embedding dimension (must be even). + theta (int): + Scaling factor that controls the frequency spectrum of the rotary embeddings. + axes_dim (list[int]): + List of embedding dimensions for each axis (each must be even). + """ + + def __init__(self, dim: int, theta: int, axes_dim: List[int]): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor: + assert dim % 2 == 0 + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + omega = 1.0 / (theta**scale) + out = pos.unsqueeze(-1) * omega.unsqueeze(0) + out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) + # Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2) + # out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2) + out = out.reshape(*out.shape[:-1], 2, 2) + return out.float() + + def forward(self, ids: Tensor) -> Tensor: + n_axes = ids.shape[-1] + emb = torch.cat( + [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)], + dim=-3, + ) + return emb.unsqueeze(1) + + +class MLPEmbedder(nn.Module): + r""" + A simple 2-layer MLP used for embedding inputs. + + Args: + in_dim (`int`): + Dimensionality of the input features. + hidden_dim (`int`): + Dimensionality of the hidden and output embedding space. + + Returns: + `torch.Tensor`: + Tensor of shape `(..., hidden_dim)` containing the embedded representations. + """ + + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.silu = nn.SiLU() + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def forward(self, x: Tensor) -> Tensor: + return self.out_layer(self.silu(self.in_layer(x))) + + +class Modulation(nn.Module): + r""" + Modulation network that generates scale, shift, and gating parameters. + + Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into + two tuples `(shift, scale, gate)`. + + Args: + dim (`int`): + Dimensionality of the input vector. The output will have `6 * dim` features internally. + + Returns: + ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)): + Two tuples `(shift, scale, gate)`. + """ + + def __init__(self, dim: int): + super().__init__() + self.lin = nn.Linear(dim, 6 * dim, bias=True) + nn.init.constant_(self.lin.weight, 0) + nn.init.constant_(self.lin.bias, 0) + + def forward(self, vec: Tensor) -> tuple[tuple[Tensor, Tensor, Tensor], tuple[Tensor, Tensor, Tensor]]: + out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) + return tuple(out[:3]), tuple(out[3:]) + + +class PhotonBlock(nn.Module): + r""" + Multimodal transformer block with text–image cross-attention, modulation, and MLP. + + Args: + hidden_size (`int`): + Dimension of the hidden representations. + num_heads (`int`): + Number of attention heads. + mlp_ratio (`float`, *optional*, defaults to 4.0): + Expansion ratio for the hidden dimension inside the MLP. + qk_scale (`float`, *optional*): + Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``. + + Attributes: + img_pre_norm (`nn.LayerNorm`): + Pre-normalization applied to image tokens before attention. + attention (`PhotonAttention`): + Multi-head attention module with built-in QKV projections and normalizations for cross-attention between + image and text tokens. + post_attention_layernorm (`nn.LayerNorm`): + Normalization applied after attention. + gate_proj / up_proj / down_proj (`nn.Linear`): + Feedforward layers forming the gated MLP. + mlp_act (`nn.GELU`): + Nonlinear activation used in the MLP. + modulation (`Modulation`): + Produces scale/shift/gating parameters for modulated layers. + + Methods: + The forward method performs cross-attention and the MLP with modulation. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: float | None = None, + ): + super().__init__() + + self.hidden_dim = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.hidden_size = hidden_size + + # Pre-attention normalization for image tokens + self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + + # PhotonAttention module with built-in projections and norms + self.attention = PhotonAttention( + query_dim=hidden_size, + heads=num_heads, + dim_head=self.head_dim, + bias=False, + out_bias=False, + eps=1e-6, + processor=PhotonAttnProcessor2_0(), + ) + + # mlp + self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) + self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) + self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False) + self.mlp_act = nn.GELU(approximate="tanh") + + self.modulation = Modulation(hidden_size) + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states: Tensor, + temb: Tensor, + image_rotary_emb: Tensor, + attention_mask: Tensor | None = None, + **kwargs: dict[str, Any], + ) -> Tensor: + r""" + Runs modulation-gated cross-attention and MLP, with residual connections. + + Args: + hidden_states (`torch.Tensor`): + Image tokens of shape `(B, L_img, hidden_size)`. + encoder_hidden_states (`torch.Tensor`): + Text tokens of shape `(B, L_txt, hidden_size)`. + temb (`torch.Tensor`): + Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or + broadcastable). + image_rotary_emb (`torch.Tensor`): + Rotary positional embeddings applied inside attention. + attention_mask (`torch.Tensor`, *optional*): + Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding. + **kwargs: + Additional keyword arguments for API compatibility. + + Returns: + `torch.Tensor`: + Updated image tokens of shape `(B, L_img, hidden_size)`. + """ + + mod_attn, mod_mlp = self.modulation(temb) + attn_shift, attn_scale, attn_gate = mod_attn + mlp_shift, mlp_scale, mlp_gate = mod_mlp + + hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift + + attn_out = self.attention( + hidden_states=hidden_states_mod, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + hidden_states = hidden_states + attn_gate * attn_out + + x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift + hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x))) + return hidden_states + + +class FinalLayer(nn.Module): + r""" + Final projection layer with adaptive LayerNorm modulation. + + This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level + outputs. + + Args: + hidden_size (`int`): + Dimensionality of the input tokens. + patch_size (`int`): + Size of the square image patches. + out_channels (`int`): + Number of output channels per pixel (e.g. RGB = 3). + + Forward Inputs: + x (`torch.Tensor`): + Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches. + vec (`torch.Tensor`): + Conditioning vector of shape `(B, hidden_size)` used to generate shift and scale parameters for adaptive + LayerNorm. + + Returns: + `torch.Tensor`: + Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`. + """ + + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x: Tensor, vec: Tensor) -> Tensor: + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x + + +def img2seq(img: Tensor, patch_size: int) -> Tensor: + r""" + Flattens an image tensor into a sequence of non-overlapping patches. + + Args: + img (`torch.Tensor`): + Input image tensor of shape `(B, C, H, W)`. + patch_size (`int`): + Size of each square patch. Must evenly divide both `H` and `W`. + + Returns: + `torch.Tensor`: + Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W + // patch_size)` is the number of patches. + """ + return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2) + + +def seq2img(seq: Tensor, patch_size: int, shape: Tensor) -> Tensor: + r""" + Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`). + + Args: + seq (`torch.Tensor`): + Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W // + patch_size)`. + patch_size (`int`): + Size of each square patch. + shape (`tuple` or `torch.Tensor`): + The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as + height and width. + + Returns: + `torch.Tensor`: + Reconstructed image tensor of shape `(B, C, H, W)`. + """ + if isinstance(shape, tuple): + shape = shape[-2:] + elif isinstance(shape, torch.Tensor): + shape = (int(shape[0]), int(shape[1])) + else: + raise NotImplementedError(f"shape type {type(shape)} not supported") + return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size) + + +class PhotonTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin): + r""" + Transformer-based 2D model for text to image generation. + + Args: + in_channels (`int`, *optional*, defaults to 16): + Number of input channels in the latent image. + patch_size (`int`, *optional*, defaults to 2): + Size of the square patches used to flatten the input image. + context_in_dim (`int`, *optional*, defaults to 2304): + Dimensionality of the text conditioning input. + hidden_size (`int`, *optional*, defaults to 1792): + Dimension of the hidden representation. + mlp_ratio (`float`, *optional*, defaults to 3.5): + Expansion ratio for the hidden dimension inside MLP blocks. + num_heads (`int`, *optional*, defaults to 28): + Number of attention heads. + depth (`int`, *optional*, defaults to 16): + Number of transformer blocks. + axes_dim (`list[int]`, *optional*): + List of dimensions for each positional embedding axis. Defaults to `[32, 32]`. + theta (`int`, *optional*, defaults to 10000): + Frequency scaling factor for rotary embeddings. + time_factor (`float`, *optional*, defaults to 1000.0): + Scaling factor applied in timestep embeddings. + time_max_period (`int`, *optional*, defaults to 10000): + Maximum frequency period for timestep embeddings. + + Attributes: + pe_embedder (`EmbedND`): + Multi-axis rotary embedding generator for positional encodings. + img_in (`nn.Linear`): + Projection layer for image patch tokens. + time_in (`MLPEmbedder`): + Embedding layer for timestep embeddings. + txt_in (`nn.Linear`): + Projection layer for text conditioning. + blocks (`nn.ModuleList`): + Stack of transformer blocks (`PhotonBlock`). + final_layer (`LastLayer`): + Projection layer mapping hidden tokens back to patch outputs. + + Methods: + attn_processors: + Returns a dictionary of all attention processors in the model. + set_attn_processor(processor): + Replaces attention processors across all attention layers. + process_inputs(image_latent, txt): + Converts inputs into patch tokens, encodes text, and produces positional encodings. + compute_timestep_embedding(timestep, dtype): + Creates a timestep embedding of dimension 256, scaled and projected. + forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask, + **block_kwargs): + Runs the sequence of transformer blocks over image and text tokens. + forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None, + attention_kwargs=None, return_dict=True): + Full forward pass from latent input to reconstructed output image. + + Returns: + `Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing: + - `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`. + """ + + config_name = "config.json" + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 16, + patch_size: int = 2, + context_in_dim: int = 2304, + hidden_size: int = 1792, + mlp_ratio: float = 3.5, + num_heads: int = 28, + depth: int = 16, + axes_dim: list = None, + theta: int = 10000, + time_factor: float = 1000.0, + time_max_period: int = 10000, + ): + super().__init__() + + if axes_dim is None: + axes_dim = [32, 32] + + # Store parameters directly + self.in_channels = in_channels + self.patch_size = patch_size + self.out_channels = self.in_channels * self.patch_size**2 + + self.time_factor = time_factor + self.time_max_period = time_max_period + + if hidden_size % num_heads != 0: + raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}") + + pe_dim = hidden_size // num_heads + + if sum(axes_dim) != pe_dim: + raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}") + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.pe_embedder = PhotonEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim) + self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.txt_in = nn.Linear(context_in_dim, self.hidden_size) + + self.blocks = nn.ModuleList( + [ + PhotonBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=mlp_ratio, + ) + for i in range(depth) + ] + ) + + self.final_layer = FinalLayer(self.hidden_size, 1, self.out_channels) + + self.gradient_checkpointing = False + + def _compute_timestep_embedding(self, timestep: Tensor, dtype: torch.dtype) -> Tensor: + return self.time_in( + get_timestep_embedding( + timesteps=timestep, + embedding_dim=256, + max_period=self.time_max_period, + scale=self.time_factor, + flip_sin_to_cos=True, # Match original cos, sin order + ).to(dtype) + ) + + def forward( + self, + hidden_states: Tensor, + timestep: Tensor, + encoder_hidden_states: Tensor, + attention_mask: Optional[Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + r""" + Forward pass of the PhotonTransformer2DModel. + + The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of + transformer blocks modulated by the timestep. The output is reconstructed into the latent image space. + + Args: + hidden_states (`torch.Tensor`): + Input latent image tensor of shape `(B, C, H, W)`. + timestep (`torch.Tensor`): + Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning. + encoder_hidden_states (`torch.Tensor`): + Text conditioning tensor of shape `(B, L_txt, context_in_dim)`. + attention_mask (`torch.Tensor`, *optional*): + Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence. + attention_kwargs (`dict`, *optional*): + Additional arguments passed to attention layers. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `Transformer2DModelOutput` or a tuple. + + Returns: + `Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple: + + - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`. + """ + # Process text conditioning + txt = self.txt_in(encoder_hidden_states) + + # Convert image to sequence and embed + img = img2seq(hidden_states, self.patch_size) + img = self.img_in(img) + + # Generate positional embeddings + bs, _, h, w = hidden_states.shape + img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device) + pe = self.pe_embedder(img_ids) + + # Compute time embedding + vec = self._compute_timestep_embedding(timestep, dtype=img.dtype) + + # Apply transformer blocks + for block in self.blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + img = self._gradient_checkpointing_func( + block.__call__, + img, + txt, + vec, + pe, + attention_mask, + ) + else: + img = block( + hidden_states=img, + encoder_hidden_states=txt, + temb=vec, + image_rotary_emb=pe, + attention_mask=attention_mask, + ) + + # Final layer and convert back to image + img = self.final_layer(img, vec) + output = seq2img(img, self.patch_size, hidden_states.shape) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 190c7871d270..12904ae4990f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -144,6 +144,7 @@ "FluxKontextPipeline", "FluxKontextInpaintPipeline", ] + _import_structure["photon"] = ["PhotonPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ "AudioLDM2Pipeline", @@ -715,6 +716,7 @@ StableDiffusionXLPAGPipeline, ) from .paint_by_example import PaintByExamplePipeline + from .photon import PhotonPipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .qwenimage import ( diff --git a/src/diffusers/pipelines/photon/__init__.py b/src/diffusers/pipelines/photon/__init__.py new file mode 100644 index 000000000000..e21e31d4225f --- /dev/null +++ b/src/diffusers/pipelines/photon/__init__.py @@ -0,0 +1,63 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_additional_imports = {} +_import_structure = {"pipeline_output": ["PhotonPipelineOutput"]} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_photon"] = ["PhotonPipeline"] + +# Import T5GemmaEncoder for pipeline loading compatibility +try: + if is_transformers_available(): + import transformers + from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + + _additional_imports["T5GemmaEncoder"] = T5GemmaEncoder + # Patch transformers module directly for serialization + if not hasattr(transformers, "T5GemmaEncoder"): + transformers.T5GemmaEncoder = T5GemmaEncoder +except ImportError: + pass + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .pipeline_output import PhotonPipelineOutput + from .pipeline_photon import PhotonPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) + for name, value in _additional_imports.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/photon/pipeline_output.py b/src/diffusers/pipelines/photon/pipeline_output.py new file mode 100644 index 000000000000..d4b0ff462983 --- /dev/null +++ b/src/diffusers/pipelines/photon/pipeline_output.py @@ -0,0 +1,35 @@ +# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL.Image + +from ...utils import BaseOutput + + +@dataclass +class PhotonPipelineOutput(BaseOutput): + """ + Output class for Photon pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] diff --git a/src/diffusers/pipelines/photon/pipeline_photon.py b/src/diffusers/pipelines/photon/pipeline_photon.py new file mode 100644 index 000000000000..b394b12d83f4 --- /dev/null +++ b/src/diffusers/pipelines/photon/pipeline_photon.py @@ -0,0 +1,768 @@ +# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable, Dict, List, Optional, Union + +import ftfy +import torch +from transformers import ( + AutoTokenizer, + GemmaTokenizerFast, + T5TokenizerFast, +) +from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + +from diffusers.image_processor import PixArtImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderDC, AutoencoderKL +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.pipelines.photon.pipeline_output import PhotonPipelineOutput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import ( + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor + + +DEFAULT_RESOLUTION = 512 + +ASPECT_RATIO_256_BIN = { + "0.46": [160, 352], + "0.6": [192, 320], + "0.78": [224, 288], + "1.0": [256, 256], + "1.29": [288, 224], + "1.67": [320, 192], + "2.2": [352, 160], +} + +ASPECT_RATIO_512_BIN = { + "0.5": [352, 704], + "0.57": [384, 672], + "0.6": [384, 640], + "0.68": [416, 608], + "0.78": [448, 576], + "0.88": [480, 544], + "1.0": [512, 512], + "1.13": [544, 480], + "1.29": [576, 448], + "1.46": [608, 416], + "1.67": [640, 384], + "1.75": [672, 384], + "2.0": [704, 352], +} + +logger = logging.get_logger(__name__) + + +class TextPreprocessor: + """Text preprocessing utility for PhotonPipeline.""" + + def __init__(self): + """Initialize text preprocessor.""" + self.bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + r"\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) + + def clean_text(self, text: str) -> str: + """Clean text using comprehensive text processing logic.""" + # See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py + text = str(text) + text = ul.unquote_plus(text) + text = text.strip().lower() + text = re.sub("", "person", text) + + # Remove all urls: + text = re.sub( + r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))", + "", + text, + ) # regex for urls + + # @ + text = re.sub(r"@[\w\d]+\b", "", text) + + # 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs + text = re.sub(r"[\u31c0-\u31ef]+", "", text) + text = re.sub(r"[\u31f0-\u31ff]+", "", text) + text = re.sub(r"[\u3200-\u32ff]+", "", text) + text = re.sub(r"[\u3300-\u33ff]+", "", text) + text = re.sub(r"[\u3400-\u4dbf]+", "", text) + text = re.sub(r"[\u4dc0-\u4dff]+", "", text) + text = re.sub(r"[\u4e00-\u9fff]+", "", text) + + # все виды тире / all types of dash --> "-" + text = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", + "-", + text, + ) + + # кавычки к одному стандарту + text = re.sub(r"[`´«»" "¨]", '"', text) + text = re.sub(r"['']", "'", text) + + # " and & + text = re.sub(r""?", "", text) + text = re.sub(r"&", "", text) + + # ip addresses: + text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text) + + # article ids: + text = re.sub(r"\d:\d\d\s+$", "", text) + + # \n + text = re.sub(r"\\n", " ", text) + + # "#123", "#12345..", "123456.." + text = re.sub(r"#\d{1,3}\b", "", text) + text = re.sub(r"#\d{5,}\b", "", text) + text = re.sub(r"\b\d{6,}\b", "", text) + + # filenames: + text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text) + + # Clean punctuation + text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT""" + text = re.sub(r"[\.]{2,}", r" ", text) + + text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT + text = re.sub(r"\s+\.\s+", r" ", text) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, text)) > 3: + text = re.sub(regex2, " ", text) + + # Basic cleaning + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + text = text.strip() + + # Clean alphanumeric patterns + text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640 + text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc + text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231 + + # Common spam patterns + text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text) + text = re.sub(r"(free\s)?download(\sfree)?", "", text) + text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text) + text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text) + text = re.sub(r"\bpage\s+\d+\b", "", text) + + text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a... + text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text) + + # Final cleanup + text = re.sub(r"\b\s+\:\s+", r": ", text) + text = re.sub(r"(\D[,\./])\b", r"\1 ", text) + text = re.sub(r"\s+", " ", text) + + text.strip() + + text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text) + text = re.sub(r"^[\'\_,\-\:;]", r"", text) + text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text) + text = re.sub(r"^\.\S+$", "", text) + + return text.strip() + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PhotonPipeline + + >>> # Load pipeline with from_pretrained + >>> pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft") + >>> pipe.to("cuda") + + >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach" + >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0] + >>> image.save("photon_output.png") + ``` +""" + + +class PhotonPipeline( + DiffusionPipeline, + LoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): + r""" + Pipeline for text-to-image generation using Photon Transformer. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + transformer ([`PhotonTransformer2DModel`]): + The Photon transformer model to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + text_encoder ([`T5GemmaEncoder`]): + Text encoder model for encoding prompts. + tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]): + Tokenizer for the text encoder. + vae ([`AutoencoderKL`] or [`AutoencoderDC`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + _optional_components = ["vae"] + + def __init__( + self, + transformer: PhotonTransformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder: T5GemmaEncoder, + tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer], + vae: Optional[Union[AutoencoderKL, AutoencoderDC]] = None, + default_sample_size: Optional[int] = DEFAULT_RESOLUTION, + ): + super().__init__() + + if PhotonTransformer2DModel is None: + raise ImportError( + "PhotonTransformer2DModel is not available. Please ensure the transformer_photon module is properly installed." + ) + + self.text_preprocessor = TextPreprocessor() + self.default_sample_size = default_sample_size + self._guidance_scale = 1.0 + + self.register_modules( + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + ) + + self.register_to_config(default_sample_size=self.default_sample_size) + + if vae is not None: + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + else: + self.image_processor = None + + @property + def vae_scale_factor(self): + if self.vae is None: + return 8 + if hasattr(self.vae, "spatial_compression_ratio"): + return self.vae.spatial_compression_ratio + else: # Flux VAE + return 2 ** (len(self.vae.config.block_out_channels) - 1) + + @property + def do_classifier_free_guidance(self): + """Check if classifier-free guidance is enabled based on guidance scale.""" + return self._guidance_scale > 1.0 + + @property + def guidance_scale(self): + return self._guidance_scale + + def get_default_resolution(self): + """Determine the default resolution based on the loaded VAE and config. + + Returns: + int: The default sample size (height/width) to use for generation. + """ + default_from_config = getattr(self.config, "default_sample_size", None) + if default_from_config is not None: + return default_from_config + + return DEFAULT_RESOLUTION + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ): + """Prepare initial latents for the diffusion process.""" + if latents is None: + spatial_compression = self.vae_scale_factor + latent_height, latent_width = ( + height // spatial_compression, + width // spatial_compression, + ) + shape = (batch_size, num_channels_latents, latent_height, latent_width) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + return latents + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.BoolTensor] = None, + negative_prompt_attention_mask: Optional[torch.BoolTensor] = None, + ): + """Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings.""" + if device is None: + device = self._execution_device + + if prompt_embeds is None: + if isinstance(prompt, str): + prompt = [prompt] + # Encode the prompts + prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = ( + self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt) + ) + + # Duplicate embeddings for each generation per prompt + if num_images_per_prompt > 1: + # Repeat prompt embeddings + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if prompt_attention_mask is not None: + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # Repeat negative embeddings if using CFG + if do_classifier_free_guidance and negative_prompt_embeds is not None: + bs_embed, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if negative_prompt_attention_mask is not None: + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds if do_classifier_free_guidance else None, + negative_prompt_attention_mask if do_classifier_free_guidance else None, + ) + + def _tokenize_prompts(self, prompts: List[str], device: torch.device): + """Tokenize and clean prompts.""" + cleaned = [self.text_preprocessor.clean_text(text) for text in prompts] + tokens = self.tokenizer( + cleaned, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device) + + def _encode_prompt_standard( + self, + prompt: List[str], + device: torch.device, + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + ): + """Encode prompt using standard text encoder and tokenizer with batch processing.""" + batch_size = len(prompt) + + if do_classifier_free_guidance: + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + + prompts_to_encode = negative_prompt + prompt + else: + prompts_to_encode = prompt + + input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device) + + with torch.no_grad(): + embeddings = self.text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + )["last_hidden_state"] + + if do_classifier_free_guidance: + uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0) + uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0) + else: + text_embeddings = embeddings + cross_attn_mask = attention_mask + uncond_text_embeddings = None + uncond_cross_attn_mask = None + + return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask + + def check_inputs( + self, + prompt: Union[str, List[str]], + height: int, + width: int, + guidance_scale: float, + callback_on_step_end_tensor_inputs: Optional[List[str]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + """Check that all inputs are in correct format.""" + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + + if prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and guidance_scale > 1.0 and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided and `guidance_scale > 1.0`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + + spatial_compression = self.vae_scale_factor + if height % spatial_compression != 0 or width % spatial_compression != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {spatial_compression} but are {height} and {width}." + ) + + if guidance_scale < 1.0: + raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}") + + if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + timesteps: List[int] = None, + guidance_scale: float = 4.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.BoolTensor] = None, + negative_prompt_attention_mask: Optional[torch.BoolTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + use_resolution_binning: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. + negative_prompt (`str`, *optional*, defaults to `""`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 28): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided and `guidance_scale > 1`, negative embeddings will be generated from an + empty string. + prompt_attention_mask (`torch.BoolTensor`, *optional*): + Pre-generated attention mask for `prompt_embeds`. If not provided, attention mask will be generated + from `prompt` input argument. + negative_prompt_attention_mask (`torch.BoolTensor`, *optional*): + Pre-generated attention mask for `negative_prompt_embeds`. If not provided and `guidance_scale > 1`, + attention mask will be generated from an empty string. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.photon.PhotonPipelineOutput`] instead of a plain tuple. + use_resolution_binning (`bool`, *optional*, defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back + to the requested resolution. Useful for generating non-square images at optimal resolutions. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`. + `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed + in the `._callback_tensor_inputs` attribute. + + Examples: + + Returns: + [`~pipelines.photon.PhotonPipelineOutput`] or `tuple`: [`~pipelines.photon.PhotonPipelineOutput`] if + `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the + generated images. + """ + + # 0. Set height and width + default_resolution = self.get_default_resolution() + height = height or default_resolution + width = width or default_resolution + + if use_resolution_binning: + if self.image_processor is None: + raise ValueError( + "Resolution binning requires a VAE with image_processor, but VAE is not available. " + "Set use_resolution_binning=False or provide a VAE." + ) + if self.default_sample_size <= 256: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + + # Store original dimensions + orig_height, orig_width = height, width + # Map to closest resolution in the bin + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + guidance_scale, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + + if self.vae is None and output_type not in ["latent", "pt"]: + raise ValueError( + f"VAE is required for output_type='{output_type}' but it is not available. " + "Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs." + ) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Use execution device (handles offloading scenarios including group offloading) + device = self._execution_device + + self._guidance_scale = guidance_scale + + # 2. Encode input prompt + text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + # Expose standard names for callbacks parity + prompt_embeds = text_embeddings + negative_prompt_embeds = uncond_text_embeddings + + # 3. Prepare timesteps + if timesteps is not None: + self.scheduler.set_timesteps(timesteps=timesteps, device=device) + timesteps = self.scheduler.timesteps + num_inference_steps = len(timesteps) + else: + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + self.num_timesteps = len(timesteps) + + # 4. Prepare latent variables + if self.vae is not None: + num_channels_latents = self.vae.config.latent_channels + else: + # When vae is None, get latent channels from transformer + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + text_embeddings.dtype, + device, + generator, + latents, + ) + + # 5. Prepare extra step kwargs + extra_step_kwargs = {} + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_eta: + extra_step_kwargs["eta"] = 0.0 + + # 6. Prepare cross-attention embeddings and masks + if self.do_classifier_free_guidance: + ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0) + ca_mask = None + if cross_attn_mask is not None and uncond_cross_attn_mask is not None: + ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0) + else: + ca_embed = text_embeddings + ca_mask = cross_attn_mask + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Duplicate latents if using classifier-free guidance + if self.do_classifier_free_guidance: + latents_in = torch.cat([latents, latents], dim=0) + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device) + else: + latents_in = latents + # Normalize timestep for the transformer + t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device) + + # Forward through transformer + noise_pred = self.transformer( + hidden_states=latents_in, + timestep=t_cont, + encoder_hidden_states=ca_embed, + attention_mask=ca_mask, + return_dict=False, + )[0] + + # Apply CFG + if self.do_classifier_free_guidance: + noise_uncond, noise_text = noise_pred.chunk(2, dim=0) + noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_on_step_end(self, i, t, callback_kwargs) + + # Call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + # 8. Post-processing + if output_type == "latent" or (output_type == "pt" and self.vae is None): + image = latents + else: + # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC) + scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) + shift_factor = getattr(self.vae.config, "shift_factor", 0.0) + latents = (latents / scaling_factor) + shift_factor + # Decode using VAE (AutoencoderKL or AutoencoderDC) + image = self.vae.decode(latents, return_dict=False)[0] + # Resize back to original resolution if using binning + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + + # Use standard image processor for post-processing + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return PhotonPipelineOutput(images=image) diff --git a/tests/models/transformers/test_models_transformer_photon.py b/tests/models/transformers/test_models_transformer_photon.py new file mode 100644 index 000000000000..f5185245d399 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_photon.py @@ -0,0 +1,83 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel + +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class PhotonTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = PhotonTransformer2DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + return self.prepare_dummy_input() + + @property + def input_shape(self): + return (16, 16, 16) + + @property + def output_shape(self): + return (16, 16, 16) + + def prepare_dummy_input(self, height=16, width=16): + batch_size = 1 + num_latent_channels = 16 + sequence_length = 16 + embedding_dim = 1792 + + hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) + + return { + "hidden_states": hidden_states, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + } + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 16, + "patch_size": 2, + "context_in_dim": 1792, + "hidden_size": 1792, + "mlp_ratio": 3.5, + "num_heads": 28, + "depth": 4, # Smaller depth for testing + "axes_dim": [32, 32], + "theta": 10_000, + } + inputs_dict = self.prepare_dummy_input() + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"PhotonTransformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/pipelines/photon/__init__.py b/tests/pipelines/photon/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/photon/test_pipeline_photon.py b/tests/pipelines/photon/test_pipeline_photon.py new file mode 100644 index 000000000000..9c5803b5d0f4 --- /dev/null +++ b/tests/pipelines/photon/test_pipeline_photon.py @@ -0,0 +1,258 @@ +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer +from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig +from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder + +from diffusers.models import AutoencoderDC, AutoencoderKL +from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel +from diffusers.pipelines.photon.pipeline_photon import PhotonPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler + +from ..pipeline_params import TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +class PhotonPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = PhotonPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"]) + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + @classmethod + def setUpClass(cls): + # Ensure PhotonPipeline has an _execution_device property expected by __call__ + if not isinstance(getattr(PhotonPipeline, "_execution_device", None), property): + try: + setattr(PhotonPipeline, "_execution_device", property(lambda self: torch.device("cpu"))) + except Exception: + pass + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = PhotonTransformer2DModel( + patch_size=1, + in_channels=4, + context_in_dim=8, + hidden_size=8, + mlp_ratio=2.0, + num_heads=2, + depth=1, + axes_dim=[2, 2], + ) + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=4, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0, + scaling_factor=1.0, + ).eval() + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + + torch.manual_seed(0) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") + tokenizer.model_max_length = 64 + + torch.manual_seed(0) + + encoder_params = { + "vocab_size": tokenizer.vocab_size, + "hidden_size": 8, + "intermediate_size": 16, + "num_hidden_layers": 1, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "head_dim": 4, + "max_position_embeddings": 64, + "layer_types": ["full_attention"], + "attention_bias": False, + "attention_dropout": 0.0, + "dropout_rate": 0.0, + "hidden_activation": "gelu_pytorch_tanh", + "rms_norm_eps": 1e-06, + "attn_logit_softcapping": 50.0, + "final_logit_softcapping": 30.0, + "query_pre_attn_scalar": 4, + "rope_theta": 10000.0, + "sliding_window": 4096, + } + encoder_config = T5GemmaModuleConfig(**encoder_params) + text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params) + text_encoder = T5GemmaEncoder(text_encoder_config) + + return { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + return { + "prompt": "", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 1.0, + "height": 32, + "width": 32, + "output_type": "pt", + "use_resolution_binning": False, + } + + def test_inference(self): + device = "cpu" + components = self.get_dummy_components() + pipe = PhotonPipeline(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + try: + pipe.register_to_config(_execution_device="cpu") + except Exception: + pass + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs)[0] + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 32, 32)) + expected_image = torch.zeros(3, 32, 32) + max_diff = np.abs(generated_image - expected_image).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + components = self.get_dummy_components() + pipe = PhotonPipeline(**components) + pipe = pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + try: + pipe.register_to_config(_execution_device="cpu") + except Exception: + pass + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {PhotonPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + for tensor_name in callback_kwargs.keys(): + assert tensor_name in pipe._callback_tensor_inputs + return callback_kwargs + + inputs = self.get_dummy_inputs("cpu") + + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + _ = pipe(**inputs)[0] + + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + _ = pipe(**inputs)[0] + + def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + def to_np_local(tensor): + if isinstance(tensor, torch.Tensor): + return tensor.detach().cpu().numpy() + return tensor + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + max_diff1 = np.abs(to_np_local(output_with_slicing1) - to_np_local(output_without_slicing)).max() + max_diff2 = np.abs(to_np_local(output_with_slicing2) - to_np_local(output_without_slicing)).max() + self.assertLess(max(max_diff1, max_diff2), expected_max_diff) + + def test_inference_with_autoencoder_dc(self): + """Test PhotonPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL.""" + device = "cpu" + + components = self.get_dummy_components() + + torch.manual_seed(0) + vae_dc = AutoencoderDC( + in_channels=3, + latent_channels=4, + attention_head_dim=2, + encoder_block_types=( + "ResBlock", + "EfficientViTBlock", + ), + decoder_block_types=( + "ResBlock", + "EfficientViTBlock", + ), + encoder_block_out_channels=(8, 8), + decoder_block_out_channels=(8, 8), + encoder_qkv_multiscales=((), (5,)), + decoder_qkv_multiscales=((), (5,)), + encoder_layers_per_block=(1, 1), + decoder_layers_per_block=(1, 1), + upsample_block_type="interpolate", + downsample_block_type="stride_conv", + decoder_norm_types="rms_norm", + decoder_act_fns="silu", + ).eval() + + components["vae"] = vae_dc + + pipe = PhotonPipeline(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + expected_scale_factor = vae_dc.spatial_compression_ratio + self.assertEqual(pipe.vae_scale_factor, expected_scale_factor) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs)[0] + generated_image = image[0] + + self.assertEqual(generated_image.shape, (3, 32, 32)) + expected_image = torch.zeros(3, 32, 32) + max_diff = np.abs(generated_image - expected_image).max() + self.assertLessEqual(max_diff, 1e10)