diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index fa374285eec1..78a566549377 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -213,10 +213,8 @@ def forward( encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - # Construct joint attention mask once to avoid reconstructing in every block block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {} if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index cf11d8e01fb4..8cf0b19d09d0 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -935,11 +935,8 @@ def forward( image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device) - # Construct joint attention mask once to avoid reconstructing in every block - # This eliminates 60 GPU syncs during training while maintaining torch.compile compatibility block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {} if encoder_hidden_states_mask is not None: - # Build joint mask: [text_mask, all_ones_for_image] batch_size, image_seq_len = hidden_states.shape[:2] image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device) joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py index bc3ce84e1019..88c6d74f92a8 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py @@ -27,6 +27,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -210,14 +211,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -248,19 +242,18 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, max_sequence_length + ) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py index ce6fc974a56e..28803542867a 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -274,14 +275,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -313,16 +307,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py index 77d78a5ca7a1..4c0a96a4eb3d 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet_inpaint.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -256,14 +257,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -294,16 +288,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py index dd723460a59e..e65be467df54 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -257,14 +258,7 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -298,16 +292,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py index cf467203a9d2..40a0d9f35464 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_inpaint.py @@ -29,6 +29,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -268,14 +269,7 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -310,16 +304,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py index 257e2d846c7c..a33366d7d1df 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, concat_prompt_embeds_for_cfg, repeat_prompt_embeds_and_mask if is_torch_xla_available(): @@ -270,14 +271,7 @@ def _get_qwen_prompt_embeds( hidden_states = outputs.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -312,16 +306,12 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask @@ -724,6 +714,15 @@ def __call__( max_sequence_length=max_sequence_length, ) + use_batch_cfg = do_true_cfg and not self.transformer.is_cache_enabled + if use_batch_cfg: + prompt_embeds, prompt_embeds_mask = concat_prompt_embeds_for_cfg( + prompt_embeds, + prompt_embeds_mask, + negative_prompt_embeds, + negative_prompt_embeds_mask, + ) + # 4. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, image_latents = self.prepare_latents( @@ -799,7 +798,11 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - with self.transformer.cache_context("cond"): + if use_batch_cfg: + latent_model_input = torch.cat([latent_model_input] * 2) + timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) + + if use_batch_cfg: noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, @@ -811,20 +814,36 @@ def __call__( return_dict=False, )[0] noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - with self.transformer.cache_context("uncond"): - neg_noise_pred = self.transformer( + neg_noise_pred, noise_pred = noise_pred.chunk(2) + else: + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, - encoder_hidden_states_mask=negative_prompt_embeds_mask, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_mask=prompt_embeds_mask, + encoder_hidden_states=prompt_embeds, img_shapes=img_shapes, attention_kwargs=self.attention_kwargs, return_dict=False, )[0] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=negative_prompt_embeds_mask, + encoder_hidden_states=negative_prompt_embeds, + img_shapes=img_shapes, + attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + + if do_true_cfg: comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py index e0b41b8b8799..c0aa0d56dd8f 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py @@ -13,6 +13,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -217,14 +218,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -291,19 +285,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, max_sequence_length + ) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py index 83f02539b1ba..52326f9001eb 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py @@ -14,6 +14,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -227,14 +228,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -302,19 +296,16 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, max_sequence_length + ) - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py index 53d2c169ee63..11d11167d359 100644 --- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py +++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_layered.py @@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import QwenImagePipelineOutput +from .utils import build_prompt_embeds_and_mask, repeat_prompt_embeds_and_mask, slice_prompt_embeds_and_mask if is_torch_xla_available(): @@ -275,14 +276,7 @@ def _get_qwen_prompt_embeds( hidden_states = encoder_hidden_states.hidden_states[-1] split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = [e[drop_idx:] for e in split_hidden_states] - attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] - max_seq_len = max([e.size(0) for e in split_hidden_states]) - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] - ) - encoder_attention_mask = torch.stack( - [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] - ) + prompt_embeds, encoder_attention_mask = build_prompt_embeds_and_mask(split_hidden_states) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) @@ -314,19 +308,15 @@ def encode_prompt( device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] - if prompt_embeds is None: prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) - prompt_embeds = prompt_embeds[:, :max_sequence_length] - prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] - - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) - prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + prompt_embeds, prompt_embeds_mask = slice_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, max_sequence_length + ) + prompt_embeds, prompt_embeds_mask = repeat_prompt_embeds_and_mask( + prompt_embeds, prompt_embeds_mask, num_images_per_prompt + ) return prompt_embeds, prompt_embeds_mask diff --git a/src/diffusers/pipelines/qwenimage/utils.py b/src/diffusers/pipelines/qwenimage/utils.py new file mode 100644 index 000000000000..7c91fec05a0a --- /dev/null +++ b/src/diffusers/pipelines/qwenimage/utils.py @@ -0,0 +1,92 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def build_prompt_embeds_and_mask(split_hidden_states): + seq_lens = [e.size(0) for e in split_hidden_states] + max_seq_len = max(seq_lens) + if all(seq_len == max_seq_len for seq_len in seq_lens): + prompt_embeds = torch.stack(split_hidden_states) + return prompt_embeds, None + + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + return prompt_embeds, encoder_attention_mask + + +def slice_prompt_embeds_and_mask(prompt_embeds, prompt_embeds_mask, max_sequence_length): + prompt_embeds = prompt_embeds[:, :max_sequence_length] + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + return prompt_embeds, prompt_embeds_mask + + +def repeat_prompt_embeds_and_mask(prompt_embeds, prompt_embeds_mask, num_images_per_prompt): + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + if prompt_embeds_mask is not None: + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + return prompt_embeds, prompt_embeds_mask + + +def concat_prompt_embeds_for_cfg( + prompt_embeds, prompt_embeds_mask, negative_prompt_embeds, negative_prompt_embeds_mask +): + pos_len = prompt_embeds.shape[1] + neg_len = negative_prompt_embeds.shape[1] + max_len = max(pos_len, neg_len) + + def _pad_prompt(embeds, mask): + orig_len = embeds.shape[1] + if orig_len != max_len: + pad_len = max_len - orig_len + embeds = torch.cat([embeds, embeds.new_zeros(embeds.shape[0], pad_len, embeds.shape[2])], dim=1) + if mask is None and orig_len != max_len: + mask = torch.ones((embeds.shape[0], orig_len), dtype=torch.long, device=embeds.device) + if mask is not None and mask.shape[1] != max_len: + pad_len = max_len - mask.shape[1] + mask = torch.cat([mask, mask.new_zeros(mask.shape[0], pad_len)], dim=1) + return embeds, mask + + prompt_embeds, prompt_embeds_mask = _pad_prompt(prompt_embeds, prompt_embeds_mask) + negative_prompt_embeds, negative_prompt_embeds_mask = _pad_prompt( + negative_prompt_embeds, negative_prompt_embeds_mask + ) + + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + if prompt_embeds_mask is None and negative_prompt_embeds_mask is None: + prompt_embeds_mask = None + else: + batch_half = prompt_embeds.shape[0] // 2 + if prompt_embeds_mask is None: + prompt_embeds_mask = torch.ones((batch_half, max_len), dtype=torch.long, device=prompt_embeds.device) + if negative_prompt_embeds_mask is None: + negative_prompt_embeds_mask = torch.ones( + (batch_half, max_len), dtype=torch.long, device=prompt_embeds.device + ) + prompt_embeds_mask = torch.cat([negative_prompt_embeds_mask, prompt_embeds_mask], dim=0) + + if prompt_embeds_mask is not None and prompt_embeds_mask.all(): + prompt_embeds_mask = None + + return prompt_embeds, prompt_embeds_mask diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index 384954dfbad7..6acd7fb500ee 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -276,3 +276,45 @@ def prepare_dummy_input(self, height, width): def test_torch_compile_recompilation_and_graph_break(self): super().test_torch_compile_recompilation_and_graph_break() + + def test_torch_compile_with_and_without_mask(self): + """Test that torch.compile works with both None mask and padding mask.""" + init_dict, inputs = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + model.eval() + + compiled_model = torch.compile(model, mode="default", fullgraph=False) + + # Test 1: Run with None mask (no padding, all tokens are valid) + inputs_no_mask = inputs.copy() + inputs_no_mask["encoder_hidden_states_mask"] = None + + with torch.no_grad(): + output_no_mask = compiled_model(**inputs_no_mask) + + self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 2: Run with all-ones mask (should behave like None) + inputs_all_ones = inputs.copy() + # Keep the all-ones mask + self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item()) + + with torch.no_grad(): + output_all_ones = compiled_model(**inputs_all_ones) + + self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Test 3: Run with actual padding mask (has zeros) + inputs_with_padding = inputs.copy() + mask_with_padding = inputs["encoder_hidden_states_mask"].clone() + mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding + + inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding + + with torch.no_grad(): + output_with_padding = compiled_model(**inputs_with_padding) + + self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1]) + + # Verify that outputs are different (mask should affect results) + self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))