diff --git a/examples/z_image/predict_t2i_omni.py b/examples/z_image/predict_t2i_omni.py new file mode 100644 index 00000000..58e6b2f3 --- /dev/null +++ b/examples/z_image/predict_t2i_omni.py @@ -0,0 +1,233 @@ +import os +import sys + +import torch +from diffusers import FlowMatchEulerDiscreteScheduler + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.dist import set_multi_gpus_devices, shard_model +from videox_fun.models import (AutoencoderKL, AutoProcessor, AutoTokenizer, + Qwen3ForCausalLM, Siglip2VisionModel, + ZImageOmniTransformer2DModel) +from videox_fun.models.cache_utils import get_teacache_coefficients +from videox_fun.pipeline import ZImageOmniPipeline +from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler +from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper) +from videox_fun.utils.lora_utils import merge_lora, unmerge_lora +from videox_fun.utils.utils import (filter_kwargs, get_image, get_image_latent, + get_image_to_video_latent, + get_video_to_video_latent, + save_videos_grid) + +# GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload]. +# model_full_load means that the entire model will be moved to the GPU. +# +# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory. +# +# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, +# resulting in slower speeds but saving a large amount of GPU memory. +GPU_memory_mode = "model_cpu_offload" +# Multi GPUs config +# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used. +# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4. +# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1. +ulysses_degree = 1 +ring_degree = 1 +# Use FSDP to save more GPU memory in multi gpus. +fsdp_dit = False +fsdp_text_encoder = False +# Compile will give a speedup in fixed resolution and need a little GPU memory. +# The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload. +compile_dit = False + +# model path +model_name = "models/Diffusion_Transformer/Z-Image-Base-Omni" + +# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++" +sampler_name = "Flow" + +# Load pretrained model if nee +transformer_path = None +vae_path = None +lora_path = None + +# Other params +sample_size = [1568, 1184] + +# Use torch.float16 if GPU does not support torch.bfloat16 +# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 +weight_dtype = torch.bfloat16 +image = None + +# Please use as detailed a prompt as possible to describe the object that needs to be generated. +prompt = "这是一张充满东方古典韵味的人像摄影作品,画面中的年轻女子身着一袭精致的香槟色旗袍蹲在地上,面料上点缀着精美的白色刺绣花纹,在阳光照射下泛着柔和的光泽。" +negative_prompt = "" +guidance_scale = 5.00 +seed = 42 +num_inference_steps = 40 +lora_weight = 0.55 +save_path = "samples/z-image-omni" + +device = set_multi_gpus_devices(ulysses_degree, ring_degree) + +transformer = ZImageOmniTransformer2DModel.from_pretrained( + model_name, + subfolder="transformer", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, +).to(weight_dtype) + +if transformer_path is not None: + print(f"From checkpoint: {transformer_path}") + if transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(transformer_path) + else: + state_dict = torch.load(transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Get Vae +vae = AutoencoderKL.from_pretrained( + model_name, + subfolder="vae" +).to(weight_dtype) + +if vae_path is not None: + print(f"From checkpoint: {vae_path}") + if vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(vae_path) + else: + state_dict = torch.load(vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Get tokenizer and text_encoder +tokenizer = AutoTokenizer.from_pretrained( + model_name, subfolder="tokenizer" +) +text_encoder = Qwen3ForCausalLM.from_pretrained( + model_name, subfolder="text_encoder", torch_dtype=weight_dtype, + low_cpu_mem_usage=True, +) + +siglip = Siglip2VisionModel.from_pretrained( + model_name, subfolder="clip_encoder", + torch_dtype=weight_dtype, +) +siglip_processor = AutoProcessor.from_pretrained( + model_name, subfolder="clip_encoder", +) + +# Get Scheduler +Chosen_Scheduler = scheduler_dict = { + "Flow": FlowMatchEulerDiscreteScheduler, + "Flow_Unipc": FlowUniPCMultistepScheduler, + "Flow_DPM++": FlowDPMSolverMultistepScheduler, +}[sampler_name] +scheduler = Chosen_Scheduler.from_pretrained( + model_name, + subfolder="scheduler" +) + +pipeline = ZImageOmniPipeline( + vae=vae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + siglip=siglip, + siglip_processor=siglip_processor, + scheduler=scheduler, +) + +if ulysses_degree > 1 or ring_degree > 1: + from functools import partial + transformer.enable_multi_gpus_inference() + if fsdp_dit: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=list(transformer.layers)) + pipeline.transformer = shard_fn(pipeline.transformer) + print("Add FSDP DIT") + if fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=list(text_encoder.model.layers)) + text_encoder = shard_fn(text_encoder) + print("Add FSDP TEXT ENCODER") + +if compile_dit: + for i in range(len(pipeline.transformer.transformer_blocks)): + pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i]) + print("Add Compile") + +if GPU_memory_mode == "sequential_cpu_offload": + pipeline.enable_sequential_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload": + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.to(device=device) +else: + pipeline.to(device=device) + +generator = torch.Generator(device=device).manual_seed(seed) + +if lora_path is not None: + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +with torch.no_grad(): + if image is not None: + if not isinstance(image, list): + image = get_image(image).convert("RGB") + else: + image = [get_image(_image).convert("RGB") for _image in image] + + sample = pipeline( + image = image, + prompt = prompt, + negative_prompt = negative_prompt, + height = sample_size[0], + width = sample_size[1], + generator = generator, + guidance_scale = guidance_scale, + num_inference_steps = num_inference_steps, + ).images + +if lora_path is not None: + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +def save_results(): + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + index = len([path for path in os.listdir(save_path)]) + 1 + prefix = str(index).zfill(8) + video_path = os.path.join(save_path, prefix + ".png") + image = sample[0] + image.save(video_path) + +if ulysses_degree * ring_degree > 1: + import torch.distributed as dist + if dist.get_rank() == 0: + save_results() +else: + save_results() \ No newline at end of file diff --git a/scripts/z_image/train_lora_omni.py b/scripts/z_image/train_lora_omni.py new file mode 100644 index 00000000..bf5103f5 --- /dev/null +++ b/scripts/z_image/train_lora_omni.py @@ -0,0 +1,1707 @@ +"""Modified from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py +""" +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import gc +import logging +import math +import os +import pickle +import random +import shutil +import sys +from typing import (Any, Callable, Dict, List, NamedTuple, Optional, Tuple, + Union) + +import accelerate +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers import DDIMScheduler, FlowMatchEulerDiscreteScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import (EMAModel, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3) +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from torch.utils.data import RandomSampler +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.utils import ContextManagers + +import datasets + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.data.bucket_sampler import (ASPECT_RATIO_512, + ASPECT_RATIO_RANDOM_CROP_512, + ASPECT_RATIO_RANDOM_CROP_PROB, + AspectRatioBatchImageVideoSampler, + RandomSampler, get_closest_ratio) +from videox_fun.data.dataset_image import ImageEditDataset +from videox_fun.data.dataset_image_video import (ImageVideoControlDataset, + ImageVideoDataset, + ImageVideoSampler, + get_random_mask) +from videox_fun.dist import set_multi_gpus_devices, shard_model +from videox_fun.models import (AutoencoderKL, AutoProcessor, AutoTokenizer, + CLIPImageProcessor, + CLIPVisionModelWithProjection, Qwen2Tokenizer, + Qwen3ForCausalLM, QwenImageTransformer2DModel, + Siglip2VisionModel, + ZImageOmniTransformer2DModel) +from videox_fun.models.flux2_image_processor import Flux2ImageProcessor +from videox_fun.pipeline import ZImageOmniPipeline +from videox_fun.utils.discrete_sampler import DiscreteSampling +from videox_fun.utils.lora_utils import (create_network, merge_lora, + unmerge_lora) +from videox_fun.utils.utils import get_image_to_video_latent, save_videos_grid + +if is_wandb_available(): + import wandb + +def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + +def linear_decay(initial_value, final_value, total_steps, current_step): + if current_step >= total_steps: + return final_value + current_step = max(0, current_step) + step_size = (final_value - initial_value) / total_steps + current_value = initial_value + step_size * current_step + return current_value + +def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=None): + u = torch.normal(mean=0.0, std=1.0, size=shape, device=device, generator=generator) + t = 1 / (1 + torch.exp(-u)) * (high - low) + low + return torch.clip(t.to(torch.int32), low, high - 1) + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu +def encode_prompt( + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + text_encoder = None, + tokenizer = None, + max_sequence_length: int = 512, + num_condition_images = 0, +) -> List[torch.FloatTensor]: + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + + if num_condition_images == 0: + prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] + elif num_condition_images > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) + prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|im_end|>"] + prompt[i] = prompt_list + + flattened_prompt = [] + prompt_list_lengths = [] + + for i in range(len(prompt)): + prompt_list_lengths.append(len(prompt[i])) + flattened_prompt.extend(prompt[i]) + + text_inputs = tokenizer( + flattened_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + start_idx = 0 + for i in range(len(prompt_list_lengths)): + batch_embeddings = [] + end_idx = start_idx + prompt_list_lengths[i] + for j in range(start_idx, end_idx): + batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) + embeddings_list.append(batch_embeddings) + start_idx = end_idx + + return embeddings_list + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +def log_validation(vae, text_encoder, tokenizer, transformer3d, network, args, accelerator, weight_dtype, global_step): + try: + logger.info("Running validation... ") + + transformer3d_val = ZImageOmniTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", torch_dtype=weight_dtype, + low_cpu_mem_usage=True, + ).to(weight_dtype) + transformer3d_val.load_state_dict(accelerator.unwrap_model(transformer3d).state_dict()) + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler" + ) + transformer3d = transformer3d.to("cpu") + pipeline = ZImageOmniPipeline( + vae=accelerator.unwrap_model(vae).to(weight_dtype), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + transformer=transformer3d_val, + scheduler=scheduler, + siglip=siglip, + siglip_processor=siglip_processor, + ) + pipeline = pipeline.to(accelerator.device) + pipeline = merge_lora( + pipeline, None, 1, accelerator.device, state_dict=accelerator.unwrap_model(network).state_dict(), transformer_only=True + ) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + for i in range(len(args.validation_prompts)): + with torch.no_grad(): + image = [get_image(args.validation_image_paths[i])] + sample = pipeline( + prompt = args.validation_prompts[i], + negative_prompt = "bad detailed", + height = args.image_sample_size, + width = args.image_sample_size, + generator = generator, + image = image + ).images + os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + image = sample[0].save(os.path.join(args.output_dir, f"sample/sample-{global_step}-image-{i}.gif")) + + del pipeline + del transformer3d_val + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + transformer3d = transformer3d.to(accelerator.device) + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + print(f"Eval error with info {e}") + transformer3d = transformer3d.to(accelerator.device) + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. " + ), + ) + parser.add_argument( + "--train_data_meta", + type=str, + default=None, + help=( + "A csv containing the training data. " + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--validation_image_paths", + type=str, + default=None, + nargs="+", + help=("A set of images evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--use_came", + action="store_true", + help="whether to use came", + ) + parser.add_argument( + "--multi_stream", + action="store_true", + help="whether to use cuda multi-stream", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--vae_mini_batch", type=int, default=32, help="mini batch size for vae." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=2000, + help="Run validation every X steps.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + parser.add_argument( + "--rank", + type=int, + default=128, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--network_alpha", + type=int, + default=64, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--use_peft_lora", action="store_true", help="Whether or not to use peft lora." + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--snr_loss", action="store_true", help="Whether or not to use snr_loss." + ) + parser.add_argument( + "--uniform_sampling", action="store_true", help="Whether or not to use uniform_sampling." + ) + parser.add_argument( + "--enable_text_encoder_in_dataloader", action="store_true", help="Whether or not to use text encoder in dataloader." + ) + parser.add_argument( + "--enable_bucket", action="store_true", help="Whether enable bucket sample in datasets." + ) + parser.add_argument( + "--random_ratio_crop", action="store_true", help="Whether enable random ratio crop sample in datasets." + ) + parser.add_argument( + "--random_hw_adapt", action="store_true", help="Whether enable random adapt height and width in datasets." + ) + parser.add_argument( + "--train_sampling_steps", + type=int, + default=1000, + help="Run train_sampling_steps.", + ) + parser.add_argument( + "--image_sample_size", + type=int, + default=512, + help="Sample size of the image.", + ) + parser.add_argument( + "--fix_sample_size", + nargs=2, type=int, default=None, + help="Fix Sample size [height, width] when using bucket and collate_fn." + ) + parser.add_argument( + "--transformer_path", + type=str, + default=None, + help=("If you want to load the weight from other transformers, input its path."), + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help=("If you want to load the weight from other vaes, input its path."), + ) + parser.add_argument("--save_state", action="store_true", help="Whether or not to save state.") + + parser.add_argument( + "--use_deepspeed", action="store_true", help="Whether or not to use deepspeed." + ) + parser.add_argument( + "--use_fsdp", action="store_true", help="Whether or not to use fsdp." + ) + parser.add_argument( + "--low_vram", action="store_true", help="Whether enable low_vram mode." + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--lora_skip_name", + type=str, + default=None, + help=("The module is not trained in loras. "), + ) + parser.add_argument( + "--target_name", + type=str, + default=None, + help=("The module is trained in loras. "), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + args = parse_args() + if args.train_batch_size >= 2: + raise ValueError("This code does not support args.train_batch_size >= 2 now.") + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + deepspeed_plugin = accelerator.state.deepspeed_plugin if hasattr(accelerator.state, "deepspeed_plugin") else None + fsdp_plugin = accelerator.state.fsdp_plugin if hasattr(accelerator.state, "fsdp_plugin") else None + if deepspeed_plugin is not None: + zero_stage = int(deepspeed_plugin.zero_stage) + fsdp_stage = 0 + print(f"Using DeepSpeed Zero stage: {zero_stage}") + + args.use_deepspeed = True + if zero_stage == 3: + print(f"Auto set save_state to True because zero_stage == 3") + args.save_state = True + elif fsdp_plugin is not None: + from torch.distributed.fsdp import ShardingStrategy + zero_stage = 0 + if fsdp_plugin.sharding_strategy is ShardingStrategy.FULL_SHARD: + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is None: # The fsdp_plugin.sharding_strategy is None in FSDP 2. + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is ShardingStrategy.SHARD_GRAD_OP: + fsdp_stage = 2 + else: + fsdp_stage = 0 + print(f"Using FSDP stage: {fsdp_stage}") + + args.use_fsdp = True + if fsdp_stage == 3: + print(f"Auto set save_state to True because fsdp_stage == 3") + args.save_state = True + else: + zero_stage = 0 + fsdp_stage = 0 + print("DeepSpeed is not enabled.") + + if accelerator.is_main_process: + writer = SummaryWriter(log_dir=logging_dir) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + rng = np.random.default_rng(np.random.PCG64(args.seed + accelerator.process_index)) + torch_rng = torch.Generator(accelerator.device).manual_seed(args.seed + accelerator.process_index) + else: + rng = None + torch_rng = None + index_rng = np.random.default_rng(np.random.PCG64(43)) + print(f"Init rng with seed {args.seed + accelerator.process_index}. Process_index is {accelerator.process_index}") + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Load scheduler, tokenizer and models. + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler" + ) + + # Get Tokenizer + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer" + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So Qwen3ForCausalLM and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # Get Text encoder + text_encoder = Qwen3ForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype + ) + text_encoder = text_encoder.eval() + # Get Vae + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae" + ).to(weight_dtype) + vae.eval() + + siglip_path = "models/siglip2-so400m-patch16-naflex" + siglip = Siglip2VisionModel.from_pretrained(siglip_path, torch_dtype=torch.bfloat16).to(accelerator.device) + siglip_processor = AutoProcessor.from_pretrained(siglip_path) + + # Get Transformer + transformer3d = ZImageOmniTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=weight_dtype, + ).to(weight_dtype) + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + image_processor = Flux2ImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + # Freeze vae and text_encoder and set transformer3d to trainable + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + transformer3d.requires_grad_(False) + + # Lora will work with this... + if args.use_peft_lora: + from peft import (LoraConfig, get_peft_model_state_dict, + inject_adapter_in_model) + lora_config = LoraConfig(r=args.rank, lora_alpha=args.network_alpha, target_modules=args.target_name.split(",")) + transformer3d = inject_adapter_in_model(lora_config, transformer3d) + + network = None + else: + network = create_network( + 1.0, + args.rank, + args.network_alpha, + text_encoder, + transformer3d, + neuron_dropout=None, + target_name=args.target_name, + skip_name=args.lora_skip_name, + ) + network = network.to(weight_dtype) + network.apply_to(text_encoder, transformer3d, args.train_text_encoder and not args.training_with_video_token_length, True) + + if args.transformer_path is not None: + print(f"From checkpoint: {args.transformer_path}") + if args.transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.transformer_path) + else: + state_dict = torch.load(args.transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer3d.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + if args.vae_path is not None: + print(f"From checkpoint: {args.vae_path}") + if args.vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.vae_path) + else: + state_dict = torch.load(args.vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + if fsdp_stage != 0: + def save_model_hook(models, weights, output_dir): + accelerate_state_dict = accelerator.get_state_dict(models[-1], unwrap=True) + if accelerator.is_main_process: + from safetensors.torch import save_file + safetensor_save_path = os.path.join(output_dir, f"lora_diffusion_pytorch_model.safetensors") + if args.use_peft_lora: + network_state_dict = get_peft_model_state_dict(accelerator.unwrap_model(models[-1]), accelerate_state_dict) + else: + network_state_dict = {} + for key in accelerate_state_dict: + if "network" in key: + network_state_dict[key.replace("network.", "")] = accelerate_state_dict[key].to(weight_dtype) + save_file(network_state_dict, safetensor_save_path, metadata={"format": "pt"}) + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + elif zero_stage == 3: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + accelerate_state_dict = accelerator.get_state_dict(models[-1], unwrap=True) + if accelerator.is_main_process: + from safetensors.torch import save_file + safetensor_save_path = os.path.join(output_dir, f"lora_diffusion_pytorch_model.safetensors") + if args.use_peft_lora: + network_state_dict = get_peft_model_state_dict(accelerator.unwrap_model(models[-1]), accelerate_state_dict) + else: + network_state_dict = accelerate_state_dict + save_file(network_state_dict, safetensor_save_path, metadata={"format": "pt"}) + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + else: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + safetensor_save_path = os.path.join(output_dir, f"lora_diffusion_pytorch_model.safetensors") + if args.use_peft_lora: + save_model(safetensor_save_path, get_peft_model_state_dict(accelerator.unwrap_model(models[-1]))) + else: + save_model(safetensor_save_path, accelerator.unwrap_model(models[-1])) + + if not args.use_deepspeed: + for _ in range(len(weights)): + weights.pop() + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + transformer3d.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + elif args.use_came: + try: + from came_pytorch import CAME + except: + raise ImportError( + "Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`" + ) + + optimizer_cls = CAME + else: + optimizer_cls = torch.optim.AdamW + + if args.use_peft_lora: + logging.info("Add peft parameters") + trainable_params = list(filter(lambda p: p.requires_grad, transformer3d.parameters())) + trainable_params_optim = list(filter(lambda p: p.requires_grad, transformer3d.parameters())) + else: + logging.info("Add network parameters") + trainable_params = list(filter(lambda p: p.requires_grad, network.parameters())) + trainable_params_optim = network.prepare_optimizer_params(args.learning_rate / 2, args.learning_rate, args.learning_rate) + + if args.use_came: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + # weight_decay=args.adam_weight_decay, + betas=(0.9, 0.999, 0.9999), + eps=(1e-30, 1e-16) + ) + else: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the training dataset + if args.fix_sample_size is not None and args.enable_bucket: + args.image_sample_size = max(max(args.fix_sample_size), args.image_sample_size) + args.random_hw_adapt = False + + # Get the dataset + train_dataset = ImageEditDataset( + args.train_data_meta, args.train_data_dir, + image_sample_size=args.image_sample_size, + enable_bucket=args.enable_bucket, + ) + + def worker_init_fn(_seed): + _seed = _seed * 256 + def _worker_init_fn(worker_id): + print(f"worker_init_fn with {_seed + worker_id}") + np.random.seed(_seed + worker_id) + random.seed(_seed + worker_id) + return _worker_init_fn + + if args.enable_bucket: + aspect_ratio_sample_size = {key : [x / 512 * args.image_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = AspectRatioBatchImageVideoSampler( + sampler=RandomSampler(train_dataset, generator=batch_sampler_generator), dataset=train_dataset.dataset, + batch_size=args.train_batch_size, train_folder = args.train_data_dir, drop_last=True, + aspect_ratios=aspect_ratio_sample_size, + ) + + + def collate_fn(examples): + def get_random_downsample_ratio(sample_size, image_ratio=[], + all_choices=False, rng=None): + def _create_special_list(length): + if length == 1: + return [1.0] + first_element = 0.90 + remaining_sum = 1.0 - first_element + other_elements_value = remaining_sum / (length - 1) + return [first_element] + [other_elements_value] * (length - 1) + + MIN_TARGET = 1024 + + if sample_size < MIN_TARGET: + number_list = [1.0] + else: + max_allowed_ratio = sample_size / MIN_TARGET + base_ratios = [ + 1.0, + 1.1, 1.2, 1.25, 1.33, 1.5, + 1.75, 2.0, 2.25, 2.5, 2.75, + 3.0, 3.5, 4.0, 5.0, 6.0, 8.0 + ] + candidate_ratios = set(base_ratios + list(image_ratio)) + number_list = sorted([r for r in candidate_ratios if 1.0 <= r <= max_allowed_ratio]) + + if not number_list: + number_list = [1.0] + + if all_choices: + return number_list + + probs = np.array(_create_special_list(len(number_list))) + if rng is None: + return np.random.choice(number_list, p=probs) + else: + return rng.choice(number_list, p=probs) + + # Create new output + new_examples = {} + new_examples["pixel_values"] = [] + new_examples["source_pixel_values"] = [] + new_examples["text"] = [] + + # Get downsample ratio in image + pixel_value = examples[0]["pixel_values"] + source_pixel_values = examples[0]["source_pixel_values"] + data_type = examples[0]["data_type"] + f, h, w, c = np.shape(pixel_value) + + random_downsample_ratio = 1 if not args.random_hw_adapt else get_random_downsample_ratio(args.image_sample_size) + + aspect_ratio_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()} + + if args.fix_sample_size is not None: + fix_sample_size = [int(x / 16) * 16 for x in args.fix_sample_size] + elif args.random_ratio_crop: + if rng is None: + random_sample_size = aspect_ratio_random_crop_sample_size[ + np.random.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + else: + random_sample_size = aspect_ratio_random_crop_sample_size[ + rng.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + random_sample_size = [int(x / 16) * 16 for x in random_sample_size] + else: + closest_size, closest_ratio = get_closest_ratio(h, w, ratios=aspect_ratio_sample_size) + closest_size = [int(x / 16) * 16 for x in closest_size] + + for example in examples: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + if args.fix_sample_size is not None: + # Get adapt hw for resize + fix_sample_size = list(map(lambda x: int(x), fix_sample_size)) + transform = transforms.Compose([ + transforms.Resize(fix_sample_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(fix_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + elif args.random_ratio_crop: + # Get adapt hw for resize + b, c, h, w = pixel_values.size() + th, tw = random_sample_size + if th / tw > h / w: + nh = int(th) + nw = int(w / h * nh) + else: + nw = int(tw) + nh = int(h / w * nw) + + transform = transforms.Compose([ + transforms.Resize([nh, nw]), + transforms.CenterCrop([int(x) for x in random_sample_size]), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + else: + # Get adapt hw for resize + closest_size = list(map(lambda x: int(x), closest_size)) + if closest_size[0] / h > closest_size[1] / w: + resize_size = closest_size[0], int(w * closest_size[0] / h) + else: + resize_size = int(h * closest_size[1] / w), closest_size[1] + + transform = transforms.Compose([ + transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(closest_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + + source_pixel_values = [] + for _source_pixel_value in example["source_pixel_values"]: + source_pixel_values.append(np.array(_source_pixel_value)) + + new_examples["pixel_values"].append(transform(pixel_values)) + new_examples["source_pixel_values"].append(source_pixel_values) + new_examples["text"].append(example["text"]) + + # Limit the number of frames to the same + new_examples["pixel_values"] = torch.stack([example for example in new_examples["pixel_values"]]) + + # Encode prompts when enable_text_encoder_in_dataloader=True + if args.enable_text_encoder_in_dataloader: + prompt_embeds = encode_prompt( + batch['text'], device="cpu", + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + + new_examples['prompt_embeds'] = prompt_embeds + + return new_examples + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + else: + # DataLoaders creation: + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = ImageVideoSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), train_dataset, args.train_batch_size) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + # Prepare everything with our `accelerator`. + if args.use_peft_lora: + transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer3d, optimizer, train_dataloader, lr_scheduler + ) + elif fsdp_stage != 0: + transformer3d.network = network + transformer3d = transformer3d.to(dtype=weight_dtype) + transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer3d, optimizer, train_dataloader, lr_scheduler + ) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) + + if zero_stage != 0 and not args.use_peft_lora: + from functools import partial + + from videox_fun.dist import set_multi_gpus_devices, shard_model + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype, module_to_wrapper=list(transformer3d.layers)) + transformer3d = shard_fn(transformer3d) + + if fsdp_stage != 0 or zero_stage != 0: + from functools import partial + + from videox_fun.dist import set_multi_gpus_devices, shard_model + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.model.layers) + text_encoder = shard_fn(text_encoder) + + # Move text_encode and vae to gpu and cast to weight_dtype + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + transformer3d.to(accelerator.device, dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + keys_to_pop = [k for k, v in tracker_config.items() if isinstance(v, list)] + for k in keys_to_pop: + tracker_config.pop(k) + print(f"Removed tracker_config['{k}']") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + + pkl_path = os.path.join(os.path.join(args.output_dir, path), "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + _, first_epoch = pickle.load(file) + else: + first_epoch = global_step // num_update_steps_per_epoch + print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.") + + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + else: + initial_global_step = 0 + + # function for saving/removing + def save_model(ckpt_file, unwrapped_nw): + os.makedirs(args.output_dir, exist_ok=True) + accelerator.print(f"\nsaving checkpoint: {ckpt_file}") + if isinstance(unwrapped_nw, dict): + from safetensors.torch import save_file + save_file(unwrapped_nw, ckpt_file, metadata={"format": "pt"}) + return ckpt_file + unwrapped_nw.save_weights(ckpt_file, weight_dtype, None) + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + if args.multi_stream: + # create extra cuda streams to speedup inpaint vae computation + vae_stream_1 = torch.cuda.Stream() + vae_stream_2 = torch.cuda.Stream() + else: + vae_stream_1 = None + vae_stream_2 = None + + # Calculate the index we need】 + idx_sampling = DiscreteSampling(args.train_sampling_steps, uniform_sampling=args.uniform_sampling) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + batch_sampler.sampler.generator = torch.Generator().manual_seed(args.seed + epoch) + for step, batch in enumerate(train_dataloader): + if epoch == first_epoch and step == 0: + pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] + source_pixel_values = batch['source_pixel_values'] + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + + os.makedirs(os.path.join(args.output_dir, "sanity_check"), exist_ok=True) + for idx, (pixel_value, source_pixel_value, text) in enumerate(zip(pixel_values, source_pixel_values, texts)): + pixel_value = pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}.gif", rescale=True) + for local_index, _source_pixel_value in enumerate(source_pixel_value): + _source_pixel_value = Image.fromarray(np.uint8(_source_pixel_value)) + _source_pixel_value.save(f"{args.output_dir}/sanity_check/source_{local_index}_{gif_name[:10]}.jpg") + + with accelerator.accumulate(transformer3d): + # Convert images to latent space + pixel_values = batch["pixel_values"].to(weight_dtype) + source_pixel_values = batch["source_pixel_values"] + + if args.low_vram: + torch.cuda.empty_cache() + vae.to(accelerator.device) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to("cpu") + + with torch.no_grad(): + # This way is quicker when batch grows up + def _batch_encode_vae(pixel_values): + pixel_values = pixel_values.squeeze(1) + bs = args.vae_mini_batch + new_pixel_values = [] + for i in range(0, pixel_values.shape[0], bs): + pixel_values_bs = pixel_values[i : i + bs] + pixel_values_bs = vae.encode(pixel_values_bs)[0] + pixel_values_bs = pixel_values_bs.sample() + new_pixel_values.append(pixel_values_bs) + return torch.cat(new_pixel_values, dim = 0) + if vae_stream_1 is not None: + vae_stream_1.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(vae_stream_1): + latents = _batch_encode_vae(pixel_values).unsqueeze(2) + else: + latents = _batch_encode_vae(pixel_values).unsqueeze(2) + + def prepare_siglip_embeds( + images, + batch_size, + device, + dtype, + ): + siglip_embeds = [] + for image in images: + siglip_inputs = siglip_processor(images=[image], return_tensors="pt").to(device) + shape = siglip_inputs.spatial_shapes[0] + hidden_state = siglip(**siglip_inputs).last_hidden_state + B, N, C = hidden_state.shape + hidden_state = hidden_state[:, :shape[0] * shape[1]] + hidden_state = hidden_state.view(shape[0], shape[1], C) + siglip_embeds.append(hidden_state.to(dtype)) + + # siglip_embeds = [siglip_embeds] * batch_size + siglip_embeds = [siglip_embeds.copy() for _ in range(batch_size)] + + return siglip_embeds + + def prepare_image_latents( + images: List[torch.Tensor], + batch_size, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + image_latent = (vae.encode(image.bfloat16()).latent_dist.mode()[0] - vae.config.shift_factor) * vae.config.scaling_factor + image_latent = image_latent.unsqueeze(1).to(dtype) + image_latents.append(image_latent) # (16, 128, 128) + + image_latents = [image_latents.copy() for _ in range(batch_size)] + + return image_latents + + condition_latents = [] + condition_siglip_embeds = [] + omni_images = [ + [Image.fromarray(np.uint8(source_image)) for source_image in source_pixel_value] \ + for source_pixel_value in source_pixel_values + ] + bsz, channel, f, height, width = latents.size() + + if len(omni_images[0]) == 0: + condition_latents = [[] for i in range(bsz)] + condition_siglip_embeds = [None for i in range(bsz)] + omni_images = [] + else: + image_height, image_width = pixel_values.size()[-2], pixel_values.size()[-1] + for i in range(latents.size()[0]): + condition_images = [] + resized_images = [] + for img in omni_images[i]: + img = image_processor._resize_to_target_area(img, image_width * image_height) + image_width, image_height = img.size + resized_images.append(img) + + multiple_of = vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + + local_condition_siglip_embeds = prepare_siglip_embeds( + images=resized_images, + batch_size=1, + device=accelerator.device, + dtype=torch.float32, + ) + local_condition_siglip_embeds = [[se.to(transformer3d.dtype) for se in sels] for sels in local_condition_siglip_embeds] + condition_siglip_embeds += local_condition_siglip_embeds + + local_condition_latents = prepare_image_latents( + images=condition_images, + batch_size=1, + device=accelerator.device, + dtype=torch.float32, + ) + local_condition_latents = [[lat.to(transformer3d.dtype) for lat in lats] for lats in local_condition_latents] + condition_latents += local_condition_latents + + # wait for latents = vae.encode(pixel_values) to complete + if vae_stream_1 is not None: + torch.cuda.current_stream().wait_stream(vae_stream_1) + + if args.low_vram: + vae.to('cpu') + torch.cuda.empty_cache() + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device) + + if args.enable_text_encoder_in_dataloader: + prompt_embeds = batch['prompt_embeds'].to(dtype=latents.dtype, device=accelerator.device) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + batch['text'], device=accelerator.device, + text_encoder=text_encoder, + tokenizer=tokenizer, + num_condition_images=len(omni_images), + ) + + if args.low_vram and not args.enable_text_encoder_in_dataloader: + text_encoder.to('cpu') + torch.cuda.empty_cache() + + bsz, channel, f, height, width = latents.size() + latents = ((latents - vae.config.shift_factor) * vae.config.scaling_factor).to(dtype=weight_dtype) + noise = torch.randn(latents.size(), device=latents.device, generator=torch_rng, dtype=weight_dtype) + + if not args.uniform_sampling: + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + else: + # Sample a random timestep for each image + # timesteps = generate_timestep_with_lognorm(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + # timesteps = torch.randint(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + indices = idx_sampling(bsz, generator=torch_rng, device=latents.device) + indices = indices.long().cpu() + + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + noise_scheduler.config.get("base_image_seq_len", 256), + noise_scheduler.config.get("max_image_seq_len", 4096), + noise_scheduler.config.get("base_shift", 0.5), + noise_scheduler.config.get("max_shift", 1.15), + ) + noise_scheduler.sigma_min = 0.0 + noise_scheduler.set_timesteps(args.train_sampling_steps, device=latents.device, mu=mu) + timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) + noisy_latents = (1.0 - sigmas) * latents + sigmas * noise + + # Add noise + target = noise - latents + timesteps = ((1000 - timesteps) / 1000).to(device=accelerator.device, dtype=weight_dtype) + + # Create noise mask: 0 for condition images (clean), 1 for target image (noisy) + image_noise_mask = [ + [0] * len(condition_latents[i]) + [1] for i in range(bsz) + ] + x_combined = [ + condition_latents[i] + [noisy_latents[i]] for i in range(bsz) + ] + # Predict the noise residual + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + noise_pred = transformer3d( + x = x_combined, + t = timesteps, + cap_feats = prompt_embeds, + siglip_feats = condition_siglip_embeds, + image_noise_mask = image_noise_mask, + )[0] + + def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): + noise_pred = noise_pred.float() + target = target.float() + diff = noise_pred - target + mse_loss = F.mse_loss(noise_pred, target, reduction='none') + mask = (diff.abs() <= threshold).float() + masked_loss = mse_loss * mask + if weighting is not None: + masked_loss = masked_loss * weighting + final_loss = masked_loss.mean() + return final_loss + + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + loss = custom_mse_loss(-noise_pred.float(), target.float(), weighting.float()) + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if not args.save_state: + if args.use_peft_lora: + safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors") + save_model(safetensor_save_path, get_peft_model_state_dict(accelerator.unwrap_model(transformer3d))) + logger.info(f"Saved safetensor to {safetensor_save_path}") + else: + safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors") + save_model(safetensor_save_path, accelerator.unwrap_model(network)) + logger.info(f"Saved safetensor to {safetensor_save_path}") + else: + accelerator_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(accelerator_save_path) + logger.info(f"Saved state to {accelerator_save_path}") + + if accelerator.is_main_process: + if args.validation_prompts is not None and global_step % args.validation_steps == 0: + log_validation( + vae, + text_encoder, + tokenizer, + transformer3d, + network, + args, + accelerator, + weight_dtype, + global_step, + ) + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + log_validation( + vae, + text_encoder, + tokenizer, + transformer3d, + network, + args, + accelerator, + weight_dtype, + global_step, + ) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if not args.save_state: + if args.use_peft_lora: + safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors") + save_model(safetensor_save_path, get_peft_model_state_dict(accelerator.unwrap_model(transformer3d))) + logger.info(f"Saved safetensor to {safetensor_save_path}") + else: + safetensor_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.safetensors") + save_model(safetensor_save_path, accelerator.unwrap_model(network)) + logger.info(f"Saved safetensor to {safetensor_save_path}") + else: + accelerator_save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(accelerator_save_path) + logger.info(f"Saved state to {accelerator_save_path}") + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/scripts/z_image/train_lora_omni.sh b/scripts/z_image/train_lora_omni.sh new file mode 100644 index 00000000..09a9e4ee --- /dev/null +++ b/scripts/z_image/train_lora_omni.sh @@ -0,0 +1,33 @@ +export MODEL_NAME="models/Diffusion_Transformer/Z-Image-Omni" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata_edit.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/z_image/train_lora_omni.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=1000 \ + --checkpointing_steps=50 \ + --learning_rate=1e-04 \ + --seed=42 \ + --output_dir="output_dir_z_image_lora_omni" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --rank=128 \ + --network_alpha=128 \ + --target_name="to_q,to_k,to_v,feed_forward" \ + --uniform_sampling \ + --resume_from_checkpoint="latest" \ No newline at end of file diff --git a/scripts/z_image/train_omni.py b/scripts/z_image/train_omni.py new file mode 100644 index 00000000..9cea7b32 --- /dev/null +++ b/scripts/z_image/train_omni.py @@ -0,0 +1,1750 @@ +"""Modified from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py +""" +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import gc +import logging +import math +import os +import pickle +import random +import shutil +import sys +from typing import (Any, Callable, Dict, List, NamedTuple, Optional, Tuple, + Union) + +import accelerate +import diffusers +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.state import AcceleratorState +from accelerate.utils import ProjectConfiguration, set_seed +from diffusers import DDIMScheduler, FlowMatchEulerDiscreteScheduler +from diffusers.optimization import get_scheduler +from diffusers.training_utils import (EMAModel, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3) +from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.utils.torch_utils import is_compiled_module +from einops import rearrange +from omegaconf import OmegaConf +from packaging import version +from PIL import Image +from torch.utils.data import RandomSampler +from torch.utils.tensorboard import SummaryWriter +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.utils import ContextManagers + +import datasets + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.data.bucket_sampler import (ASPECT_RATIO_512, + ASPECT_RATIO_RANDOM_CROP_512, + ASPECT_RATIO_RANDOM_CROP_PROB, + AspectRatioBatchImageVideoSampler, + RandomSampler, get_closest_ratio) +from videox_fun.data.dataset_image import ImageEditDataset +from videox_fun.data.dataset_image_video import (ImageVideoControlDataset, + ImageVideoDataset, + ImageVideoSampler, + get_random_mask) +from videox_fun.dist import set_multi_gpus_devices, shard_model +from videox_fun.models import (AutoencoderKL, AutoProcessor, AutoTokenizer, + Qwen2Tokenizer, Qwen3ForCausalLM, + QwenImageTransformer2DModel, Siglip2VisionModel, + ZImageOmniTransformer2DModel) +from videox_fun.models.flux2_image_processor import Flux2ImageProcessor +from videox_fun.pipeline import ZImageOmniPipeline +from videox_fun.utils.discrete_sampler import DiscreteSampling +from videox_fun.utils.utils import get_image_to_video_latent, save_videos_grid + +if is_wandb_available(): + import wandb + +def filter_kwargs(cls, kwargs): + import inspect + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + +def linear_decay(initial_value, final_value, total_steps, current_step): + if current_step >= total_steps: + return final_value + current_step = max(0, current_step) + step_size = (final_value - initial_value) / total_steps + current_value = initial_value + step_size * current_step + return current_value + +def generate_timestep_with_lognorm(low, high, shape, device="cpu", generator=None): + u = torch.normal(mean=0.0, std=1.0, size=shape, device=device, generator=generator) + t = 1 / (1 + torch.exp(-u)) * (high - low) + low + return torch.clip(t.to(torch.int32), low, high - 1) + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + +def encode_prompt( + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + text_encoder = None, + tokenizer = None, + max_sequence_length: int = 512, + num_condition_images = 0, +) -> List[torch.FloatTensor]: + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + + if num_condition_images == 0: + prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] + elif num_condition_images > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) + prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|im_end|>"] + prompt[i] = prompt_list + + flattened_prompt = [] + prompt_list_lengths = [] + + for i in range(len(prompt)): + prompt_list_lengths.append(len(prompt[i])) + flattened_prompt.extend(prompt[i]) + + text_inputs = tokenizer( + flattened_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + start_idx = 0 + for i in range(len(prompt_list_lengths)): + batch_embeddings = [] + end_idx = start_idx + prompt_list_lengths[i] + for j in range(start_idx, end_idx): + batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) + embeddings_list.append(batch_embeddings) + start_idx = end_idx + + return embeddings_list + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + +def log_validation(vae, text_encoder, tokenizer, transformer3d, args, accelerator, weight_dtype, global_step): + try: + logger.info("Running validation... ") + + transformer3d_val = ZImageOmniTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", torch_dtype=weight_dtype, + low_cpu_mem_usage=True, + ).to(weight_dtype) + transformer3d_val.load_state_dict(accelerator.unwrap_model(transformer3d).state_dict()) + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler" + ) + transformer3d = transformer3d.to("cpu") + pipeline = ZImageOmniPipeline( + vae=accelerator.unwrap_model(vae).to(weight_dtype), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + transformer=transformer3d_val, + scheduler=scheduler, + siglip=siglip, + siglip_processor=siglip_processor, + ) + pipeline = pipeline.to(accelerator.device) + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + for i in range(len(args.validation_prompts)): + with torch.no_grad(): + image = [get_image(args.validation_image_paths[i])] + sample = pipeline( + prompt = args.validation_prompts[i], + negative_prompt = "bad detailed", + height = args.image_sample_size, + width = args.image_sample_size, + generator = generator, + image = image + ).images + os.makedirs(os.path.join(args.output_dir, "sample"), exist_ok=True) + image = sample[0].save(os.path.join(args.output_dir, f"sample/sample-{global_step}-image-{i}.gif")) + + del pipeline + del transformer3d_val + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + transformer3d = transformer3d.to(accelerator.device) + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + print(f"Eval error with info {e}") + transformer3d = transformer3d.to(accelerator.device) + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. " + ), + ) + parser.add_argument( + "--train_data_meta", + type=str, + default=None, + help=( + "A csv containing the training data. " + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--validation_image_paths", + type=str, + default=None, + nargs="+", + help=("A set of images evaluated every `--validation_epochs` and logged to `--report_to`."), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--use_came", + action="store_true", + help="whether to use came", + ) + parser.add_argument( + "--multi_stream", + action="store_true", + help="whether to use cuda multi-stream", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--vae_mini_batch", type=int, default=32, help="mini batch size for vae." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_model_info", action="store_true", help="Whether or not to report more info about model (such as norm, grad)." + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=2000, + help="Run validation every X steps.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + parser.add_argument( + "--snr_loss", action="store_true", help="Whether or not to use snr_loss." + ) + parser.add_argument( + "--uniform_sampling", action="store_true", help="Whether or not to use uniform_sampling." + ) + parser.add_argument( + "--enable_text_encoder_in_dataloader", action="store_true", help="Whether or not to use text encoder in dataloader." + ) + parser.add_argument( + "--enable_bucket", action="store_true", help="Whether enable bucket sample in datasets." + ) + parser.add_argument( + "--random_ratio_crop", action="store_true", help="Whether enable random ratio crop sample in datasets." + ) + parser.add_argument( + "--random_hw_adapt", action="store_true", help="Whether enable random adapt height and width in datasets." + ) + parser.add_argument( + "--train_sampling_steps", + type=int, + default=1000, + help="Run train_sampling_steps.", + ) + parser.add_argument( + "--image_sample_size", + type=int, + default=512, + help="Sample size of the image.", + ) + parser.add_argument( + "--fix_sample_size", + nargs=2, type=int, default=None, + help="Fix Sample size [height, width] when using bucket and collate_fn." + ) + parser.add_argument( + "--transformer_path", + type=str, + default=None, + help=("If you want to load the weight from other transformers, input its path."), + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help=("If you want to load the weight from other vaes, input its path."), + ) + + parser.add_argument( + '--trainable_modules', + nargs='+', + help='Enter a list of trainable modules' + ) + parser.add_argument( + '--trainable_modules_low_learning_rate', + nargs='+', + default=[], + help='Enter a list of trainable modules with lower learning rate' + ) + parser.add_argument( + '--tokenizer_max_length', + type=int, + default=512, + help='Max length of tokenizer' + ) + parser.add_argument( + "--use_deepspeed", action="store_true", help="Whether or not to use deepspeed." + ) + parser.add_argument( + "--use_fsdp", action="store_true", help="Whether or not to use fsdp." + ) + parser.add_argument( + "--low_vram", action="store_true", help="Whether enable low_vram mode." + ) + parser.add_argument( + "--prompt_template_encode", + type=str, + default="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n", + help=( + 'The prompt template for text encoder.' + ), + ) + parser.add_argument( + "--prompt_template_encode_start_idx", + type=int, + default=34, + help=( + 'The start idx for prompt template.' + ), + ) + parser.add_argument( + "--abnormal_norm_clip_start", + type=int, + default=1000, + help=( + 'When do we start doing additional processing on abnormal gradients. ' + ), + ) + parser.add_argument( + "--initial_grad_norm_ratio", + type=int, + default=5, + help=( + 'The initial gradient is relative to the multiple of the max_grad_norm. ' + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def main(): + args = parse_args() + if args.train_batch_size >= 2: + raise ValueError("This code does not support args.train_batch_size >= 2 now.") + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) + logging_dir = os.path.join(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + deepspeed_plugin = accelerator.state.deepspeed_plugin if hasattr(accelerator.state, "deepspeed_plugin") else None + fsdp_plugin = accelerator.state.fsdp_plugin if hasattr(accelerator.state, "fsdp_plugin") else None + if deepspeed_plugin is not None: + zero_stage = int(deepspeed_plugin.zero_stage) + fsdp_stage = 0 + print(f"Using DeepSpeed Zero stage: {zero_stage}") + + args.use_deepspeed = True + if zero_stage == 3: + print(f"Auto set save_state to True because zero_stage == 3") + args.save_state = True + elif fsdp_plugin is not None: + from torch.distributed.fsdp import ShardingStrategy + zero_stage = 0 + if fsdp_plugin.sharding_strategy is ShardingStrategy.FULL_SHARD: + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is None: # The fsdp_plugin.sharding_strategy is None in FSDP 2. + fsdp_stage = 3 + elif fsdp_plugin.sharding_strategy is ShardingStrategy.SHARD_GRAD_OP: + fsdp_stage = 2 + else: + fsdp_stage = 0 + print(f"Using FSDP stage: {fsdp_stage}") + + args.use_fsdp = True + if fsdp_stage == 3: + print(f"Auto set save_state to True because fsdp_stage == 3") + args.save_state = True + else: + zero_stage = 0 + fsdp_stage = 0 + print("DeepSpeed is not enabled.") + + if accelerator.is_main_process: + writer = SummaryWriter(log_dir=logging_dir) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + rng = np.random.default_rng(np.random.PCG64(args.seed + accelerator.process_index)) + torch_rng = torch.Generator(accelerator.device).manual_seed(args.seed + accelerator.process_index) + else: + rng = None + torch_rng = None + index_rng = np.random.default_rng(np.random.PCG64(43)) + print(f"Init rng with seed {args.seed + accelerator.process_index}. Process_index is {accelerator.process_index}") + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora transformer3d) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + args.mixed_precision = accelerator.mixed_precision + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + args.mixed_precision = accelerator.mixed_precision + + # Load scheduler, tokenizer and models. + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler" + ) + + # Get Tokenizer + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer" + ) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So Qwen3ForCausalLM and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + # Get Text encoder + text_encoder = Qwen3ForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype + ) + text_encoder = text_encoder.eval() + # Get Vae + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae" + ).to(weight_dtype) + vae.eval() + + siglip_path = "models/siglip2-so400m-patch16-naflex" + siglip = Siglip2VisionModel.from_pretrained(siglip_path, torch_dtype=torch.bfloat16).to(accelerator.device) + siglip_processor = AutoProcessor.from_pretrained(siglip_path) + + # Get Transformer + transformer3d = ZImageOmniTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=weight_dtype, + ).to(weight_dtype) + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) + image_processor = Flux2ImageProcessor(vae_scale_factor=vae_scale_factor * 2) + + # Freeze vae and text_encoder and set transformer3d to trainable + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + transformer3d.requires_grad_(False) + + if args.transformer_path is not None: + print(f"From checkpoint: {args.transformer_path}") + if args.transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.transformer_path) + else: + state_dict = torch.load(args.transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer3d.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + if args.vae_path is not None: + print(f"From checkpoint: {args.vae_path}") + if args.vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(args.vae_path) + else: + state_dict = torch.load(args.vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + # A good trainable modules is showed below now. + # For 3D Patch: trainable_modules = ['ff.net', 'pos_embed', 'attn2', 'proj_out', 'timepositionalencoding', 'h_position', 'w_position'] + # For 2D Patch: trainable_modules = ['ff.net', 'attn2', 'timepositionalencoding', 'h_position', 'w_position'] + transformer3d.train() + if accelerator.is_main_process: + accelerator.print( + f"Trainable modules '{args.trainable_modules}'." + ) + for name, param in transformer3d.named_parameters(): + for trainable_module_name in args.trainable_modules + args.trainable_modules_low_learning_rate: + if trainable_module_name in name: + param.requires_grad = True + break + + # Create EMA for the transformer3d. + if args.use_ema: + if zero_stage == 3: + raise NotImplementedError("FSDP does not support EMA.") + + ema_transformer3d = ZImageOmniTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=weight_dtype, + ).to(weight_dtype) + + ema_transformer3d = EMAModel(ema_transformer3d.parameters(), model_cls=ZImageOmniTransformer2DModel, model_config=ema_transformer3d.config) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + if fsdp_stage != 0: + def save_model_hook(models, weights, output_dir): + accelerate_state_dict = accelerator.get_state_dict(models[-1], unwrap=True) + if accelerator.is_main_process: + from safetensors.torch import save_file + + safetensor_save_path = os.path.join(output_dir, f"diffusion_pytorch_model.safetensors") + accelerate_state_dict = {k: v.to(dtype=weight_dtype) for k, v in accelerate_state_dict.items()} + save_file(accelerate_state_dict, safetensor_save_path, metadata={"format": "pt"}) + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + elif zero_stage == 3: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + accelerate_state_dict = accelerator.get_state_dict(models[-1], unwrap=True) + if accelerator.is_main_process: + from safetensors.torch import save_file + safetensor_save_path = os.path.join(output_dir, f"diffusion_pytorch_model.safetensors") + save_file(accelerate_state_dict, safetensor_save_path, metadata={"format": "pt"}) + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + else: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_transformer3d.save_pretrained(os.path.join(output_dir, "transformer_ema")) + + models[0].save_pretrained(os.path.join(output_dir, "transformer")) + if not args.use_deepspeed: + weights.pop() + + with open(os.path.join(output_dir, "sampler_pos_start.pkl"), 'wb') as file: + pickle.dump([batch_sampler.sampler._pos_start, first_epoch], file) + + def load_model_hook(models, input_dir): + if args.use_ema: + ema_path = os.path.join(input_dir, "transformer_ema") + _, ema_kwargs = ZImageOmniTransformer2DModel.load_config(ema_path, return_unused_kwargs=True) + load_model = ZImageOmniTransformer2DModel.from_pretrained( + input_dir, subfolder="transformer_ema", + ) + load_model = EMAModel(load_model.parameters(), model_cls=ZImageOmniTransformer2DModel, model_config=load_model.config) + load_model.load_state_dict(ema_kwargs) + + ema_transformer3d.load_state_dict(load_model.state_dict()) + ema_transformer3d.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = ZImageOmniTransformer2DModel.from_pretrained( + input_dir, subfolder="transformer" + ) + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + pkl_path = os.path.join(input_dir, "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + loaded_number, _ = pickle.load(file) + batch_sampler.sampler._pos_start = max(loaded_number - args.dataloader_num_workers * accelerator.num_processes * 2, 0) + print(f"Load pkl from {pkl_path}. Get loaded_number = {loaded_number}.") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + transformer3d.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + elif args.use_came: + try: + from came_pytorch import CAME + except: + raise ImportError( + "Please install came_pytorch to use CAME. You can do so by running `pip install came_pytorch`" + ) + + optimizer_cls = CAME + else: + optimizer_cls = torch.optim.AdamW + + trainable_params = list(filter(lambda p: p.requires_grad, transformer3d.parameters())) + trainable_params_optim = [ + {'params': [], 'lr': args.learning_rate}, + {'params': [], 'lr': args.learning_rate / 2}, + ] + in_already = [] + for name, param in transformer3d.named_parameters(): + high_lr_flag = False + if name in in_already: + continue + for trainable_module_name in args.trainable_modules: + if trainable_module_name in name: + in_already.append(name) + high_lr_flag = True + trainable_params_optim[0]['params'].append(param) + if accelerator.is_main_process: + print(f"Set {name} to lr : {args.learning_rate}") + break + if high_lr_flag: + continue + for trainable_module_name in args.trainable_modules_low_learning_rate: + if trainable_module_name in name: + in_already.append(name) + trainable_params_optim[1]['params'].append(param) + if accelerator.is_main_process: + print(f"Set {name} to lr : {args.learning_rate / 2}") + break + + if args.use_came: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + # weight_decay=args.adam_weight_decay, + betas=(0.9, 0.999, 0.9999), + eps=(1e-30, 1e-16) + ) + else: + optimizer = optimizer_cls( + trainable_params_optim, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the training dataset + if args.fix_sample_size is not None and args.enable_bucket: + args.image_sample_size = max(max(args.fix_sample_size), args.image_sample_size) + args.random_hw_adapt = False + + # Get the dataset + train_dataset = ImageEditDataset( + args.train_data_meta, args.train_data_dir, + image_sample_size=args.image_sample_size, + enable_bucket=args.enable_bucket, + ) + + def worker_init_fn(_seed): + _seed = _seed * 256 + def _worker_init_fn(worker_id): + print(f"worker_init_fn with {_seed + worker_id}") + np.random.seed(_seed + worker_id) + random.seed(_seed + worker_id) + return _worker_init_fn + + if args.enable_bucket: + aspect_ratio_sample_size = {key : [x / 512 * args.image_sample_size for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = AspectRatioBatchImageVideoSampler( + sampler=RandomSampler(train_dataset, generator=batch_sampler_generator), dataset=train_dataset.dataset, + batch_size=args.train_batch_size, train_folder = args.train_data_dir, drop_last=True, + aspect_ratios=aspect_ratio_sample_size, + ) + + + def collate_fn(examples): + def get_random_downsample_ratio(sample_size, image_ratio=[], + all_choices=False, rng=None): + def _create_special_list(length): + if length == 1: + return [1.0] + first_element = 0.90 + remaining_sum = 1.0 - first_element + other_elements_value = remaining_sum / (length - 1) + return [first_element] + [other_elements_value] * (length - 1) + + MIN_TARGET = 1024 + + if sample_size < MIN_TARGET: + number_list = [1.0] + else: + max_allowed_ratio = sample_size / MIN_TARGET + base_ratios = [ + 1.0, + 1.1, 1.2, 1.25, 1.33, 1.5, + 1.75, 2.0, 2.25, 2.5, 2.75, + 3.0, 3.5, 4.0, 5.0, 6.0, 8.0 + ] + candidate_ratios = set(base_ratios + list(image_ratio)) + number_list = sorted([r for r in candidate_ratios if 1.0 <= r <= max_allowed_ratio]) + + if not number_list: + number_list = [1.0] + + if all_choices: + return number_list + + probs = np.array(_create_special_list(len(number_list))) + if rng is None: + return np.random.choice(number_list, p=probs) + else: + return rng.choice(number_list, p=probs) + + # Create new output + new_examples = {} + new_examples["pixel_values"] = [] + new_examples["source_pixel_values"] = [] + new_examples["text"] = [] + + # Get downsample ratio in image + pixel_value = examples[0]["pixel_values"] + source_pixel_values = examples[0]["source_pixel_values"] + data_type = examples[0]["data_type"] + f, h, w, c = np.shape(pixel_value) + + random_downsample_ratio = 1 if not args.random_hw_adapt else get_random_downsample_ratio(args.image_sample_size) + + aspect_ratio_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + aspect_ratio_random_crop_sample_size = {key : [x / 512 * args.image_sample_size / random_downsample_ratio for x in ASPECT_RATIO_RANDOM_CROP_512[key]] for key in ASPECT_RATIO_RANDOM_CROP_512.keys()} + + if args.fix_sample_size is not None: + fix_sample_size = [int(x / 16) * 16 for x in args.fix_sample_size] + elif args.random_ratio_crop: + if rng is None: + random_sample_size = aspect_ratio_random_crop_sample_size[ + np.random.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + else: + random_sample_size = aspect_ratio_random_crop_sample_size[ + rng.choice(list(aspect_ratio_random_crop_sample_size.keys()), p = ASPECT_RATIO_RANDOM_CROP_PROB) + ] + random_sample_size = [int(x / 16) * 16 for x in random_sample_size] + else: + closest_size, closest_ratio = get_closest_ratio(h, w, ratios=aspect_ratio_sample_size) + closest_size = [int(x / 16) * 16 for x in closest_size] + + for example in examples: + # To 0~1 + pixel_values = torch.from_numpy(example["pixel_values"]).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + + if args.fix_sample_size is not None: + # Get adapt hw for resize + fix_sample_size = list(map(lambda x: int(x), fix_sample_size)) + transform = transforms.Compose([ + transforms.Resize(fix_sample_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(fix_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + elif args.random_ratio_crop: + # Get adapt hw for resize + b, c, h, w = pixel_values.size() + th, tw = random_sample_size + if th / tw > h / w: + nh = int(th) + nw = int(w / h * nh) + else: + nw = int(tw) + nh = int(h / w * nw) + + transform = transforms.Compose([ + transforms.Resize([nh, nw]), + transforms.CenterCrop([int(x) for x in random_sample_size]), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + else: + # Get adapt hw for resize + closest_size = list(map(lambda x: int(x), closest_size)) + if closest_size[0] / h > closest_size[1] / w: + resize_size = closest_size[0], int(w * closest_size[0] / h) + else: + resize_size = int(h * closest_size[1] / w), closest_size[1] + + transform = transforms.Compose([ + transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BILINEAR), # Image.BICUBIC + transforms.CenterCrop(closest_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + + source_pixel_values = [] + for _source_pixel_value in example["source_pixel_values"]: + source_pixel_values.append(np.array(_source_pixel_value)) + + new_examples["pixel_values"].append(transform(pixel_values)) + new_examples["source_pixel_values"].append(source_pixel_values) + new_examples["text"].append(example["text"]) + + # Limit the number of frames to the same + new_examples["pixel_values"] = torch.stack([example for example in new_examples["pixel_values"]]) + + # Encode prompts when enable_text_encoder_in_dataloader=True + if args.enable_text_encoder_in_dataloader: + prompt_embeds = encode_prompt( + batch['text'], device="cpu", + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + + new_examples['prompt_embeds'] = prompt_embeds + + return new_examples + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + else: + # DataLoaders creation: + batch_sampler_generator = torch.Generator().manual_seed(args.seed) + batch_sampler = ImageVideoSampler(RandomSampler(train_dataset, generator=batch_sampler_generator), train_dataset, args.train_batch_size) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_sampler=batch_sampler, + persistent_workers=True if args.dataloader_num_workers != 0 else False, + num_workers=args.dataloader_num_workers, + worker_init_fn=worker_init_fn(args.seed + accelerator.process_index) + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + transformer3d, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer3d, optimizer, train_dataloader, lr_scheduler + ) + + if fsdp_stage != 0 or zero_stage != 0: + from functools import partial + + from videox_fun.dist import set_multi_gpus_devices, shard_model + + shard_fn = partial(shard_model, device_id=accelerator.device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.model.layers) + text_encoder = shard_fn(text_encoder) + + if args.use_ema: + ema_transformer3d.to(accelerator.device) + + # Move text_encode and vae to gpu and cast to weight_dtype + vae.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device if not args.low_vram else "cpu", dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + keys_to_pop = [k for k, v in tracker_config.items() if isinstance(v, list)] + for k in keys_to_pop: + tracker_config.pop(k) + print(f"Removed tracker_config['{k}']") + accelerator.init_trackers(args.tracker_project_name, tracker_config) + + # Function for unwrapping if model was compiled with `torch.compile`. + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + + pkl_path = os.path.join(os.path.join(args.output_dir, path), "sampler_pos_start.pkl") + if os.path.exists(pkl_path): + with open(pkl_path, 'rb') as file: + _, first_epoch = pickle.load(file) + else: + first_epoch = global_step // num_update_steps_per_epoch + print(f"Load pkl from {pkl_path}. Get first_epoch = {first_epoch}.") + + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + if args.multi_stream: + # create extra cuda streams to speedup inpaint vae computation + vae_stream_1 = torch.cuda.Stream() + vae_stream_2 = torch.cuda.Stream() + else: + vae_stream_1 = None + vae_stream_2 = None + + # Calculate the index we need】 + idx_sampling = DiscreteSampling(args.train_sampling_steps, uniform_sampling=args.uniform_sampling) + + for epoch in range(first_epoch, args.num_train_epochs): + train_loss = 0.0 + batch_sampler.sampler.generator = torch.Generator().manual_seed(args.seed + epoch) + for step, batch in enumerate(train_dataloader): + if epoch == first_epoch and step == 0: + pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] + source_pixel_values = batch['source_pixel_values'] + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + + os.makedirs(os.path.join(args.output_dir, "sanity_check"), exist_ok=True) + for idx, (pixel_value, source_pixel_value, text) in enumerate(zip(pixel_values, source_pixel_values, texts)): + pixel_value = pixel_value[None, ...] + gif_name = '-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_step}-{idx}' + save_videos_grid(pixel_value, f"{args.output_dir}/sanity_check/{gif_name[:10]}.gif", rescale=True) + for local_index, _source_pixel_value in enumerate(source_pixel_value): + _source_pixel_value = Image.fromarray(np.uint8(_source_pixel_value)) + _source_pixel_value.save(f"{args.output_dir}/sanity_check/source_{local_index}_{gif_name[:10]}.jpg") + + with accelerator.accumulate(transformer3d): + # Convert images to latent space + pixel_values = batch["pixel_values"].to(weight_dtype) + source_pixel_values = batch["source_pixel_values"] + + if args.low_vram: + torch.cuda.empty_cache() + vae.to(accelerator.device) + if not args.enable_text_encoder_in_dataloader: + text_encoder.to("cpu") + + with torch.no_grad(): + # This way is quicker when batch grows up + def _batch_encode_vae(pixel_values): + pixel_values = pixel_values.squeeze(1) + bs = args.vae_mini_batch + new_pixel_values = [] + for i in range(0, pixel_values.shape[0], bs): + pixel_values_bs = pixel_values[i : i + bs] + pixel_values_bs = vae.encode(pixel_values_bs)[0] + pixel_values_bs = pixel_values_bs.sample() + new_pixel_values.append(pixel_values_bs) + return torch.cat(new_pixel_values, dim = 0) + if vae_stream_1 is not None: + vae_stream_1.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(vae_stream_1): + latents = _batch_encode_vae(pixel_values).unsqueeze(2) + else: + latents = _batch_encode_vae(pixel_values).unsqueeze(2) + + def prepare_siglip_embeds( + images, + batch_size, + device, + dtype, + ): + siglip_embeds = [] + for image in images: + siglip_inputs = siglip_processor(images=[image], return_tensors="pt").to(device) + shape = siglip_inputs.spatial_shapes[0] + hidden_state = siglip(**siglip_inputs).last_hidden_state + B, N, C = hidden_state.shape + hidden_state = hidden_state[:, :shape[0] * shape[1]] + hidden_state = hidden_state.view(shape[0], shape[1], C) + siglip_embeds.append(hidden_state.to(dtype)) + + # siglip_embeds = [siglip_embeds] * batch_size + siglip_embeds = [siglip_embeds.copy() for _ in range(batch_size)] + + return siglip_embeds + + def prepare_image_latents( + images: List[torch.Tensor], + batch_size, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + image_latent = (vae.encode(image.bfloat16()).latent_dist.mode()[0] - vae.config.shift_factor) * vae.config.scaling_factor + image_latent = image_latent.unsqueeze(1).to(dtype) + image_latents.append(image_latent) # (16, 128, 128) + + image_latents = [image_latents.copy() for _ in range(batch_size)] + + return image_latents + + condition_latents = [] + condition_siglip_embeds = [] + omni_images = [ + [Image.fromarray(np.uint8(source_image)) for source_image in source_pixel_value] \ + for source_pixel_value in source_pixel_values + ] + bsz, channel, f, height, width = latents.size() + + if len(omni_images[0]) == 0: + condition_latents = [[] for i in range(bsz)] + condition_siglip_embeds = [None for i in range(bsz)] + omni_images = [] + else: + image_height, image_width = pixel_values.size()[-2], pixel_values.size()[-1] + for i in range(latents.size()[0]): + condition_images = [] + resized_images = [] + for img in omni_images[i]: + img = image_processor._resize_to_target_area(img, image_width * image_height) + image_width, image_height = img.size + resized_images.append(img) + + multiple_of = vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + + local_condition_siglip_embeds = prepare_siglip_embeds( + images=resized_images, + batch_size=1, + device=accelerator.device, + dtype=torch.float32, + ) + local_condition_siglip_embeds = [[se.to(transformer3d.dtype) for se in sels] for sels in local_condition_siglip_embeds] + condition_siglip_embeds += local_condition_siglip_embeds + + local_condition_latents = prepare_image_latents( + images=condition_images, + batch_size=1, + device=accelerator.device, + dtype=torch.float32, + ) + local_condition_latents = [[lat.to(transformer3d.dtype) for lat in lats] for lats in local_condition_latents] + condition_latents += local_condition_latents + + # wait for latents = vae.encode(pixel_values) to complete + if vae_stream_1 is not None: + torch.cuda.current_stream().wait_stream(vae_stream_1) + + if args.low_vram: + vae.to('cpu') + torch.cuda.empty_cache() + if not args.enable_text_encoder_in_dataloader: + text_encoder.to(accelerator.device) + + if args.enable_text_encoder_in_dataloader: + prompt_embeds = batch['prompt_embeds'].to(dtype=latents.dtype, device=accelerator.device) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + batch['text'], device=accelerator.device, + text_encoder=text_encoder, + tokenizer=tokenizer, + num_condition_images=len(omni_images), + ) + + if args.low_vram and not args.enable_text_encoder_in_dataloader: + text_encoder.to('cpu') + torch.cuda.empty_cache() + + bsz, channel, f, height, width = latents.size() + latents = ((latents - vae.config.shift_factor) * vae.config.scaling_factor).to(dtype=weight_dtype) + noise = torch.randn(latents.size(), device=latents.device, generator=torch_rng, dtype=weight_dtype) + + if not args.uniform_sampling: + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler.config.num_train_timesteps).long() + else: + # Sample a random timestep for each image + # timesteps = generate_timestep_with_lognorm(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + # timesteps = torch.randint(0, args.train_sampling_steps, (bsz,), device=latents.device, generator=torch_rng) + indices = idx_sampling(bsz, generator=torch_rng, device=latents.device) + indices = indices.long().cpu() + + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + noise_scheduler.config.get("base_image_seq_len", 256), + noise_scheduler.config.get("max_image_seq_len", 4096), + noise_scheduler.config.get("base_shift", 0.5), + noise_scheduler.config.get("max_shift", 1.15), + ) + noise_scheduler.sigma_min = 0.0 + noise_scheduler.set_timesteps(args.train_sampling_steps, device=latents.device, mu=mu) + timesteps = noise_scheduler.timesteps[indices].to(device=latents.device) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype) + noisy_latents = (1.0 - sigmas) * latents + sigmas * noise + + # Add noise + target = noise - latents + timesteps = ((1000 - timesteps) / 1000).to(device=accelerator.device, dtype=weight_dtype) + + # Create noise mask: 0 for condition images (clean), 1 for target image (noisy) + image_noise_mask = [ + [0] * len(condition_latents[i]) + [1] for i in range(bsz) + ] + x_combined = [ + condition_latents[i] + [noisy_latents[i]] for i in range(bsz) + ] + # Predict the noise residual + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=accelerator.device): + noise_pred = transformer3d( + x = x_combined, + t = timesteps, + cap_feats = prompt_embeds, + siglip_feats = condition_siglip_embeds, + image_noise_mask = image_noise_mask, + )[0] + + def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): + noise_pred = noise_pred.float() + target = target.float() + diff = noise_pred - target + mse_loss = F.mse_loss(noise_pred, target, reduction='none') + mask = (diff.abs() <= threshold).float() + masked_loss = mse_loss * mask + if weighting is not None: + masked_loss = masked_loss * weighting + final_loss = masked_loss.mean() + return final_loss + + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + loss = custom_mse_loss(-noise_pred.float(), target.float(), weighting.float()) + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + if not args.use_deepspeed and not args.use_fsdp: + trainable_params_grads = [p.grad for p in trainable_params if p.grad is not None] + trainable_params_total_norm = torch.norm(torch.stack([torch.norm(g.detach(), 2) for g in trainable_params_grads]), 2) + max_grad_norm = linear_decay(args.max_grad_norm * args.initial_grad_norm_ratio, args.max_grad_norm, args.abnormal_norm_clip_start, global_step) + if trainable_params_total_norm / max_grad_norm > 5 and global_step > args.abnormal_norm_clip_start: + actual_max_grad_norm = max_grad_norm / min((trainable_params_total_norm / max_grad_norm), 10) + else: + actual_max_grad_norm = max_grad_norm + else: + actual_max_grad_norm = args.max_grad_norm + + if not args.use_deepspeed and not args.use_fsdp and args.report_model_info and accelerator.is_main_process: + if trainable_params_total_norm > 1 and global_step > args.abnormal_norm_clip_start: + for name, param in transformer3d.named_parameters(): + if param.requires_grad: + writer.add_scalar(f'gradients/before_clip_norm/{name}', param.grad.norm(), global_step=global_step) + + norm_sum = accelerator.clip_grad_norm_(trainable_params, actual_max_grad_norm) + if not args.use_deepspeed and not args.use_fsdp and args.report_model_info and accelerator.is_main_process: + writer.add_scalar(f'gradients/norm_sum', norm_sum, global_step=global_step) + writer.add_scalar(f'gradients/actual_max_grad_norm', actual_max_grad_norm, global_step=global_step) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + + if args.use_ema: + ema_transformer3d.step(transformer3d.parameters()) + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % args.checkpointing_steps == 0: + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if accelerator.is_main_process: + if args.validation_prompts is not None and global_step % args.validation_steps == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_transformer3d.store(transformer3d.parameters()) + ema_transformer3d.copy_to(transformer3d.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + transformer3d, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original transformer3d parameters. + ema_transformer3d.restore(transformer3d.parameters()) + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_transformer3d.store(transformer3d.parameters()) + ema_transformer3d.copy_to(transformer3d.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + transformer3d, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original transformer3d parameters. + ema_transformer3d.restore(transformer3d.parameters()) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if args.use_deepspeed or args.use_fsdp or accelerator.is_main_process: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/scripts/z_image/train_omni.sh b/scripts/z_image/train_omni.sh new file mode 100644 index 00000000..335c11ac --- /dev/null +++ b/scripts/z_image/train_omni.sh @@ -0,0 +1,33 @@ +export MODEL_NAME="models/Diffusion_Transformer/Z-Image-Omni" +export DATASET_NAME="datasets/internal_datasets/" +export DATASET_META_NAME="datasets/internal_datasets/metadata_edit.json" +# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. +# export NCCL_IB_DISABLE=1 +# export NCCL_P2P_DISABLE=1 +NCCL_DEBUG=INFO + +accelerate launch --mixed_precision="bf16" scripts/z_image/train_omni.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --train_data_dir=$DATASET_NAME \ + --train_data_meta=$DATASET_META_NAME \ + --train_batch_size=1 \ + --image_sample_size=1328 \ + --gradient_accumulation_steps=1 \ + --dataloader_num_workers=8 \ + --num_train_epochs=100 \ + --checkpointing_steps=100 \ + --learning_rate=2e-05 \ + --lr_scheduler="constant_with_warmup" \ + --lr_warmup_steps=100 \ + --seed=42 \ + --output_dir="output_dir_z_image_omni" \ + --gradient_checkpointing \ + --mixed_precision="bf16" \ + --adam_weight_decay=3e-2 \ + --adam_epsilon=1e-10 \ + --vae_mini_batch=1 \ + --max_grad_norm=0.05 \ + --enable_bucket \ + --uniform_sampling \ + --random_hw_adapt \ + --trainable_modules "." \ No newline at end of file diff --git a/scripts/z_image_fun/train_control_distill.py b/scripts/z_image_fun/train_control_distill.py index 27ddb3a8..6998fe22 100644 --- a/scripts/z_image_fun/train_control_distill.py +++ b/scripts/z_image_fun/train_control_distill.py @@ -2055,4 +2055,4 @@ def custom_mse_loss(noise_pred, target, weighting=None, threshold=50): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/videox_fun/models/__init__.py b/videox_fun/models/__init__.py index 24979859..2be12346 100755 --- a/videox_fun/models/__init__.py +++ b/videox_fun/models/__init__.py @@ -6,8 +6,9 @@ CLIPVisionModelWithProjection, LlamaModel, LlamaTokenizerFast, LlavaForConditionalGeneration, Mistral3ForConditionalGeneration, PixtralProcessor, - Qwen3Config, Qwen3ForCausalLM, T5EncoderModel, - T5Tokenizer, T5TokenizerFast, UMT5EncoderModel) + Qwen3Config, Qwen3ForCausalLM, Siglip2VisionModel, + T5EncoderModel, T5Tokenizer, T5TokenizerFast, + UMT5EncoderModel) try: from transformers import (Qwen2_5_VLConfig, @@ -47,6 +48,7 @@ from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8 from .z_image_transformer2d import ZImageTransformer2DModel from .z_image_transformer2d_control import ZImageControlTransformer2DModel +from .z_image_transformer2d_omni import ZImageOmniTransformer2DModel # The pai_fuser is an internally developed acceleration package, which can be used on PAI. if importlib.util.find_spec("paifuser") is not None: diff --git a/videox_fun/models/z_image_transformer2d.py b/videox_fun/models/z_image_transformer2d.py index 78f22905..aa38c5ae 100644 --- a/videox_fun/models/z_image_transformer2d.py +++ b/videox_fun/models/z_image_transformer2d.py @@ -52,17 +52,9 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): if mid_size is None: mid_size = out_size self.mlp = nn.Sequential( - nn.Linear( - frequency_embedding_size, - mid_size, - bias=True, - ), + nn.Linear(frequency_embedding_size, mid_size, bias=True), nn.SiLU(), - nn.Linear( - mid_size, - out_size, - bias=True, - ), + nn.Linear(mid_size, out_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @@ -83,8 +75,11 @@ def timestep_embedding(t, dim, max_period=10000): def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) weight_dtype = self.mlp[0].weight.dtype + compute_dtype = getattr(self.mlp[0], "compute_dtype", None) if weight_dtype.is_floating_point: t_freq = t_freq.to(weight_dtype) + elif compute_dtype is not None: + t_freq = t_freq.to(compute_dtype) t_emb = self.mlp(t_freq) return t_emb diff --git a/videox_fun/models/z_image_transformer2d_omni.py b/videox_fun/models/z_image_transformer2d_omni.py new file mode 100644 index 00000000..2eb0b1a0 --- /dev/null +++ b/videox_fun/models/z_image_transformer2d_omni.py @@ -0,0 +1,922 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.attention_processor import Attention +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import RMSNorm +from diffusers.utils.torch_utils import maybe_allow_in_graph +from torch.nn.utils.rnn import pad_sequence + +from .attention_utils import attention +from .z_image_transformer2d import (ADALN_EMBED_DIM, SEQ_MULTI_OF, FeedForward, + RopeEmbedder, TimestepEmbedder) + + +def select_per_token( + value_noisy: torch.Tensor, + value_clean: torch.Tensor, + noise_mask: torch.Tensor, + seq_len: int, +) -> torch.Tensor: + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) + return torch.where( + noise_mask_expanded == 1, + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), + value_clean.unsqueeze(1).expand(-1, seq_len, -1), + ) + + +def extract_seqlens_from_mask(attn_mask): + if attn_mask is None: + return None + + if len(attn_mask.shape) == 4: + bs, _, _, seq_len = attn_mask.shape + + if attn_mask.dtype == torch.bool: + valid_mask = attn_mask.squeeze(1).squeeze(1) + else: + valid_mask = ~torch.isinf(attn_mask.squeeze(1).squeeze(1)) + elif len(attn_mask.shape) == 3: + raise ValueError( + "attn_mask should be 2D or 4D tensor, but got {}".format( + attn_mask.shape)) + + seqlens = valid_mask.sum(dim=1) + return seqlens + + +class ZSingleStreamAttnProcessor: + """ + Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the + original Z-ImageAttention module. + """ + + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(-1, (attn.heads, -1)) + key = key.unflatten(-1, (attn.heads, -1)) + value = value.unflatten(-1, (attn.heads, -1)) + + # Apply Norms + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE + def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", enabled=False): + x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.unsqueeze(2) + x_out = torch.view_as_real(x * freqs_cis).flatten(3) + return x_out.type_as(x_in) # todo + + if freqs_cis is not None: + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + + # Cast to correct dtype + dtype = query.dtype + query, key = query.to(dtype), key.to(dtype) + + # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] + if attention_mask is not None and attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + if attention_mask is not None: + q_lens = k_lens = extract_seqlens_from_mask(attention_mask) + + hidden_states = torch.zeros_like(query) + for i in range(len(q_lens)): + hidden_states[i][:q_lens[i]] = attention( + query[i][:q_lens[i]].unsqueeze(0), + key[i][:q_lens[i]].unsqueeze(0), + value[i][:q_lens[i]].unsqueeze(0), + attn_mask=None, + ) + else: + hidden_states = attention( + query, key, value, attn_mask=attention_mask, + ) + + # Reshape back + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.to(dtype) + + output = attn.to_out[0](hidden_states) + if len(attn.to_out) > 1: # dropout + output = attn.to_out[1](output) + + return output + + +@maybe_allow_in_graph +class ZImageTransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + modulation=True, + ): + super().__init__() + self.dim = dim + self.head_dim = dim // n_heads + + # Refactored to use diffusers Attention with custom processor + # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm + self.attention = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // n_heads, + heads=n_heads, + qk_norm="rms_norm" if qk_norm else None, + eps=1e-5, + bias=False, + out_bias=False, + processor=ZSingleStreamAttnProcessor(), + ) + + self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) + self.layer_id = layer_id + + self.attention_norm1 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.attention_norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + self.modulation = modulation + if modulation: + self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True)) + + def forward( + self, + x: torch.Tensor, + attn_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + noise_mask: Optional[torch.Tensor] = None, + adaln_noisy: Optional[torch.Tensor] = None, + adaln_clean: Optional[torch.Tensor] = None, + ): + if self.modulation: + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation: different modulation for noisy/clean tokens + mod_noisy = self.adaLN_modulation(adaln_noisy) + mod_clean = self.adaLN_modulation(adaln_clean) + + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) + + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() + + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean + + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) + else: + # Global modulation: same modulation for all tokens (avoid double select) + mod = self.adaLN_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp + + # Attention block + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + # FFN block + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) + else: + # Attention block + attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis) + x = x + self.attention_norm2(attn_out) + + # FFN block + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + + return x + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, out_channels, bias=True) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), + ) + + def forward(self, x, c=None, noise_mask=None, c_noisy=None, c_clean=None): + seq_len = x.shape[1] + + if noise_mask is not None: + # Per-token modulation + scale_noisy = 1.0 + self.adaLN_modulation(c_noisy) + scale_clean = 1.0 + self.adaLN_modulation(c_clean) + scale = select_per_token(scale_noisy, scale_clean, noise_mask, seq_len) + else: + # Original global modulation + assert c is not None, "Either c or (c_noisy, c_clean) must be provided" + scale = 1.0 + self.adaLN_modulation(c) + scale = scale.unsqueeze(1) + + x = self.norm_final(x) * scale + x = self.linear(x) + return x + + +class ZImageOmniTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + _supports_gradient_checkpointing = True + # _no_split_modules = ["ZImageTransformerBlock"] + # _repeated_blocks = ["ZImageTransformerBlock"] + # _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers + + @register_to_config + def __init__( + self, + all_patch_size=(2,), + all_f_patch_size=(1,), + in_channels=16, + dim=3840, + n_layers=30, + n_refiner_layers=2, + n_heads=30, + n_kv_heads=30, + norm_eps=1e-5, + qk_norm=True, + cap_feat_dim=2560, + siglip_feat_dim=None, + rope_theta=256.0, + t_scale=1000.0, + axes_dims=[32, 48, 48], + axes_lens=[1024, 512, 512], + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.all_patch_size = all_patch_size + self.all_f_patch_size = all_f_patch_size + self.dim = dim + self.n_heads = n_heads + + self.rope_theta = rope_theta + self.t_scale = t_scale + self.gradient_checkpointing = False + + assert len(all_patch_size) == len(all_f_patch_size) + + all_x_embedder = {} + all_final_layer = {} + for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)): + x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) + all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer + + self.all_x_embedder = nn.ModuleDict(all_x_embedder) + self.all_final_layer = nn.ModuleDict(all_final_layer) + self.noise_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 1000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=True, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.context_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) + self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) + + # Optional SigLIP components (for Omni variant) + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + RMSNorm(siglip_feat_dim, eps=norm_eps), nn.Linear(siglip_feat_dim, dim, bias=True) + ) + self.siglip_refiner = nn.ModuleList( + [ + ZImageTransformerBlock( + 2000 + layer_id, + dim, + n_heads, + n_kv_heads, + norm_eps, + qk_norm, + modulation=False, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim))) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + + self.x_pad_token = nn.Parameter(torch.empty((1, dim))) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) + + self.layers = nn.ModuleList( + [ + ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm) + for layer_id in range(n_layers) + ] + ) + head_dim = dim // n_heads + assert head_dim == sum(axes_dims) + self.axes_dims = axes_dims + self.axes_lens = axes_lens + + self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) + + def unpatchify( + self, + x: List[torch.Tensor], + size: List[Tuple], + patch_size, + f_patch_size, + x_pos_offsets: Optional[List[Tuple[int, int]]] = None, + ) -> List[torch.Tensor]: + pH = pW = patch_size + pF = f_patch_size + bsz = len(x) + assert len(size) == bsz + + # Omni: extract target image from unified sequence (cond_images + target) + result = [] + for i in range(bsz): + unified_x = x[i][x_pos_offsets[i][0] : x_pos_offsets[i][1]] + cu_len = 0 + x_item = None + for j in range(len(size[i])): + if size[i][j] is None: + ori_len = 0 + pad_len = SEQ_MULTI_OF + cu_len += pad_len + ori_len + else: + F, H, W = size[i][j] + ori_len = (F // pF) * (H // pH) * (W // pW) + pad_len = (-ori_len) % SEQ_MULTI_OF + x_item = ( + unified_x[cu_len : cu_len + ori_len] + .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels) + .permute(6, 0, 3, 1, 4, 2, 5) + .reshape(self.out_channels, F, H, W) + ) + cu_len += ori_len + pad_len + result.append(x_item) # Return only the last (target) image + return result + + @staticmethod + def create_coordinate_grid(size, start=None, device=None): + if start is None: + start = (0 for _ in size) + + axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] + grids = torch.meshgrid(axes, indexing="ij") + return torch.stack(grids, dim=-1) + + def patchify_and_embed( + self, + all_x: List[List[torch.Tensor]], + all_cap_feats: List[List[torch.Tensor]], + all_siglip_feats: List[List[torch.Tensor]], + patch_size: int, + f_patch_size: int, + images_noise_mask: List[List[int]], + ): + bsz = len(all_x) + pH = pW = patch_size + pF = f_patch_size + device = all_x[0][-1].device + dtype = all_x[0][-1].dtype + + all_x_padded = [] + all_x_size = [] + all_x_pos_ids = [] + all_x_pad_mask = [] + all_x_len = [] + all_x_noise_mask = [] + all_cap_padded_feats = [] + all_cap_pos_ids = [] + all_cap_pad_mask = [] + all_cap_len = [] + all_cap_noise_mask = [] + all_siglip_padded_feats = [] + all_siglip_pos_ids = [] + all_siglip_pad_mask = [] + all_siglip_len = [] + all_siglip_noise_mask = [] + + for i in range(bsz): + # Process captions + num_images = len(all_x[i]) + cap_padded_feats = [] + cap_item_cu_len = 1 + cap_start_pos = [] + cap_end_pos = [] + cap_padded_pos_ids = [] + cap_pad_mask = [] + cap_len = [] + cap_noise_mask = [] + + for j, cap_item in enumerate(all_cap_feats[i]): + cap_item_ori_len = len(cap_item) + cap_item_padding_len = (-cap_item_ori_len) % SEQ_MULTI_OF + cap_len.append(cap_item_ori_len + cap_item_padding_len) + + cap_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(cap_item_padding_len, 1) + ) + cap_start_pos.append(cap_item_cu_len) + cap_item_ori_pos_ids = self.create_coordinate_grid( + size=(cap_item_ori_len, 1, 1), start=(cap_item_cu_len, 0, 0), device=device + ).flatten(0, 2) + cap_padded_pos_ids.append(cap_item_ori_pos_ids) + cap_padded_pos_ids.append(cap_item_padding_pos_ids) + cap_item_cu_len += cap_item_ori_len + cap_end_pos.append(cap_item_cu_len) + cap_item_cu_len += 2 # for image vae tokens and siglip tokens + + cap_pad_mask.append(torch.zeros((cap_item_ori_len,), dtype=torch.bool, device=device)) + cap_pad_mask.append(torch.ones((cap_item_padding_len,), dtype=torch.bool, device=device)) + cap_item_padded_feat = torch.cat([cap_item, cap_item[-1:].repeat(cap_item_padding_len, 1)], dim=0) + cap_padded_feats.append(cap_item_padded_feat) + + if j < len(images_noise_mask[i]): + cap_noise_mask.extend([images_noise_mask[i][j]] * (cap_item_ori_len + cap_item_padding_len)) + else: + cap_noise_mask.extend([1] * (cap_item_ori_len + cap_item_padding_len)) + + all_cap_noise_mask.append(cap_noise_mask) + cap_padded_pos_ids = torch.cat(cap_padded_pos_ids, dim=0) + all_cap_pos_ids.append(cap_padded_pos_ids) + cap_pad_mask = torch.cat(cap_pad_mask, dim=0) + all_cap_pad_mask.append(cap_pad_mask) + all_cap_padded_feats.append(torch.cat(cap_padded_feats, dim=0)) + all_cap_len.append(cap_len) + + # Process images (x) + x_padded = [] + x_padded_pos_ids = [] + x_pad_mask = [] + x_len = [] + x_size = [] + x_noise_mask = [] + + for j, x_item in enumerate(all_x[i]): + if x_item is not None: + C, F, H, W = x_item.size() + x_size.append((F, H, W)) + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW + + x_item = x_item.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) + x_item = x_item.permute(1, 3, 5, 2, 4, 6, 0).reshape( + F_tokens * H_tokens * W_tokens, pF * pH * pW * C + ) + + x_item_ori_len = len(x_item) + x_item_padding_len = (-x_item_ori_len) % SEQ_MULTI_OF + x_len.append(x_item_ori_len + x_item_padding_len) + + x_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(x_item_padding_len, 1) + ) + x_item_ori_pos_ids = self.create_coordinate_grid( + size=(F_tokens, H_tokens, W_tokens), start=(cap_end_pos[j], 0, 0), device=device + ).flatten(0, 2) + x_padded_pos_ids.append(x_item_ori_pos_ids) + x_padded_pos_ids.append(x_item_padding_pos_ids) + + x_pad_mask.append(torch.zeros((x_item_ori_len,), dtype=torch.bool, device=device)) + x_pad_mask.append(torch.ones((x_item_padding_len,), dtype=torch.bool, device=device)) + x_item_padded_feat = torch.cat([x_item, x_item[-1:].repeat(x_item_padding_len, 1)], dim=0) + x_padded.append(x_item_padded_feat) + x_noise_mask.extend([images_noise_mask[i][j]] * (x_item_ori_len + x_item_padding_len)) + else: + x_pad_dim = 64 + x_item_padding_len = SEQ_MULTI_OF + x_size.append(None) + x_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(x_item_padding_len, 1) + ) + x_len.append(x_item_padding_len) + x_padded_pos_ids.append(x_item_padding_pos_ids) + x_pad_mask.append(torch.ones((x_item_padding_len,), dtype=torch.bool, device=device)) + x_padded.append(torch.zeros((x_item_padding_len, x_pad_dim), dtype=dtype, device=device)) + x_noise_mask.extend([images_noise_mask[i][j]] * x_item_padding_len) + + all_x_noise_mask.append(x_noise_mask) + all_x_size.append(x_size) + x_padded_pos_ids = torch.cat(x_padded_pos_ids, dim=0) + all_x_pos_ids.append(x_padded_pos_ids) + x_pad_mask = torch.cat(x_pad_mask, dim=0) + all_x_pad_mask.append(x_pad_mask) + all_x_padded.append(torch.cat(x_padded, dim=0)) + all_x_len.append(x_len) + + # Process siglip_feats + if all_siglip_feats[i] is None: + all_siglip_len.append([0 for _ in range(num_images)]) + all_siglip_padded_feats.append(None) + else: + sig_padded_feats = [] + sig_padded_pos_ids = [] + sig_pad_mask = [] + sig_len = [] + sig_noise_mask = [] + + for j, sig_item in enumerate(all_siglip_feats[i]): + if sig_item is not None: + sig_H, sig_W, sig_C = sig_item.size() + sig_H_tokens, sig_W_tokens, sig_F_tokens = sig_H, sig_W, 1 + + sig_item = sig_item.view(sig_C, sig_F_tokens, 1, sig_H_tokens, 1, sig_W_tokens, 1) + sig_item = sig_item.permute(1, 3, 5, 2, 4, 6, 0).reshape( + sig_F_tokens * sig_H_tokens * sig_W_tokens, sig_C + ) + + sig_item_ori_len = len(sig_item) + sig_item_padding_len = (-sig_item_ori_len) % SEQ_MULTI_OF + sig_len.append(sig_item_ori_len + sig_item_padding_len) + + sig_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(sig_item_padding_len, 1) + ) + sig_item_ori_pos_ids = self.create_coordinate_grid( + size=(sig_F_tokens, sig_H_tokens, sig_W_tokens), + start=(cap_end_pos[j] + 1, 0, 0), + device=device, + ) + # Scale position IDs to match x resolution + sig_item_ori_pos_ids[..., 1] = ( + sig_item_ori_pos_ids[..., 1] / (sig_H_tokens - 1) * (x_size[j][1] - 1) + ) + sig_item_ori_pos_ids[..., 2] = ( + sig_item_ori_pos_ids[..., 2] / (sig_W_tokens - 1) * (x_size[j][2] - 1) + ) + sig_item_ori_pos_ids = sig_item_ori_pos_ids.flatten(0, 2) + sig_padded_pos_ids.append(sig_item_ori_pos_ids) + sig_padded_pos_ids.append(sig_item_padding_pos_ids) + + sig_pad_mask.append(torch.zeros((sig_item_ori_len,), dtype=torch.bool, device=device)) + sig_pad_mask.append(torch.ones((sig_item_padding_len,), dtype=torch.bool, device=device)) + sig_item_padded_feat = torch.cat( + [sig_item, sig_item[-1:].repeat(sig_item_padding_len, 1)], dim=0 + ) + sig_padded_feats.append(sig_item_padded_feat) + sig_noise_mask.extend([images_noise_mask[i][j]] * (sig_item_ori_len + sig_item_padding_len)) + else: + sig_pad_dim = self.config.siglip_feat_dim or 1152 + sig_item_padding_len = SEQ_MULTI_OF + sig_item_padding_pos_ids = ( + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(sig_item_padding_len, 1) + ) + sig_padded_pos_ids.append(sig_item_padding_pos_ids) + sig_pad_mask.append(torch.ones((sig_item_padding_len,), dtype=torch.bool, device=device)) + sig_padded_feats.append( + torch.zeros((sig_item_padding_len, sig_pad_dim), dtype=dtype, device=device) + ) + sig_noise_mask.extend([images_noise_mask[i][j]] * sig_item_padding_len) + + all_siglip_noise_mask.append(sig_noise_mask) + sig_padded_pos_ids = torch.cat(sig_padded_pos_ids, dim=0) + all_siglip_pos_ids.append(sig_padded_pos_ids) + sig_pad_mask = torch.cat(sig_pad_mask, dim=0) + all_siglip_pad_mask.append(sig_pad_mask) + all_siglip_padded_feats.append(torch.cat(sig_padded_feats, dim=0)) + all_siglip_len.append(sig_len) + + # Compute x position offsets + all_x_pos_offsets = [] + for i in range(bsz): + start = sum(all_cap_len[i]) + end = start + sum(all_x_len[i]) + all_x_pos_offsets.append((start, end)) + + return ( + all_x_padded, + all_cap_padded_feats, + all_siglip_padded_feats, + all_x_size, + all_x_pos_ids, + all_cap_pos_ids, + all_siglip_pos_ids, + all_x_pad_mask, + all_cap_pad_mask, + all_siglip_pad_mask, + all_x_pos_offsets, + all_x_noise_mask, + all_cap_noise_mask, + all_siglip_noise_mask, + ) + + + def forward( + self, + x: List[List[torch.Tensor]], + t, + cap_feats: List[List[torch.Tensor]], + siglip_feats: List[List[torch.Tensor]], + image_noise_mask: List[List[int]], + patch_size=2, + f_patch_size=1, + return_dict: bool = True, + ): + """Omni mode forward pass with image conditioning.""" + bsz = len(x) + device = x[0][-1].device # From target latent + + # Create dual timestep embeddings: one for noisy tokens (t), one for clean tokens (t=1) + t_noisy = self.t_embedder(t * self.t_scale) + t_clean = self.t_embedder(torch.ones_like(t) * self.t_scale) + + # Patchify and embed for Omni mode + ( + x, + cap_feats, + siglip_feats, + x_size, + x_pos_ids, + cap_pos_ids, + siglip_pos_ids, + x_inner_pad_mask, + cap_inner_pad_mask, + siglip_inner_pad_mask, + x_pos_offsets, + x_noise_mask, + cap_noise_mask, + siglip_noise_mask, + ) = self.patchify_and_embed(x, cap_feats, siglip_feats, patch_size, f_patch_size, image_noise_mask) + + # x embed & refine + x_item_seqlens = [len(_) for _ in x] + assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) + x_max_item_seqlen = max(x_item_seqlens) + + x = torch.cat(x, dim=0) + x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) + + x[torch.cat(x_inner_pad_mask)] = self.x_pad_token + x = list(x.split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) + + x = pad_sequence(x, batch_first=True, padding_value=0.0) + x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + + # Create x_noise_mask tensor + x_noise_mask_tensor = [] + for i in range(bsz): + x_mask = torch.tensor(x_noise_mask[i], dtype=torch.long, device=device) + x_noise_mask_tensor.append(x_mask) + x_noise_mask_tensor = pad_sequence(x_noise_mask_tensor, batch_first=True, padding_value=0) + x_noise_mask_tensor = x_noise_mask_tensor[:, : x.shape[1]] + + # Match t_embedder output dtype to x + t_noisy_x = t_noisy.type_as(x) + t_clean_x = t_clean.type_as(x) + + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(x_item_seqlens): + x_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.noise_refiner: + x = self._gradient_checkpointing_func( + layer, x, + x_attn_mask, x_freqs_cis, None, + x_noise_mask_tensor, + t_noisy_x, t_clean_x + ) + else: + for layer in self.noise_refiner: + x = layer( + x, + x_attn_mask, + x_freqs_cis, + noise_mask=x_noise_mask_tensor, + adaln_noisy=t_noisy_x, + adaln_clean=t_clean_x, + ) + + # cap embed & refine (no modulation) + cap_item_seqlens = [len(_) for _ in cap_feats] + cap_max_item_seqlen = max(cap_item_seqlens) + + cap_feats = torch.cat(cap_feats, dim=0) + cap_feats = self.cap_embedder(cap_feats) + cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token + cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list( + self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) + ) + + cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) + cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(cap_item_seqlens): + cap_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.context_refiner: + cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis) + else: + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) + + # siglip embed & refine (if available) + siglip_item_seqlens = None + if siglip_feats[0] is not None and self.siglip_embedder is not None: + siglip_item_seqlens = [len(_) for _ in siglip_feats] + siglip_max_item_seqlen = max(siglip_item_seqlens) + + siglip_feats = torch.cat(siglip_feats, dim=0) + siglip_feats = self.siglip_embedder(siglip_feats) + siglip_feats[torch.cat(siglip_inner_pad_mask)] = self.siglip_pad_token + siglip_feats = list(siglip_feats.split(siglip_item_seqlens, dim=0)) + siglip_freqs_cis = list( + self.rope_embedder(torch.cat(siglip_pos_ids, dim=0)).split([len(_) for _ in siglip_pos_ids], dim=0) + ) + + siglip_feats = pad_sequence(siglip_feats, batch_first=True, padding_value=0.0) + siglip_freqs_cis = pad_sequence(siglip_freqs_cis, batch_first=True, padding_value=0.0) + siglip_freqs_cis = siglip_freqs_cis[:, : siglip_feats.shape[1]] + + siglip_attn_mask = torch.zeros((bsz, siglip_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(siglip_item_seqlens): + siglip_attn_mask[i, :seq_len] = 1 + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer in self.siglip_refiner: + siglip_feats = self._gradient_checkpointing_func( + layer, siglip_feats, siglip_attn_mask, siglip_freqs_cis + ) + else: + for layer in self.siglip_refiner: + siglip_feats = layer(siglip_feats, siglip_attn_mask, siglip_freqs_cis) + + # Build unified sequence + unified = [] + unified_freqs_cis = [] + unified_noise_mask = [] + + if siglip_item_seqlens is not None: + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + siglip_len = siglip_item_seqlens[i] + unified.append(torch.cat([cap_feats[i][:cap_len], x[i][:x_len], siglip_feats[i][:siglip_len]])) + unified_freqs_cis.append( + torch.cat([cap_freqs_cis[i][:cap_len], x_freqs_cis[i][:x_len], siglip_freqs_cis[i][:siglip_len]]) + ) + unified_noise_mask.append( + torch.tensor( + cap_noise_mask[i] + x_noise_mask[i] + siglip_noise_mask[i], + dtype=torch.long, + device=device, + ) + ) + unified_item_seqlens = [ + a + b + c for a, b, c in zip(cap_item_seqlens, x_item_seqlens, siglip_item_seqlens) + ] + else: + for i in range(bsz): + x_len = x_item_seqlens[i] + cap_len = cap_item_seqlens[i] + unified.append(torch.cat([cap_feats[i][:cap_len], x[i][:x_len]])) + unified_freqs_cis.append(torch.cat([cap_freqs_cis[i][:cap_len], x_freqs_cis[i][:x_len]])) + unified_noise_mask.append( + torch.tensor(cap_noise_mask[i] + x_noise_mask[i], dtype=torch.long, device=device) + ) + unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] + + unified_max_item_seqlen = max(unified_item_seqlens) + + unified = pad_sequence(unified, batch_first=True, padding_value=0.0) + unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0) + unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) + for i, seq_len in enumerate(unified_item_seqlens): + unified_attn_mask[i, :seq_len] = 1 + + # Create unified_noise_mask tensor + unified_noise_mask_tensor = pad_sequence(unified_noise_mask, batch_first=True, padding_value=0) + unified_noise_mask_tensor = unified_noise_mask_tensor[:, : unified.shape[1]] + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for layer_idx, layer in enumerate(self.layers): + unified = self._gradient_checkpointing_func( + layer, + unified, unified_attn_mask, unified_freqs_cis, None, + unified_noise_mask_tensor, + t_noisy_x, t_clean_x + ) + else: + for layer_idx, layer in enumerate(self.layers): + unified = layer( + unified, + unified_attn_mask, + unified_freqs_cis, + noise_mask=unified_noise_mask_tensor, + adaln_noisy=t_noisy_x, + adaln_clean=t_clean_x, + ) + + unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"]( + unified, noise_mask=unified_noise_mask_tensor, c_noisy=t_noisy_x, c_clean=t_clean_x + ) + + x = self.unpatchify(unified, x_size, patch_size, f_patch_size, x_pos_offsets) + x = torch.stack(x) + if not return_dict: + return (x,) + + return Transformer2DModelOutput(sample=x) \ No newline at end of file diff --git a/videox_fun/pipeline/__init__.py b/videox_fun/pipeline/__init__.py index 92e5a008..94018f68 100755 --- a/videox_fun/pipeline/__init__.py +++ b/videox_fun/pipeline/__init__.py @@ -27,6 +27,7 @@ from .pipeline_wan_vace import WanVacePipeline from .pipeline_z_image import ZImagePipeline from .pipeline_z_image_control import ZImageControlPipeline +from .pipeline_z_image_omni import ZImageOmniPipeline WanFunPipeline = WanPipeline WanI2VPipeline = WanFunInpaintPipeline diff --git a/videox_fun/pipeline/pipeline_z_image_omni.py b/videox_fun/pipeline/pipeline_z_image_omni.py new file mode 100644 index 00000000..a402cd5b --- /dev/null +++ b/videox_fun/pipeline/pipeline_z_image_omni.py @@ -0,0 +1,762 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin +from diffusers.models.autoencoders import AutoencoderKL +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from transformers import (AutoProcessor, AutoTokenizer, PreTrainedModel, + Siglip2VisionModel) + +from ..models import AutoencoderKL, ZImageOmniTransformer2DModel +from ..models.flux2_image_processor import Flux2ImageProcessor + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImageOmniPipeline + + >>> pipe = ZImageOmniPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + ... prompt, + ... height=1024, + ... width=1024, + ... num_inference_steps=9, + ... guidance_scale=0.0, + ... generator=torch.Generator("cuda").manual_seed(42), + ... ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class ZImagePipelineOutput(BaseOutput): + """ + Output class for Z-Image image generation pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class ZImageOmniPipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageOmniTransformer2DModel, + siglip: Siglip2VisionModel, + siglip_processor: AutoProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + siglip=siglip, + siglip_processor=siglip_processor, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + # self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + num_condition_images: int = 0, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + num_condition_images: int = 0, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + if num_condition_images == 0: + prompt[i] = ["<|im_start|>user\n" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n"] + elif num_condition_images > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (num_condition_images - 1) + prompt_list += ["<|vision_end|>" + prompt_item + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|im_end|>"] + prompt[i] = prompt_list + + flattened_prompt = [] + prompt_list_lengths = [] + + for i in range(len(prompt)): + prompt_list_lengths.append(len(prompt[i])) + flattened_prompt.extend(prompt[i]) + + text_inputs = self.tokenizer( + flattened_prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + start_idx = 0 + for i in range(len(prompt_list_lengths)): + batch_embeddings = [] + end_idx = start_idx + prompt_list_lengths[i] + for j in range(start_idx, end_idx): + batch_embeddings.append(prompt_embeds[j][prompt_masks[j]]) + embeddings_list.append(batch_embeddings) + start_idx = end_idx + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + def prepare_image_latents( + self, + images: List[torch.Tensor], + batch_size, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + image_latent = ( + self.vae.encode(image.bfloat16()).latent_dist.mode()[0] - self.vae.config.shift_factor + ) * self.vae.config.scaling_factor + image_latent = image_latent.unsqueeze(1).to(dtype) + image_latents.append(image_latent) # (16, 128, 128) + + # image_latents = [image_latents] * batch_size + image_latents = [image_latents.copy() for _ in range(batch_size)] + + return image_latents + + def prepare_siglip_embeds( + self, + images: List[torch.Tensor], + batch_size, + device, + dtype, + ): + siglip_embeds = [] + for image in images: + siglip_inputs = self.siglip_processor(images=[image], return_tensors="pt").to(device) + shape = siglip_inputs.spatial_shapes[0] + hidden_state = self.siglip(**siglip_inputs).last_hidden_state + B, N, C = hidden_state.shape + hidden_state = hidden_state[:, : shape[0] * shape[1]] + hidden_state = hidden_state.view(shape[0], shape[1], C) + siglip_embeds.append(hidden_state.to(dtype)) + + # siglip_embeds = [siglip_embeds] * batch_size + siglip_embeds = [siglip_embeds.copy() for _ in range(batch_size)] + + return siglip_embeds + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + if image is not None and not isinstance(image, list): + image = [image] + num_condition_images = len(image) if image is not None else 0 + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + num_condition_images=num_condition_images, + ) + + # 3. Process condition images. Copied from diffusers.pipelines.flux2.pipeline_flux2 + condition_images = [] + resized_images = [] + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + for img in image: + image_width, image_height = img.size + # if image_width * image_height > 1024 * 1024: + if height is not None and width is not None: + img = self.image_processor._resize_to_target_area(img, height * width) + else: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + resized_images.append(img) + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + + if len(condition_images) > 0: + height = height or image_height + width = width or image_width + + else: + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + condition_latents = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + device=device, + dtype=torch.float32, + ) + condition_latents = [[lat.to(self.transformer.dtype) for lat in lats] for lats in condition_latents] + if self.do_classifier_free_guidance: + negative_condition_latents = [[lat.clone() for lat in batch] for batch in condition_latents] + + condition_siglip_embeds = self.prepare_siglip_embeds( + images=resized_images, + batch_size=batch_size * num_images_per_prompt, + device=device, + dtype=torch.float32, + ) + condition_siglip_embeds = [[se.to(self.transformer.dtype) for se in sels] for sels in condition_siglip_embeds] + if self.do_classifier_free_guidance: + negative_condition_siglip_embeds = [[se.clone() for se in batch] for batch in condition_siglip_embeds] + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + condition_siglip_embeds = [None if sels == [] else sels + [None] for sels in condition_siglip_embeds] + negative_condition_siglip_embeds = [ + None if sels == [] else sels + [None] for sels in negative_condition_siglip_embeds + ] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + condition_latents_model_input = condition_latents + negative_condition_latents + condition_siglip_embeds_model_input = condition_siglip_embeds + negative_condition_siglip_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + condition_latents_model_input = condition_latents + condition_siglip_embeds_model_input = condition_siglip_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + # Combine condition latents with target latent + current_batch_size = len(latent_model_input_list) + x_combined = [ + condition_latents_model_input[i] + [latent_model_input_list[i]] for i in range(current_batch_size) + ] + # Create noise mask: 0 for condition images (clean), 1 for target image (noisy) + image_noise_mask = [ + [0] * len(condition_latents_model_input[i]) + [1] for i in range(current_batch_size) + ] + + model_out_list = self.transformer( + x=x_combined, + t=timestep_model_input, + cap_feats=prompt_embeds_model_input, + siglip_feats=condition_siglip_embeds_model_input, + image_noise_mask=image_noise_mask, + return_dict=False, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) diff --git a/videox_fun/utils/lora_utils.py b/videox_fun/utils/lora_utils.py index f94341b4..f25164e5 100755 --- a/videox_fun/utils/lora_utils.py +++ b/videox_fun/utils/lora_utils.py @@ -157,7 +157,7 @@ class LoRANetwork(torch.nn.Module): "Wan2_2Transformer3DModel", "FluxTransformer2DModel", "QwenImageTransformer2DModel", \ "Wan2_2Transformer3DModel_Animate", "Wan2_2Transformer3DModel_S2V", "FantasyTalkingTransformer3DModel", \ "HunyuanVideoTransformer3DModel", "Flux2Transformer2DModel", "ZImageTransformer2DModel", \ - "LongCatVideoTransformer3DModel", + "LongCatVideoTransformer3DModel", "ZImageOmniTransformer2DModel", ] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder", "T5SelfAttention", "T5CrossAttention"] LORA_PREFIX_TRANSFORMER = "lora_unet"