Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
],
dim=1,
)
self.rope_cache = {}

# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
self.scale_rope = scale_rope
Expand All @@ -195,10 +194,20 @@ def rope_params(self, index, dim, theta=10000):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs

def forward(self, video_fhw, txt_seq_lens, device):
def forward(
self,
video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
txt_seq_lens: List[int],
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text
Args:
video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
A list of 3 integers [frame, height, width] representing the shape of the video.
txt_seq_lens (`List[int]`):
A list of integers of length batch_size representing the length of each text prompt.
device: (`torch.device`):
The device on which to perform the RoPE computation.
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
Expand All @@ -213,14 +222,8 @@ def forward(self, video_fhw, txt_seq_lens, device):
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"

if not torch.compiler.is_compiling():
if rope_key not in self.rope_cache:
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
video_freq = self.rope_cache[rope_key]
else:
video_freq = self._compute_video_freqs(frame, height, width, idx)
# RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
video_freq = self._compute_video_freqs(frame, height, width, idx)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)

Expand All @@ -235,8 +238,8 @@ def forward(self, video_fhw, txt_seq_lens, device):

return vid_freqs, txt_freqs

@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
@functools.lru_cache(maxsize=128)
def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
Expand Down
Loading