From ff4a3b532596d9ece180deb945dd3c073e98f0a3 Mon Sep 17 00:00:00 2001 From: hlky <106811348+hlky@users.noreply.github.com> Date: Sat, 17 Sep 2022 21:13:25 +0100 Subject: [PATCH 1/5] Additional k-diffusion samples --- .../pipeline_stable_diffusion.py | 30 +-- src/diffusers/schedulers/__init__.py | 6 + .../scheduling_dpm2_ancestral_discrete.py | 185 +++++++++++++++++ .../schedulers/scheduling_dpm2_discrete.py | 182 +++++++++++++++++ .../scheduling_euler_ancestral_discrete.py | 169 ++++++++++++++++ .../schedulers/scheduling_euler_discrete.py | 176 +++++++++++++++++ .../schedulers/scheduling_heun_discrete.py | 186 ++++++++++++++++++ 7 files changed, 920 insertions(+), 14 deletions(-) create mode 100644 src/diffusers/schedulers/scheduling_dpm2_ancestral_discrete.py create mode 100644 src/diffusers/schedulers/scheduling_dpm2_discrete.py create mode 100644 src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py create mode 100644 src/diffusers/schedulers/scheduling_euler_discrete.py create mode 100644 src/diffusers/schedulers/scheduling_heun_discrete.py diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 8ae51999a7b3..1fe05afe4ac0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -8,7 +8,9 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, \ +DPM2DiscreteScheduler, DPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, \ +EulerAncestralDiscreteScheduler, HeunDiscreteScheduler from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -51,7 +53,8 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[DDIMScheduler, PNDMScheduler, DPM2DiscreteScheduler, DPM2AncestralDiscreteScheduler, + LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, ): @@ -301,12 +304,9 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - # Some schedulers like PNDM have timesteps as arrays - # It's more optimized to move all timesteps to correct device beforehand - timesteps_tensor = self.scheduler.timesteps.to(self.device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas + if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): + latents = latents * self.scheduler.sigmas[0] # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -320,7 +320,10 @@ def __call__( for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): + sigma = self.scheduler.sigmas[i] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -331,11 +334,10 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index a906c39eb24c..9d27ee8b28b3 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -41,5 +41,11 @@ if is_scipy_available() and is_torch_available(): from .scheduling_lms_discrete import LMSDiscreteScheduler + from .scheduling_dpm2_discrete import DPM2DiscreteScheduler + from .scheduling_dpm2_ancestral_discrete import DPM2AncestralDiscreteScheduler + from .scheduling_euler_discrete import EulerDiscreteScheduler + from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler + from .scheduling_heun_discrete import HeunDiscreteScheduler + else: from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 diff --git a/src/diffusers/schedulers/scheduling_dpm2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_dpm2_ancestral_discrete.py new file mode 100644 index 000000000000..cf859df67aeb --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dpm2_ancestral_discrete.py @@ -0,0 +1,185 @@ +# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class DPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Ancestral sampling with DPM-Solver inspired second-order steps. + for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L145 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, #sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + tensor_format: str = "pt", + ): + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.derivatives = [] + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + + low_idx = np.floor(self.timesteps).astype(int) + high_idx = np.ceil(self.timesteps).astype(int) + frac = np.mod(self.timesteps, 1.0) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + + self.derivatives = [] + + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + s_churn: float = 0., + s_tmin: float = 0., + s_tmax: float = float('inf'), + s_noise: float = 1., + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + s_churn (`float`) + s_tmin (`float`) + s_tmax (`float`) + s_noise (`float`) + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + sigma = self.sigmas[timestep] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma * model_output + sigma_from = sigma + sigma_to = self.sigmas[timestep + 1] + sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + self.derivatives.append(derivative) + + # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule + sigma_mid = ((sigma ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 + + dt_1 = sigma_mid - sigma + dt_2 = sigma_down - sigma + sample_2 = sample + derivative * dt_1 + pred_original_sample_2 = sample_2 - sigma_mid * model_output + derivative_2 = (sample_2 - pred_original_sample_2) / sigma_mid + sample = sample + derivative_2 * dt_2 + sample = sample + torch.randn_like(sample) * sigma_up + + prev_sample = sample + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> Union[torch.FloatTensor, np.ndarray]: + if self.tensor_format == "pt": + timesteps = timesteps.to(self.sigmas.device) + sigmas = self.match_shape(self.sigmas[timesteps], noise) + noisy_samples = original_samples + noise * sigmas + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpm2_discrete.py b/src/diffusers/schedulers/scheduling_dpm2_discrete.py new file mode 100644 index 000000000000..7083f48aaeac --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dpm2_discrete.py @@ -0,0 +1,182 @@ +# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class DPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022). + for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L119 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, #sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + tensor_format: str = "pt", + ): + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.derivatives = [] + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + + low_idx = np.floor(self.timesteps).astype(int) + high_idx = np.ceil(self.timesteps).astype(int) + frac = np.mod(self.timesteps, 1.0) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + + self.derivatives = [] + + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + s_churn: float = 0., + s_tmin: float = 0., + s_tmax: float = float('inf'), + s_noise: float = 1., + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + s_churn (`float`) + s_tmin (`float`) + s_tmax (`float`) + s_noise (`float`) + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + sigma = self.sigmas[timestep] + gamma = min(s_churn / (len(self.sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0. + eps = torch.randn_like(sample) * s_noise + sigma_hat = sigma * (gamma + 1) + if gamma > 0: + sample = sample + eps * (sigma_hat ** 2 - sigma ** 2) ** 0.5 + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma_hat * model_output + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + self.derivatives.append(derivative) + + sigma_mid = ((sigma_hat ** (1 / 3) + self.sigmas[timestep + 1] ** (1 / 3)) / 2) ** 3 + dt_1 = sigma_mid - sigma_hat + dt_2 = self.sigmas[timestep + 1] - sigma_hat + sample_2 = sample + derivative * dt_1 + pred_original_sample_2 = sample_2 - sigma_mid * model_output + derivative_2 = (sample_2 - pred_original_sample_2) / sigma_mid + sample = sample + derivative_2 * dt_2 + + prev_sample = sample + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> Union[torch.FloatTensor, np.ndarray]: + if self.tensor_format == "pt": + timesteps = timesteps.to(self.sigmas.device) + sigmas = self.match_shape(self.sigmas[timesteps], noise) + noisy_samples = original_samples + noise * sigmas + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py new file mode 100644 index 000000000000..fa2b49483063 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -0,0 +1,169 @@ +# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Ancestral sampling with Euler method steps. + for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, #sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + tensor_format: str = "pt", + ): + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.derivatives = [] + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + + low_idx = np.floor(self.timesteps).astype(int) + high_idx = np.ceil(self.timesteps).astype(int) + frac = np.mod(self.timesteps, 1.0) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + + self.derivatives = [] + + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + sigma = self.sigmas[timestep] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma * model_output + sigma_from = self.sigmas[timestep] + sigma_to = self.sigmas[timestep + 1] + sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + self.derivatives.append(derivative) + + dt = sigma_down - sigma + + prev_sample = sample + derivative * dt + + prev_sample = prev_sample + torch.randn_like(prev_sample) * sigma_up + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> Union[torch.FloatTensor, np.ndarray]: + if self.tensor_format == "pt": + timesteps = timesteps.to(self.sigmas.device) + sigmas = self.match_shape(self.sigmas[timesteps], noise) + noisy_samples = original_samples + noise * sigmas + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py new file mode 100644 index 000000000000..df60266801f8 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -0,0 +1,176 @@ +# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Implements Algorithm 2 (Euler steps) from Karras et al. (2022). + for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, #sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + tensor_format: str = "pt", + ): + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.derivatives = [] + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + + low_idx = np.floor(self.timesteps).astype(int) + high_idx = np.ceil(self.timesteps).astype(int) + frac = np.mod(self.timesteps, 1.0) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + + self.derivatives = [] + + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + s_churn: float = 0., + s_tmin: float = 0., + s_tmax: float = float('inf'), + s_noise: float = 1., + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + s_churn (`float`) + s_tmin (`float`) + s_tmax (`float`) + s_noise (`float`) + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + sigma = self.sigmas[timestep] + gamma = min(s_churn / (len(self.sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0. + eps = torch.randn_like(sample) * s_noise + sigma_hat = sigma * (gamma + 1) + if gamma > 0: + sample = sample + eps * (sigma_hat ** 2 - sigma ** 2) ** 0.5 + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma_hat * model_output + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + self.derivatives.append(derivative) + + dt = self.sigmas[timestep + 1] - sigma_hat + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> Union[torch.FloatTensor, np.ndarray]: + if self.tensor_format == "pt": + timesteps = timesteps.to(self.sigmas.device) + sigmas = self.match_shape(self.sigmas[timesteps], noise) + noisy_samples = original_samples + noise * sigmas + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py new file mode 100644 index 000000000000..c92ef5276db4 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -0,0 +1,186 @@ +# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from scipy import integrate + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Implements Algorithm 2 (Heun steps) from Karras et al. (2022). + for discrete beta schedules. Based on the original k-diffusion implementation by + Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and + [`~ConfigMixin.from_config`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, + `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays. + + """ + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, #sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[np.ndarray] = None, + tensor_format: str = "pt", + ): + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + if beta_schedule == "linear": + self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.derivatives = [] + + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + self.num_inference_steps = num_inference_steps + self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + + low_idx = np.floor(self.timesteps).astype(int) + high_idx = np.ceil(self.timesteps).astype(int) + frac = np.mod(self.timesteps, 1.0) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] + self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + + self.derivatives = [] + + self.set_format(tensor_format=self.tensor_format) + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + s_churn: float = 0., + s_tmin: float = 0., + s_tmax: float = float('inf'), + s_noise: float = 1., + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor` or `np.ndarray`): + current instance of sample being created by diffusion process. + s_churn (`float`) + s_tmin (`float`) + s_tmax (`float`) + s_noise (`float`) + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + sigma = self.sigmas[timestep] + gamma = min(s_churn / (len(self.sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0. + eps = torch.randn_like(sample) * s_noise + sigma_hat = sigma * (gamma + 1) + if gamma > 0: + sample = sample + eps * (sigma_hat ** 2 - sigma ** 2) ** 0.5 + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + pred_original_sample = sample - sigma_hat * model_output + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + self.derivatives.append(derivative) + + dt = self.sigmas[timestep + 1] - sigma_hat + if self.sigmas[timestep + 1] == 0: + # Euler method + sample = sample + derivative * dt + else: + # Heun's method + sample_2 = sample + derivative * dt + pred_original_sample_2 = sample_2 - self.sigmas[timestep + 1] * model_output + derivative_2 = (sample_2 - pred_original_sample_2) / self.sigmas[timestep + 1] + d_prime = (derivative + derivative_2) / 2 + sample = sample + d_prime * dt + + prev_sample = sample + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: Union[torch.FloatTensor, np.ndarray], + noise: Union[torch.FloatTensor, np.ndarray], + timesteps: Union[torch.IntTensor, np.ndarray], + ) -> Union[torch.FloatTensor, np.ndarray]: + if self.tensor_format == "pt": + timesteps = timesteps.to(self.sigmas.device) + sigmas = self.match_shape(self.sigmas[timesteps], noise) + noisy_samples = original_samples + noise * sigmas + + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps From 8c1cfe549325207be1005d6cb7916c18b509efa2 Mon Sep 17 00:00:00 2001 From: Sean Date: Mon, 17 Oct 2022 16:18:35 +0200 Subject: [PATCH 2/5] feat: add support for new k-schedulers --- src/diffusers/__init__.py | 5 ++ .../pipeline_stable_diffusion_img2img.py | 53 +++++++++++++++---- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0a0b0b4965dd..ac19477fb9e5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -37,6 +37,11 @@ PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPM2DiscreteScheduler, + DPM2AncestralDiscreteScheduler, + HeunDiscreteScheduler ) from .training_utils import EMAModel else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 799fd459bbd9..b4b8f54e3763 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -10,7 +10,9 @@ from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, \ +DPM2DiscreteScheduler, DPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, \ +EulerAncestralDiscreteScheduler, HeunDiscreteScheduler from ...utils import deprecate, logging from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -63,7 +65,12 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[ + DDIMScheduler, PNDMScheduler, DPM2DiscreteScheduler, + DPM2AncestralDiscreteScheduler, LMSDiscreteScheduler, + EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, + HeunDiscreteScheduler + ], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, ): @@ -127,7 +134,7 @@ def disable_attention_slicing(self): Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go back to computing attention in one step. """ - # set slice_size = `None` to disable `set_attention_slice` + # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) @torch.no_grad() @@ -227,6 +234,30 @@ def __call__( if isinstance(init_image, PIL.Image.Image): init_image = preprocess(init_image) + # encode the init image into latents and scale the latents + init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # expand init_latents for batch_size + init_latents = torch.cat([init_latents] * batch_size) + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): + timesteps = torch.tensor( + [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device + ) + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device) + # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -346,7 +377,12 @@ def __call__( for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): + sigma = self.scheduler.sigmas[t_index] + # the model input needs to be scaled to match the continuous ODE formulation in K-LMS + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = latent_model_input.to(self.unet.dtype) + t = t.to(self.unet.dtype) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -357,11 +393,10 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) + if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): + latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample From 06d61f3fa0b1055facef4ecf1320983738f0bab8 Mon Sep 17 00:00:00 2001 From: Sean Date: Mon, 17 Oct 2022 23:40:01 +0200 Subject: [PATCH 3/5] feat: update euler_a and img2img + txt2img pipeline for 0.5.1 --- .../pipeline_stable_diffusion.py | 23 ++--- .../pipeline_stable_diffusion_img2img.py | 51 +++-------- .../scheduling_euler_ancestral_discrete.py | 86 ++++++++++++------- 3 files changed, 82 insertions(+), 78 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 1fe05afe4ac0..0e43ae80447e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -304,9 +304,12 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas - if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): - latents = latents * self.scheduler.sigmas[0] + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + timesteps_tensor = self.scheduler.timesteps.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -320,10 +323,7 @@ def __call__( for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): - sigma = self.scheduler.sigmas[i] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -334,10 +334,11 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index b4b8f54e3763..e4a0d016b95a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -134,7 +134,7 @@ def disable_attention_slicing(self): Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go back to computing attention in one step. """ - # set slice_size = `None` to disable `attention slicing` + # set slice_size = `None` to disable `set_attention_slice` self.enable_attention_slicing(None) @torch.no_grad() @@ -234,30 +234,6 @@ def __call__( if isinstance(init_image, PIL.Image.Image): init_image = preprocess(init_image) - # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - - # expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size) - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): - timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device - ) - else: - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device) - - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps).to(self.device) - # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -350,8 +326,13 @@ def __call__( init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): + timesteps = torch.tensor( + [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device + ) + else: + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) # add noise to latents using the timesteps noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) @@ -377,12 +358,7 @@ def __call__( for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): - sigma = self.scheduler.sigmas[t_index] - # the model input needs to be scaled to match the continuous ODE formulation in K-LMS - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - latent_model_input = latent_model_input.to(self.unet.dtype) - t = t.to(self.unet.dtype) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t, i) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -393,10 +369,11 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): - latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, i, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index fa2b49483063..d3186477d7aa 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -20,6 +20,7 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, deprecate from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -58,32 +59,53 @@ def __init__( beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) - if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + self.init_noise_sigma = None # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) self.derivatives = [] + self.is_scale_input_called = False + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor] + ) -> torch.FloatTensor: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain - def set_timesteps(self, num_inference_steps: int): + Returns: + `torch.FloatTensor`: scaled input sample + """ + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -92,23 +114,24 @@ def set_timesteps(self, num_inference_steps: int): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps - self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) low_idx = np.floor(self.timesteps).astype(int) high_idx = np.ceil(self.timesteps).astype(int) frac = np.mod(self.timesteps, 1.0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(self.timesteps) + self.init_noise_sigma = self.sigmas[0] self.derivatives = [] - self.set_format(tensor_format=self.tensor_format) - def step( self, model_output: Union[torch.FloatTensor, np.ndarray], - timestep: int, + timestep: Union[float, torch.FloatTensor], + step_index: Union[int, torch.IntTensor], sample: Union[torch.FloatTensor, np.ndarray], return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: @@ -129,12 +152,12 @@ def step( returning a tuple, the first element is the sample tensor. """ - sigma = self.sigmas[timestep] + sigma = self.sigmas[step_index] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise pred_original_sample = sample - sigma * model_output - sigma_from = self.sigmas[timestep] - sigma_to = self.sigmas[timestep + 1] + sigma_from = self.sigmas[step_index] + sigma_to = self.sigmas[step_index + 1] sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 # 2. Convert to an ODE derivative @@ -154,16 +177,19 @@ def step( def add_noise( self, - original_samples: Union[torch.FloatTensor, np.ndarray], - noise: Union[torch.FloatTensor, np.ndarray], - timesteps: Union[torch.IntTensor, np.ndarray], - ) -> Union[torch.FloatTensor, np.ndarray]: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.sigmas.device) - sigmas = self.match_shape(self.sigmas[timesteps], noise) - noisy_samples = original_samples + noise * sigmas - + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + self.timesteps = self.timesteps.to(original_samples.device) + sigma = self.sigmas[timesteps].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): - return self.config.num_train_timesteps + return self.config.num_train_timesteps \ No newline at end of file From ccad44685af340579b88e69a6ae09631b2a27cdd Mon Sep 17 00:00:00 2001 From: Sean Date: Sat, 22 Oct 2022 13:01:07 +0200 Subject: [PATCH 4/5] fix: fix other schedulers to accept step_index and generated timestep --- src/diffusers/schedulers/scheduling_ddim.py | 4 +- src/diffusers/schedulers/scheduling_ddpm.py | 4 +- .../scheduling_dpm2_ancestral_discrete.py | 119 +++++++++++++----- .../schedulers/scheduling_dpm2_discrete.py | 77 ++++++++---- .../schedulers/scheduling_euler_discrete.py | 69 ++++++---- .../schedulers/scheduling_heun_discrete.py | 72 +++++++---- .../schedulers/scheduling_karras_ve.py | 4 +- .../schedulers/scheduling_lms_discrete.py | 6 +- src/diffusers/schedulers/scheduling_pndm.py | 4 +- 9 files changed, 255 insertions(+), 104 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 0f1a40229475..ba3ca2c4582a 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -151,7 +151,9 @@ def __init__( self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + def scale_model_input( + self, sample: torch.FloatTensor, + *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d51d58ac8f45..f7085adbe555 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -145,7 +145,9 @@ def __init__( self.variance_type = variance_type - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + def scale_model_input( + self, sample: torch.FloatTensor, + *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. diff --git a/src/diffusers/schedulers/scheduling_dpm2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_dpm2_ancestral_discrete.py index cf859df67aeb..f7b10f7c6455 100644 --- a/src/diffusers/schedulers/scheduling_dpm2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_dpm2_ancestral_discrete.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Optional, Tuple, Union +from diffusers.models.unet_2d_condition import UNet2DConditionModel import numpy as np import torch @@ -58,32 +59,36 @@ def __init__( beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) - if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + self.init_noise_sigma = None # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) self.derivatives = [] + self.is_scale_input_called = False - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - - def set_timesteps(self, num_inference_steps: int): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -92,24 +97,46 @@ def set_timesteps(self, num_inference_steps: int): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps - self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) low_idx = np.floor(self.timesteps).astype(int) high_idx = np.ceil(self.timesteps).astype(int) frac = np.mod(self.timesteps, 1.0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(self.timesteps) + self.init_noise_sigma = self.sigmas[0] self.derivatives = [] - self.set_format(tensor_format=self.tensor_format) + def scale_model_input( + self, sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor], + **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample def step( self, model_output: Union[torch.FloatTensor, np.ndarray], - timestep: int, + timestep: Union[float, torch.FloatTensor], + step_index: Union[int, torch.IntTensor], sample: Union[torch.FloatTensor, np.ndarray], + # unet: UNet2DConditionModel = None, + # encoder_hidden_states: torch.Tensor, s_churn: float = 0., s_tmin: float = 0., s_tmax: float = float('inf'), @@ -137,19 +164,17 @@ def step( returning a tuple, the first element is the sample tensor. """ - sigma = self.sigmas[timestep] - + sigma = self.sigmas[step_index] + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise pred_original_sample = sample - sigma * model_output - sigma_from = sigma - sigma_to = self.sigmas[timestep + 1] + sigma_from = self.sigmas[step_index] + sigma_to = self.sigmas[step_index + 1] sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 - # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma self.derivatives.append(derivative) - # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule sigma_mid = ((sigma ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 @@ -162,6 +187,41 @@ def step( sample = sample + torch.randn_like(sample) * sigma_up prev_sample = sample + # if sigma_to == 0: + # # Euler method + # dt = sigma_down - sigma + # sample = sample + derivative * dt + # else: + # # DPM-Solver-2 + # # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule + # # sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp() + # sigma_mid = sigma.log().lerp(sigma_down.log(), 0.5).exp() + # # sigma_mid = ((sigma ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 + # dt_1 = sigma_mid - sigma + # dt_2 = sigma_down - sigma + # sample_2 = sample + derivative * dt_1 + # pred_original_sample_2 = sample_2 - sigma_mid * model_output + + # derivative_2 = unet(pred_original_sample_2, sigma_mid, encoder_hidden_states=encoder_hidden_states) + # sample = sample + derivative_2 * dt_2 + # sample = sample + torch.randn_like(sample) * sigma_up + + # prev_sample = sample + + # if sigma_down == 0: + # # Euler method + # dt = sigma_down - sigmas[i] + # x = x + d * dt + # else: + # # DPM-Solver-2 + # sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp() + # dt_1 = sigma_mid - sigmas[i] + # dt_2 = sigma_down - sigmas[i] + # x_2 = x + d * dt_1 + # denoised_2 = model(x_2, sigma_mid * s_in, **extra_args) + # d_2 = to_d(x_2, sigma_mid, denoised_2) + # x = x + d_2 * dt_2 + # x = x + torch.randn_like(x) * sigma_up if not return_dict: return (prev_sample,) @@ -174,11 +234,14 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.sigmas.device) - sigmas = self.match_shape(self.sigmas[timesteps], noise) - noisy_samples = original_samples + noise * sigmas - + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + self.timesteps = self.timesteps.to(original_samples.device) + sigma = self.sigmas[timesteps].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_dpm2_discrete.py b/src/diffusers/schedulers/scheduling_dpm2_discrete.py index 7083f48aaeac..b10ae0bdb28d 100644 --- a/src/diffusers/schedulers/scheduling_dpm2_discrete.py +++ b/src/diffusers/schedulers/scheduling_dpm2_discrete.py @@ -58,32 +58,36 @@ def __init__( beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) - if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + self.init_noise_sigma = None # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) self.derivatives = [] + self.is_scale_input_called = False - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) - - def set_timesteps(self, num_inference_steps: int): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -92,23 +96,43 @@ def set_timesteps(self, num_inference_steps: int): the number of diffusion steps used when generating samples with a pre-trained model. """ self.num_inference_steps = num_inference_steps - self.timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) + self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float) low_idx = np.floor(self.timesteps).astype(int) high_idx = np.ceil(self.timesteps).astype(int) frac = np.mod(self.timesteps, 1.0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(self.timesteps) + self.init_noise_sigma = self.sigmas[0] self.derivatives = [] - self.set_format(tensor_format=self.tensor_format) + def scale_model_input( + self, sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor], + **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample def step( self, model_output: Union[torch.FloatTensor, np.ndarray], - timestep: int, + timestep: Union[float, torch.FloatTensor], + step_index: Union[int, torch.IntTensor], sample: Union[torch.FloatTensor, np.ndarray], s_churn: float = 0., s_tmin: float = 0., @@ -137,7 +161,7 @@ def step( returning a tuple, the first element is the sample tensor. """ - sigma = self.sigmas[timestep] + sigma = self.sigmas[step_index] gamma = min(s_churn / (len(self.sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0. eps = torch.randn_like(sample) * s_noise sigma_hat = sigma * (gamma + 1) @@ -150,9 +174,9 @@ def step( derivative = (sample - pred_original_sample) / sigma_hat self.derivatives.append(derivative) - sigma_mid = ((sigma_hat ** (1 / 3) + self.sigmas[timestep + 1] ** (1 / 3)) / 2) ** 3 + sigma_mid = ((sigma_hat ** (1 / 3) + self.sigmas[step_index + 1] ** (1 / 3)) / 2) ** 3 dt_1 = sigma_mid - sigma_hat - dt_2 = self.sigmas[timestep + 1] - sigma_hat + dt_2 = self.sigmas[step_index + 1] - sigma_hat sample_2 = sample + derivative * dt_1 pred_original_sample_2 = sample_2 - sigma_mid * model_output derivative_2 = (sample_2 - pred_original_sample_2) / sigma_mid @@ -171,11 +195,14 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.sigmas.device) - sigmas = self.match_shape(self.sigmas[timesteps], noise) - noisy_samples = original_samples + noise * sigmas - + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + self.timesteps = self.timesteps.to(original_samples.device) + sigma = self.sigmas[timesteps].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index df60266801f8..33cb192c554e 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -58,30 +58,51 @@ def __init__( beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) - if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + self.init_noise_sigma = None # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) self.derivatives = [] + self.is_scale_input_called = False + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor] + ) -> torch.FloatTensor: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + + Returns: + `torch.FloatTensor`: scaled input sample + """ + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample def set_timesteps(self, num_inference_steps: int): """ @@ -99,16 +120,17 @@ def set_timesteps(self, num_inference_steps: int): frac = np.mod(self.timesteps, 1.0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(self.timesteps) + self.init_noise_sigma = self.sigmas[0] self.derivatives = [] - self.set_format(tensor_format=self.tensor_format) - def step( self, model_output: Union[torch.FloatTensor, np.ndarray], - timestep: int, + timestep: Union[float, torch.FloatTensor], + step_index: Union[int, torch.IntTensor], sample: Union[torch.FloatTensor, np.ndarray], s_churn: float = 0., s_tmin: float = 0., @@ -137,7 +159,7 @@ def step( returning a tuple, the first element is the sample tensor. """ - sigma = self.sigmas[timestep] + sigma = self.sigmas[step_index] gamma = min(s_churn / (len(self.sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0. eps = torch.randn_like(sample) * s_noise sigma_hat = sigma * (gamma + 1) @@ -150,7 +172,7 @@ def step( derivative = (sample - pred_original_sample) / sigma_hat self.derivatives.append(derivative) - dt = self.sigmas[timestep + 1] - sigma_hat + dt = self.sigmas[step_index + 1] - sigma_hat prev_sample = sample + derivative * dt @@ -165,11 +187,14 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.sigmas.device) - sigmas = self.match_shape(self.sigmas[timesteps], noise) - noisy_samples = original_samples + noise * sigmas - + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + self.timesteps = self.timesteps.to(original_samples.device) + sigma = self.sigmas[timesteps].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index c92ef5276db4..25bdbfd79edf 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -58,30 +58,50 @@ def __init__( beta_end: float = 0.012, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, - tensor_format: str = "pt", ): if trained_betas is not None: - self.betas = np.asarray(trained_betas) - if beta_schedule == "linear": - self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) + self.betas = torch.from_numpy(trained_betas) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2 + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.init_noise_sigma = None + timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) self.derivatives = [] - self.tensor_format = tensor_format - self.set_format(tensor_format=tensor_format) + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], step_index: Union[int, torch.IntTensor] + ) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + return sample def set_timesteps(self, num_inference_steps: int): """ @@ -99,16 +119,17 @@ def set_timesteps(self, num_inference_steps: int): frac = np.mod(self.timesteps, 1.0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(self.timesteps) + self.init_noise_sigma = self.sigmas[0] self.derivatives = [] - self.set_format(tensor_format=self.tensor_format) - def step( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, + step_index: int, sample: Union[torch.FloatTensor, np.ndarray], s_churn: float = 0., s_tmin: float = 0., @@ -137,7 +158,7 @@ def step( returning a tuple, the first element is the sample tensor. """ - sigma = self.sigmas[timestep] + sigma = self.sigmas[step_index] gamma = min(s_churn / (len(self.sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigma <= s_tmax else 0. eps = torch.randn_like(sample) * s_noise sigma_hat = sigma * (gamma + 1) @@ -150,15 +171,15 @@ def step( derivative = (sample - pred_original_sample) / sigma_hat self.derivatives.append(derivative) - dt = self.sigmas[timestep + 1] - sigma_hat - if self.sigmas[timestep + 1] == 0: + dt = self.sigmas[step_index + 1] - sigma_hat + if self.sigmas[step_index + 1] == 0: # Euler method sample = sample + derivative * dt else: # Heun's method sample_2 = sample + derivative * dt - pred_original_sample_2 = sample_2 - self.sigmas[timestep + 1] * model_output - derivative_2 = (sample_2 - pred_original_sample_2) / self.sigmas[timestep + 1] + pred_original_sample_2 = sample_2 - self.sigmas[step_index + 1] * model_output + derivative_2 = (sample_2 - pred_original_sample_2) / self.sigmas[step_index + 1] d_prime = (derivative + derivative_2) / 2 sample = sample + d_prime * dt @@ -175,11 +196,14 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: - if self.tensor_format == "pt": - timesteps = timesteps.to(self.sigmas.device) - sigmas = self.match_shape(self.sigmas[timesteps], noise) - noisy_samples = original_samples + noise * sigmas - + # Make sure sigmas and timesteps have the same device and dtype as original_samples + self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + self.timesteps = self.timesteps.to(original_samples.device) + sigma = self.sigmas[timesteps].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 743f2e061c53..1e4ee7623d6e 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -95,7 +95,9 @@ def __init__( self.timesteps: np.IntTensor = None self.schedule: torch.FloatTensor = None # sigma(t_i) - def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + def scale_model_input( + self, sample: torch.FloatTensor, + *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 1b8ca7c5df8d..9b86aa1b4550 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -105,7 +105,10 @@ def __init__( self.is_scale_input_called = False def scale_model_input( - self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + step_index: Union[int, torch.IntTensor] ) -> torch.FloatTensor: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. @@ -172,6 +175,7 @@ def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], + step_index: Union[int, torch.IntTensor], sample: torch.FloatTensor, order: int = 4, return_dict: bool = True, diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 99ccc6c66f20..aaa97920b63f 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -337,7 +337,9 @@ def step_plms( return SchedulerOutput(prev_sample=prev_sample) - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + def scale_model_input( + self, sample: torch.FloatTensor, + *args, **kwargs) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. From 1bbaa207f7745f43532f443a9177586043cc687d Mon Sep 17 00:00:00 2001 From: Sean Date: Sat, 22 Oct 2022 13:46:57 +0200 Subject: [PATCH 5/5] fix: missing arg in pndm, ddim, ddpm --- src/diffusers/schedulers/scheduling_ddim.py | 1 + src/diffusers/schedulers/scheduling_ddpm.py | 1 + src/diffusers/schedulers/scheduling_pndm.py | 1 + 3 files changed, 3 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index ba3ca2c4582a..b68f94c286d0 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -202,6 +202,7 @@ def step( self, model_output: torch.FloatTensor, timestep: int, + step_index:int, sample: torch.FloatTensor, eta: float = 0.0, use_clipped_model_output: bool = False, diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index f7085adbe555..1897c1af1d84 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -213,6 +213,7 @@ def step( self, model_output: torch.FloatTensor, timestep: int, + step_index:int, sample: torch.FloatTensor, predict_epsilon=True, generator=None, diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index aaa97920b63f..9d05f4c50e8b 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -189,6 +189,7 @@ def step( self, model_output: torch.FloatTensor, timestep: int, + step_index:int, sample: torch.FloatTensor, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: