From 661febb64b3f99def7f042871816e197745985f3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 16 Jan 2026 14:22:36 +0000 Subject: [PATCH 1/4] avoid creating attention masks when there is no padding --- .../pipelines/qwenimage/pipeline_qwenimage.py | 26 ++---- .../pipeline_qwenimage_controlnet.py | 20 ++--- .../pipeline_qwenimage_controlnet_inpaint.py | 20 ++--- .../qwenimage/pipeline_qwenimage_edit.py | 20 ++--- .../pipeline_qwenimage_edit_inpaint.py | 20 ++--- .../qwenimage/pipeline_qwenimage_edit_plus.py | 65 +++++++++----- .../qwenimage/pipeline_qwenimage_img2img.py | 25 ++---- .../qwenimage/pipeline_qwenimage_inpaint.py | 25 ++---- .../qwenimage/pipeline_qwenimage_layered.py | 25 ++---- src/diffusers/pipelines/qwenimage/utils.py | 89 +++++++++++++++++++ 10 files changed, 183 insertions(+), 152 deletions(-) create mode 100644 src/diffusers/pipelines/qwenimage/utils.py diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index bc3ce84e1019..bf289be905ff 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -27,6 +27,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -210,14 +211,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -248,19 +242,15 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, max_sequence_length + ) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index ce6fc974a56e..28803542867a 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -274,14 +275,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -313,16 +307,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 77d78a5ca7a1..4c0a96a4eb3d 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -256,14 +257,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -294,16 +288,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index dd723460a59e..e65be467df54 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -257,14 +258,7 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -298,16 +292,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index cf467203a9d2..40a0d9f35464 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -29,6 +29,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -268,14 +269,7 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -310,16 +304,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 257e2d846c7c..a33366d7d1df 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, concat_prompt_embeds_for_cfg, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -270,14 +271,7 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -312,16 +306,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask @@ -724,6 +714,15 @@ def __call__( max_sequence_length=max_sequence_length, ) + use_batch_cfg = do_true_cfg and not self.transformer.is_cache_enabled + if use_batch_cfg: + prompt_embeds, prompt_embeds_mask = concat_prompt_embeds_for_cfg( + prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds, + negative_prompt_embeds_mask, + ) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, image_latents = self.prepare_latents( @@ -799,7 +798,11 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - with self.transformer.cache_context("cond"): + if use_batch_cfg: + latent_model_input = torch.cat([latent_model_input] * 2) + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + if use_batch_cfg: noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, @@ -811,20 +814,36 @@ def __call__( return_dict=False, )[0] noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - with self.transformer.cache_context("uncond"): - neg_noise_pred = self.transformer( + neg_noise_pred, noise_pred = noise_pred.chunk(2) + else: + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + + if do_true_cfg: comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index e0b41b8b8799..c0aa0d56dd8f 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -13,6 +13,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -217,14 +218,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -291,19 +285,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, max_sequence_length + ) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 83f02539b1ba..52326f9001eb 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -14,6 +14,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -227,14 +228,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -302,19 +296,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, max_sequence_length + ) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 53d2c169ee63..4da2406e046f 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -275,14 +276,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -314,19 +308,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, max_sequence_length + ) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/utils.py b/src/diffusers/pipelines/qwenimage/utils.py new file mode 100644 index 000000000000..7271fce9304e --- /dev/null +++ b/src/diffusers/pipelines/qwenimage/utils.py @@ -0,0 +1,89 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def build_prompt_embeds_and_mask(split_hidden_states): + seq_lens = [e.size(0) for e in split_hidden_states] + max_seq_len = max(seq_lens) + if all(seq_len == max_seq_len for seq_len in seq_lens): + prompt_embeds = torch.stack(split_hidden_states) + return prompt_embeds, None + + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + return prompt_embeds, encoder_attention_mask + + +def slice_prompt_embeds_and_mask(prompt_embeds, prompt_embeds_mask, max_sequence_length): + prompt_embeds = prompt_embeds[:, :max_sequence_length] + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + return prompt_embeds, prompt_embeds_mask + + +def repeat_prompt_embeds_and_mask(prompt_embeds, prompt_embeds_mask, num_images_per_prompt): + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + return prompt_embeds, prompt_embeds_mask + + +def concat_prompt_embeds_for_cfg( + prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask +): + pos_len = prompt_embeds.shape[1] + neg_len = negative_prompt_embeds.shape[1] + max_len = max(pos_len, neg_len) + + def _pad_prompt(embeds, mask): + orig_len = embeds.shape[1] + if orig_len != max_len: + pad_len = max_len - orig_len + embeds = torch.cat([embeds, embeds.new_zeros(embeds.shape[0], pad_len, embeds.shape[2])], dim=1) + if mask is None and orig_len != max_len: + mask = torch.ones((embeds.shape[0], orig_len), dtype=torch.long, device=embeds.device) + if mask is not None and mask.shape[1] != max_len: + pad_len = max_len - mask.shape[1] + mask = torch.cat([mask, mask.new_zeros(mask.shape[0], pad_len)], dim=1) + return embeds, mask + + prompt_embeds, prompt_embeds_mask = _pad_prompt(prompt_embeds, prompt_embeds_mask) + negative_prompt_embeds, negative_prompt_embeds_mask = _pad_prompt( + negative_prompt_embeds, negative_prompt_embeds_mask + ) + + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + if prompt_embeds_mask is None and negative_prompt_embeds_mask is None: + prompt_embeds_mask = None + else: + batch_half = prompt_embeds.shape[0] // 2 + if prompt_embeds_mask is None: + prompt_embeds_mask = torch.ones((batch_half, max_len), dtype=torch.long, device=prompt_embeds.device) + if negative_prompt_embeds_mask is None: + negative_prompt_embeds_mask = torch.ones( + (batch_half, max_len), dtype=torch.long, device=prompt_embeds.device + ) + prompt_embeds_mask = torch.cat([negative_prompt_embeds_mask, prompt_embeds_mask], dim=0) + + return prompt_embeds, prompt_embeds_mask From 5507b5e8b982cb837df5354f3d4ca3f3fefecedd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 16 Jan 2026 14:28:05 +0000 Subject: [PATCH 2/4] make fix-copies --- src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 4da2406e046f..11d11167d359 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -314,7 +314,6 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( prompt_embeds, prompt_embeds_mask, max_sequence_length ) - prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) From 4839fcfc312c7e85a76cd99df7b616ea813fa372 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 17 Jan 2026 23:34:48 +0000 Subject: [PATCH 3/4] torch compile tests --- .../controlnets/controlnet_qwenimage.py | 2 - .../transformers/transformer_qwenimage.py | 3 -- .../pipelines/qwenimage/pipeline_qwenimage.py | 3 ++ src/diffusers/pipelines/qwenimage/utils.py | 3 ++ .../test_models_transformer_qwenimage.py | 42 +++++++++++++++++++ 5 files changed, 48 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index fa374285eec1..78a566549377 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -213,10 +213,8 @@ def forward( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - # Construct joint attention mask once to avoid reconstructing in every block block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {} if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index cf11d8e01fb4..8cf0b19d09d0 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -935,11 +935,8 @@ def forward( image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) - # Construct joint attention mask once to avoid reconstructing in every block - # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index bf289be905ff..88c6d74f92a8 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -252,6 +252,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/utils.py b/src/diffusers/pipelines/qwenimage/utils.py index 7271fce9304e..7c91fec05a0a 100644 --- a/src/diffusers/pipelines/qwenimage/utils.py +++ b/src/diffusers/pipelines/qwenimage/utils.py @@ -86,4 +86,7 @@ def _pad_prompt(embeds, mask): ) prompt_embeds_mask = torch.cat([negative_prompt_embeds_mask, prompt_embeds_mask], dim=0) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 384954dfbad7..6acd7fb500ee 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -276,3 +276,45 @@ def prepare_dummy_input(self, height, width): def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break() + + def test_torch_compile_with_and_without_mask(self): + """Test that torch.compile works with both None mask and padding mask.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + compiled_model = torch.compile(model, mode="default", fullgraph=False) + + # Test 1: Run with None mask (no padding, all tokens are valid) + inputs_no_mask = inputs.copy() + inputs_no_mask["encoder_hidden_states_mask"] = None + + with torch.no_grad(): + output_no_mask = compiled_model(**inputs_no_mask) + + self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 2: Run with all-ones mask (should behave like None) + inputs_all_ones = inputs.copy() + # Keep the all-ones mask + self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item()) + + with torch.no_grad(): + output_all_ones = compiled_model(**inputs_all_ones) + + self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 3: Run with actual padding mask (has zeros) + inputs_with_padding = inputs.copy() + mask_with_padding = inputs["encoder_hidden_states_mask"].clone() + mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding + + inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding + + with torch.no_grad(): + output_with_padding = compiled_model(**inputs_with_padding) + + self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Verify that outputs are different (mask should affect results) + self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)) From 23150e46ab7ee87132fe4d23808d0c0b2961631b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Jan 2026 11:35:24 +0000 Subject: [PATCH 4/4] set all ones mask to none --- .../pipelines/qwenimage/pipeline_qwenimage_controlnet.py | 3 +++ .../qwenimage/pipeline_qwenimage_controlnet_inpaint.py | 3 +++ src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py | 3 +++ .../pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py | 3 +++ .../pipelines/qwenimage/pipeline_qwenimage_edit_plus.py | 3 +++ .../pipelines/qwenimage/pipeline_qwenimage_img2img.py | 3 +++ .../pipelines/qwenimage/pipeline_qwenimage_inpaint.py | 3 +++ .../pipelines/qwenimage/pipeline_qwenimage_layered.py | 3 +++ 8 files changed, 24 insertions(+) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index 28803542867a..4ee51151701f 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -314,6 +314,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 4c0a96a4eb3d..1f39ce08246e 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -295,6 +295,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index e65be467df54..fd278def7245 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -299,6 +299,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index 40a0d9f35464..4b56e8c9daa8 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -311,6 +311,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.check_inputs diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index a33366d7d1df..d5fc2b78ae73 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -313,6 +313,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index c0aa0d56dd8f..4da613e4d6a2 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -296,6 +296,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 52326f9001eb..109148c0f923 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -307,6 +307,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def check_inputs( diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 11d11167d359..c8c5994e612b 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -318,6 +318,9 @@ def encode_prompt( prompt_embeds, prompt_embeds_mask, num_images_per_prompt ) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + return prompt_embeds, prompt_embeds_mask def get_image_caption(self, prompt_image, use_en_prompt=True, device=None):