Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/diffusers/models/controlnets/controlnet_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,8 @@ def forward(
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)

# Construct joint attention mask once to avoid reconstructing in every block
block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {}
if encoder_hidden_states_mask is not None:
# Build joint mask: [text_mask, all_ones_for_image]
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
Expand Down
3 changes: 0 additions & 3 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,11 +935,8 @@ def forward(

image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)

# Construct joint attention mask once to avoid reconstructing in every block
# This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
if encoder_hidden_states_mask is not None:
# Build joint mask: [text_mask, all_ones_for_image]
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
Expand Down
27 changes: 10 additions & 17 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -210,14 +211,7 @@ def _get_qwen_prompt_embeds(
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -248,19 +242,18 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, max_sequence_length
)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
prompt_embeds_mask = None

return prompt_embeds, prompt_embeds_mask

Expand Down
20 changes: 5 additions & 15 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -274,14 +275,7 @@ def _get_qwen_prompt_embeds(
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -313,16 +307,12 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

return prompt_embeds, prompt_embeds_mask

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -256,14 +257,7 @@ def _get_qwen_prompt_embeds(
hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -294,16 +288,12 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

return prompt_embeds, prompt_embeds_mask

Expand Down
20 changes: 5 additions & 15 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -257,14 +258,7 @@ def _get_qwen_prompt_embeds(
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -298,16 +292,12 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

return prompt_embeds, prompt_embeds_mask

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -268,14 +269,7 @@ def _get_qwen_prompt_embeds(
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -310,16 +304,12 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

return prompt_embeds, prompt_embeds_mask

Expand Down
65 changes: 42 additions & 23 deletions src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput
from .utils import build_prompt_embeds_and_mask, concat_prompt_embeds_for_cfg, repeat_prompt_embeds_and_mask


if is_torch_xla_available():
Expand Down Expand Up @@ -270,14 +271,7 @@ def _get_qwen_prompt_embeds(
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states)

prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

Expand Down Expand Up @@ -312,16 +306,12 @@ def encode_prompt(
device = device or self._execution_device

prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]

if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)

_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask(
prompt_embeds, prompt_embeds_mask, num_images_per_prompt
)

return prompt_embeds, prompt_embeds_mask

Expand Down Expand Up @@ -724,6 +714,15 @@ def __call__(
max_sequence_length=max_sequence_length,
)

use_batch_cfg = do_true_cfg and not self.transformer.is_cache_enabled
if use_batch_cfg:
prompt_embeds, prompt_embeds_mask = concat_prompt_embeds_for_cfg(
prompt_embeds,
prompt_embeds_mask,
negative_prompt_embeds,
negative_prompt_embeds_mask,
)

# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, image_latents = self.prepare_latents(
Expand Down Expand Up @@ -799,7 +798,11 @@ def __call__(

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
with self.transformer.cache_context("cond"):
if use_batch_cfg:
latent_model_input = torch.cat([latent_model_input] * 2)
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)

if use_batch_cfg:
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
Expand All @@ -811,20 +814,36 @@ def __call__(
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]

if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
neg_noise_pred, noise_pred = noise_pred.chunk(2)
else:
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
noise_pred = noise_pred[:, : latents.size(1)]

if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states_mask=negative_prompt_embeds_mask,
encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes,
attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]

if do_true_cfg:
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
Expand Down
Loading
Loading