From dceb05ce20f9adf185e4eb984ccd986f06f4e21c Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 29 Aug 2022 14:20:06 +0000 Subject: [PATCH 01/53] initial commit --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f0b353d931d4..f2faf009896e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -80,7 +80,7 @@ def __call__( truncation=True, return_tensors="pt", ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device, non_blocking=True))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -92,7 +92,7 @@ def __call__( uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device, non_blocking=True))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -110,7 +110,7 @@ def __call__( else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + latents = latents.to(self.device, non_blocking=True) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) From 45696fd0b871dbff0a0458d6df4458128dd51c01 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Mon, 29 Aug 2022 21:40:50 +0000 Subject: [PATCH 02/53] make UNet stream capturable --- src/diffusers/models/embeddings.py | 2 +- src/diffusers/models/unet_2d_condition.py | 2 +- .../pipeline_stable_diffusion.py | 31 ++++++++++++++++--- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 8d1052173e66..e3144649baf3 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -35,7 +35,7 @@ def get_timestep_embedding( exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) exponent = exponent / (half_dim - downscale_freq_shift) - emb = torch.exp(exponent).to(device=timesteps.device) + emb = torch.exp(exponent).to(device=timesteps.device, non_blocking=True) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 25c4e37d8a6d..ced46569f7e5 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -131,7 +131,7 @@ def forward( if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) + timesteps = timesteps[None].to(sample.device, non_blocking=True) # broadcast to batch dimension timesteps = timesteps.broadcast_to(sample.shape[0]) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f2faf009896e..3f6fd2f3cc1a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -133,6 +133,29 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta + + # warmup + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + t = self.scheduler.timesteps[1] + i=1 + if isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[i] + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for i in range(1): + with torch.no_grad(): + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] + torch.cuda.current_stream().wait_stream(s) + + # capture + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + with torch.no_grad(): + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] + + for i, t in tqdm(enumerate(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents @@ -141,18 +164,18 @@ def __call__( latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] + g.replay() # replay the graph and updates outputs # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred_cond = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_pred_cond, i, latents, **extra_step_kwargs)["prev_sample"] else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_pred_cond, t, latents, **extra_step_kwargs)["prev_sample"] # scale and decode the image latents with vae latents = 1 / 0.18215 * latents From d5edb1916ef6745822f8f1902a6e80521093fcd8 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Mon, 29 Aug 2022 22:05:46 +0000 Subject: [PATCH 03/53] try to fix noise_pred value --- .../stable_diffusion/pipeline_stable_diffusion.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 3f6fd2f3cc1a..0945cc189f59 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -144,7 +144,7 @@ def __call__( s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): - for i in range(1): + for i in range(11): with torch.no_grad(): noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] torch.cuda.current_stream().wait_stream(s) @@ -155,16 +155,22 @@ def __call__( with torch.no_grad(): noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] + print(noise_pred.mean()) + g.replay() + print(noise_pred.mean()) - for i, t in tqdm(enumerate(self.scheduler.timesteps)): + for i, timestep in tqdm(enumerate(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input.copy_(torch.cat([latents] * 2) if do_classifier_free_guidance else latents) if isinstance(self.scheduler, LMSDiscreteScheduler): sigma = self.scheduler.sigmas[i] - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) + latent_model_input.copy_(latent_model_input / ((sigma**2 + 1) ** 0.5)) + + t.copy_(timestep) # predict the noise residual g.replay() # replay the graph and updates outputs + print(noise_pred.mean()) # perform guidance if do_classifier_free_guidance: From 89143b176050ce70c3e59d7d2f2291b18a9f4481 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Tue, 30 Aug 2022 09:22:37 +0000 Subject: [PATCH 04/53] remove cuda graph and keep NB --- .../pipeline_stable_diffusion.py | 43 +++---------------- 1 file changed, 7 insertions(+), 36 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 0945cc189f59..f2faf009896e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -133,55 +133,26 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta - - # warmup - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - t = self.scheduler.timesteps[1] - i=1 - if isinstance(self.scheduler, LMSDiscreteScheduler): - sigma = self.scheduler.sigmas[i] - latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - for i in range(11): - with torch.no_grad(): - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] - torch.cuda.current_stream().wait_stream(s) - - # capture - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - with torch.no_grad(): - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] - - print(noise_pred.mean()) - g.replay() - print(noise_pred.mean()) - - for i, timestep in tqdm(enumerate(self.scheduler.timesteps)): + for i, t in tqdm(enumerate(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance - latent_model_input.copy_(torch.cat([latents] * 2) if do_classifier_free_guidance else latents) + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents if isinstance(self.scheduler, LMSDiscreteScheduler): sigma = self.scheduler.sigmas[i] - latent_model_input.copy_(latent_model_input / ((sigma**2 + 1) ** 0.5)) - - t.copy_(timestep) + latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual - g.replay() # replay the graph and updates outputs - print(noise_pred.mean()) + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred_cond = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred_cond, i, latents, **extra_step_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"] else: - latents = self.scheduler.step(noise_pred_cond, t, latents, **extra_step_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] # scale and decode the image latents with vae latents = 1 / 0.18215 * latents From 2f4a34b5da9afb2bf91aff7c5ff7fe8eb675f5d3 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Tue, 30 Aug 2022 10:23:22 +0000 Subject: [PATCH 05/53] non blocking unet with PNDMScheduler --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f2faf009896e..296c13c54cf9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -119,7 +119,8 @@ def __call__( extra_set_kwargs["offset"] = 1 self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) - + self.scheduler.timesteps = torch.tensor(self.scheduler.timesteps, device=self.device) + # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): latents = latents * self.scheduler.sigmas[0] From e6e41ae4dbf7dca4812c09f9fa1f2ba5bff4574b Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Tue, 30 Aug 2022 13:24:43 +0000 Subject: [PATCH 06/53] make timesteps np arrays for pndm scheduler because lists don't get formatted to tensors in `self.set_format` --- src/diffusers/schedulers/scheduling_pndm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index cd1d2bb2a701..7ad810ed0f4a 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -109,16 +109,16 @@ def set_timesteps(self, num_inference_steps, offset=0): # for some models like stable diffusion the prk steps can/should be skipped to # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 - self.prk_timesteps = [] - self.plms_timesteps = list(reversed(self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:])) + self.prk_timesteps = np.array([]) + self.plms_timesteps = (self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:])[::-1].copy() else: prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order ) - self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1])) - self.plms_timesteps = list(reversed(self._timesteps[:-3])) + self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() + self.plms_timesteps = self._timesteps[:-3][::-1].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy - self.timesteps = self.prk_timesteps + self.plms_timesteps + self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]) self.ets = [] self.counter = 0 From 95051ae34f94b7dd04fcad2c6805c66531773e96 Mon Sep 17 00:00:00 2001 From: NouamaneTazi Date: Tue, 30 Aug 2022 13:24:59 +0000 Subject: [PATCH 07/53] make max async in pndm --- src/diffusers/schedulers/scheduling_pndm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 7ad810ed0f4a..9312dae9e076 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -145,8 +145,8 @@ def step_prk( Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. """ - diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 - prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) + diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 #TODO: check if we still have condition in cuda graph + prev_timestep = torch.max(torch.tensor(timestep - diff_to_prev), self.prk_timesteps[-1]) timestep = self.prk_timesteps[self.counter // 4 * 4] if self.counter % 4 == 0: @@ -187,7 +187,7 @@ def step_plms( "for more information." ) - prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) + prev_timestep = torch.max(timestep - self.config.num_train_timesteps // self.num_inference_steps, torch.tensor(0)) if self.counter != 1: self.ets.append(model_output) From 0b85b6f5b904f531edea00e031ac5a0e77130892 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 6 Sep 2022 10:13:36 +0000 Subject: [PATCH 08/53] use channel last format in unet --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 36ef4a195ff3..b701db82b33f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -30,7 +30,7 @@ def __init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, - unet=unet, + unet=unet.to(memory_format=torch.channels_last), scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, @@ -137,6 +137,7 @@ def __call__( for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(memory_format=torch.channels_last) if isinstance(self.scheduler, LMSDiscreteScheduler): sigma = self.scheduler.sigmas[i] # the model input needs to be scaled to match the continuous ODE formulation in K-LMS From 101b8b09c51ae273018e0997360ca236d203ccaf Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 6 Sep 2022 10:29:41 +0000 Subject: [PATCH 09/53] avoid moving timesteps device in each unet call --- src/diffusers/models/unet_2d_condition.py | 2 +- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index e9aeb5eb0649..80d87ad7a738 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -149,7 +149,7 @@ def forward( if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) + timesteps = timesteps[None] # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b701db82b33f..3f05f6390bfe 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -120,6 +120,7 @@ def __call__( extra_set_kwargs["offset"] = 1 self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + self.scheduler.timesteps = self.scheduler.timesteps.to(self.device) # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): From 493a64a0a609738abd7b3f5faa92f64d595506b1 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 6 Sep 2022 10:38:07 +0000 Subject: [PATCH 10/53] avoid memcpy op in `get_timestep_embedding` --- src/diffusers/models/embeddings.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 8d1052173e66..99bfa96f0d92 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -32,10 +32,10 @@ def get_timestep_embedding( assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) - emb = torch.exp(exponent).to(device=timesteps.device) + emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings From 75b4f0c07133582e726c4ad2ae7a4188c9d82b51 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 6 Sep 2022 19:09:01 +0000 Subject: [PATCH 11/53] add `channels_last` kwarg to `DiffusionPipeline.from_pretrained` --- src/diffusers/pipeline_utils.py | 6 ++++++ .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 3 +-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index fc2bc7bcf414..c4ed955afb3c 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -165,6 +165,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) + channels_last = kwargs.pop("channels_last", False) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -267,6 +268,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # else load from the root directory loaded_sub_model = load_method(cached_folder, **loading_kwargs) + if channels_last and issubclass(class_obj, torch.nn.Module) and name == "unet": + #TODO(nouamane): it seems we don't need to specify inputs' memory format for + # se we only apply this to the model + loaded_sub_model = loaded_sub_model.to(memory_format=torch.channels_last) + init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) # 4. Instantiate the pipeline diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 3f05f6390bfe..98dd34f2ba02 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -30,7 +30,7 @@ def __init__( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, - unet=unet.to(memory_format=torch.channels_last), + unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, @@ -138,7 +138,6 @@ def __call__( for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = latent_model_input.to(memory_format=torch.channels_last) if isinstance(self.scheduler, LMSDiscreteScheduler): sigma = self.scheduler.sigmas[i] # the model input needs to be scaled to match the continuous ODE formulation in K-LMS From 1639f69d5ae5b158a0477e6cd4f5c9cbde7fb19b Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 6 Sep 2022 19:12:38 +0000 Subject: [PATCH 12/53] update TODO --- src/diffusers/pipeline_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index c4ed955afb3c..58b94a9e34f4 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -269,8 +269,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P loaded_sub_model = load_method(cached_folder, **loading_kwargs) if channels_last and issubclass(class_obj, torch.nn.Module) and name == "unet": - #TODO(nouamane): it seems we don't need to specify inputs' memory format for + #TODO(nouamane): it seems we don't need to specify memory format for inputs # se we only apply this to the model + #TODO(nouamane): check which models benefit from channels last loaded_sub_model = loaded_sub_model.to(memory_format=torch.channels_last) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) From f2176e94ff9772d197979a06997378a6bb11f5ef Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 6 Sep 2022 19:37:12 +0000 Subject: [PATCH 13/53] replace `channels_last` kwarg with `memory_format` for more generality --- src/diffusers/pipeline_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 58b94a9e34f4..a01e8d5cce7e 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -165,7 +165,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) - channels_last = kwargs.pop("channels_last", False) + memory_format = kwargs.pop("memory_format", torch.preserve_format) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -268,11 +268,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # else load from the root directory loaded_sub_model = load_method(cached_folder, **loading_kwargs) - if channels_last and issubclass(class_obj, torch.nn.Module) and name == "unet": + if issubclass(class_obj, torch.nn.Module) and name == "unet": #TODO(nouamane): it seems we don't need to specify memory format for inputs # se we only apply this to the model #TODO(nouamane): check which models benefit from channels last - loaded_sub_model = loaded_sub_model.to(memory_format=torch.channels_last) + loaded_sub_model = loaded_sub_model.to(memory_format=memory_format) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) From 08db0c3f1bcca37fb3a5fa3939ff68f14ea3a268 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 8 Sep 2022 07:23:17 +0000 Subject: [PATCH 14/53] revert the channels_last changes to leave it for another PR --- src/diffusers/pipeline_utils.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index a01e8d5cce7e..fc2bc7bcf414 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -165,7 +165,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) - memory_format = kwargs.pop("memory_format", torch.preserve_format) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -268,12 +267,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # else load from the root directory loaded_sub_model = load_method(cached_folder, **loading_kwargs) - if issubclass(class_obj, torch.nn.Module) and name == "unet": - #TODO(nouamane): it seems we don't need to specify memory format for inputs - # se we only apply this to the model - #TODO(nouamane): check which models benefit from channels last - loaded_sub_model = loaded_sub_model.to(memory_format=memory_format) - init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) # 4. Instantiate the pipeline From 96477527eb52572db2359048d67cc6ed03a5a0a9 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 13 Sep 2022 18:00:24 +0000 Subject: [PATCH 15/53] remove non_blocking when moving input ids to device --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 77661db1ca7b..451171aeeea6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -178,7 +178,7 @@ def __call__( truncation=True, return_tensors="pt", ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device, non_blocking=True))[0] + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` From cab7b285bb6c709c7e8c5e09718950c8575fb7d3 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 13 Sep 2022 18:06:11 +0000 Subject: [PATCH 16/53] remove blocking from all .to() operations at beginning of pipeline --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 451171aeeea6..37a2737c959e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -190,7 +190,7 @@ def __call__( uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device, non_blocking=True))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -208,7 +208,7 @@ def __call__( else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device, non_blocking=True) + latents = latents.to(self.device) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) From acb8397e228497dd45a61f74d242e7a09836f58b Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 13 Sep 2022 18:11:48 +0000 Subject: [PATCH 17/53] fix merging --- src/diffusers/schedulers/scheduling_pndm.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index be9ef4687e22..171b50989818 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -147,15 +147,19 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.Floa # produce better results. When using PNDM with `self.config.skip_prk_steps` the implementation # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51 self.prk_timesteps = np.array([]) - self.plms_timesteps = (self._timesteps[:-1] + self._timesteps[-2:-1] + self._timesteps[-1:])[::-1].copy() + self.plms_timesteps = np.concatenate([self._timesteps[:-1], self._timesteps[-2:-1], self._timesteps[-1:]])[ + ::-1 + ].copy() else: prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order ) self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy() - self.plms_timesteps = self._timesteps[:-3][::-1].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy + self.plms_timesteps = self._timesteps[:-3][ + ::-1 + ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy - self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]) + self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64) self.ets = [] self.counter = 0 @@ -212,8 +216,13 @@ def step_prk( prev_sample (`SchedulerOutput` or `Tuple`): updated sample in the diffusion chain. """ - diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 #TODO: check if we still have condition in cuda graph - prev_timestep = torch.max(torch.tensor(timestep - diff_to_prev), self.prk_timesteps[-1]) + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 + prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) timestep = self.prk_timesteps[self.counter // 4 * 4] if self.counter % 4 == 0: @@ -274,7 +283,7 @@ def step_plms( "for more information." ) - prev_timestep = torch.max(timestep - self.config.num_train_timesteps // self.num_inference_steps, torch.tensor(0)) + prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) if self.counter != 1: self.ets.append(model_output) From 39994ccf2637e4579b138550375bded04dc6babf Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 13 Sep 2022 20:58:02 +0000 Subject: [PATCH 18/53] fix merging --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 3918a616f370..99bfa96f0d92 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -35,7 +35,7 @@ def get_timestep_embedding( exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) - emb = torch.exp(exponent).to(device=timesteps.device, non_blocking=True) + emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings From d30f9682226b07f3e6e6234b442622a8921bd108 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 13 Sep 2022 22:27:19 +0000 Subject: [PATCH 19/53] model can run in other precisions without autocast --- src/diffusers/models/resnet.py | 4 ++-- src/diffusers/models/unet_2d_condition.py | 4 ++-- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 51efea9ee423..f98d11e4171e 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -333,7 +333,7 @@ def forward(self, x, temb): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: @@ -351,7 +351,7 @@ def forward(self, x, temb): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 5b02fd926081..10e2ece54634 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -177,7 +177,7 @@ def forward( timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) - emb = self.time_embedding(t_emb) + emb = self.time_embedding(t_emb.to(self.dtype)) # 2. pre-process sample = self.conv_in(sample) @@ -215,7 +215,7 @@ def forward( # 6. post-process # make sure hidden states is in float32 # when running in half-precision - sample = self.conv_norm_out(sample.float()).type(sample.dtype) + sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype) sample = self.conv_act(sample) sample = self.conv_out(sample) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 37a2737c959e..ba3e55dfa6eb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -204,6 +204,7 @@ def __call__( latents_shape, generator=generator, device=self.device, + dtype=text_embeddings.dtype, ) else: if latents.shape != latents_shape: @@ -263,7 +264,7 @@ def __call__( # run safety checker safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values.to(text_embeddings.dtype)) if output_type == "pil": image = self.numpy_to_pil(image) From 0c70c0e189cd2c4d8768274c9fcf5b940ee310fb Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 14 Sep 2022 12:36:38 +0000 Subject: [PATCH 20/53] attn refactoring --- src/diffusers/models/attention.py | 213 +++++++++++++++++++++--------- 1 file changed, 151 insertions(+), 62 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index a69d9014bdf6..f78777d4ac72 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1,4 +1,5 @@ import math +from typing import Optional import torch import torch.nn.functional as F @@ -10,16 +11,24 @@ class AttentionBlock(nn.Module): An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (:obj:`int`): The number of channels in the input and output. + num_head_channels (:obj:`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ def __init__( self, - channels, - num_head_channels=None, - num_groups=32, - rescale_output_factor=1.0, - eps=1e-5, + channels: int, + num_head_channels: Optional[int] = None, + num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, ): super().__init__() self.channels = channels @@ -86,10 +95,26 @@ def forward(self, hidden_states): class SpatialTransformer(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image + standard transformer action. Finally, reshape to image. + + Parameters: + in_channels (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The number of context dimensions to use. """ - def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None): + def __init__( + self, + in_channels: int, + n_heads: int, + d_head: int, + depth: int = 1, + dropout: float = 0.0, + context_dim: Optional[int] = None, + ): super().__init__() self.n_heads = n_heads self.d_head = d_head @@ -112,22 +137,44 @@ def _set_attention_slice(self, slice_size): for block in self.transformer_blocks: block._set_attention_slice(slice_size) - def forward(self, x, context=None): + def forward(self, hidden_states, context=None): # note: if no context is given, cross-attention defaults to self-attention - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) for block in self.transformer_blocks: - x = block(x, context=context) - x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) - x = self.proj_out(x) - return x + x_in + hidden_states = block(hidden_states, context=context) + hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) + hidden_states = self.proj_out(hidden_states) + return hidden_states + residual class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True): + r""" + A basic Transformer block. + + Parameters: + dim (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. + gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. + checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. + """ + + def __init__( + self, + dim: int, + n_heads: int, + d_head: int, + dropout=0.0, + context_dim: Optional[int] = None, + gated_ff: bool = True, + checkpoint: bool = True, + ): super().__init__() self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout @@ -145,15 +192,30 @@ def _set_attention_slice(self, slice_size): self.attn1._slice_size = slice_size self.attn2._slice_size = slice_size - def forward(self, x, context=None): - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x + def forward(self, hidden_states, context=None): + hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states + hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states + hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + return hidden_states class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + r""" + A cross attention layer. + + Parameters: + query_dim (:obj:`int`): The number of channels in the query. + context_dim (:obj:`int`, *optional*): + The number of channels in the context. If not given, defaults to `query_dim`. + heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 + ): super().__init__() inner_dim = dim_head * heads context_dim = context_dim if context_dim is not None else query_dim @@ -174,52 +236,58 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor + tensor2 = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor3 = tensor2.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor3 def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor + tensor2 = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor3 = tensor2.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor3 - def forward(self, x, context=None, mask=None): - batch_size, sequence_length, dim = x.shape + def forward(self, hidden_states, context=None, mask=None): + batch_size, sequence_length, dim = hidden_states.shape - q = self.to_q(x) - context = context if context is not None else x - k = self.to_k(context) - v = self.to_v(context) + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + key = self.to_k(context) + value = self.to_v(context) - q = self.reshape_heads_to_batch_dim(q) - k = self.reshape_heads_to_batch_dim(k) - v = self.reshape_heads_to_batch_dim(v) + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of - hidden_states = self._attention(q, k, v, sequence_length, dim) + hidden_states = self._attention(query, key, value, sequence_length, dim) return self.to_out(hidden_states) def _attention(self, query, key, value, sequence_length, dim): batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype - ) - slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] - for i in range(hidden_states.shape[0] // slice_size): - start_idx = i * slice_size - end_idx = (i + 1) * slice_size - attn_slice = ( - torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale - ) - attn_slice = attn_slice.softmax(dim=-1) - attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice + # hidden_states = torch.zeros( + # (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + # ) + slice_size = self._slice_size if self._slice_size is not None else batch_size_attention + # for i in range(hidden_states.shape[0] // slice_size): + # start_idx = i * slice_size + # end_idx = (i + 1) * slice_size + # qslice = query[start_idx:end_idx] + qslice = query + # kslice = key[start_idx:end_idx].transpose(1, 2) + kslice = key.transpose(1, 2) + attn_slice = torch.matmul(qslice, kslice) * self.scale + attn_slice = attn_slice.softmax(dim=-1) + # vslice = value[start_idx:end_idx] + vslice = value + hidden_states = torch.matmul(attn_slice, vslice) + + + # hidden_states = torch.cat(attn_slices, dim=0) + # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) @@ -227,7 +295,20 @@ def _attention(self, query, key, value, sequence_length, dim): class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + r""" + A feed-forward layer. + + Parameters: + dim (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 + ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim @@ -235,16 +316,24 @@ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) - def forward(self, x): - return self.net(x) + def forward(self, hidden_states): + return self.net(hidden_states) # feedforward class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * F.gelu(gate) From e422eb3738a2592ab8e4f11c91caa50d6bee9d2e Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 14 Sep 2022 12:37:19 +0000 Subject: [PATCH 21/53] Revert "attn refactoring" This reverts commit 0c70c0e189cd2c4d8768274c9fcf5b940ee310fb. --- src/diffusers/models/attention.py | 213 +++++++++--------------------- 1 file changed, 62 insertions(+), 151 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f78777d4ac72..a69d9014bdf6 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -1,5 +1,4 @@ import math -from typing import Optional import torch import torch.nn.functional as F @@ -11,24 +10,16 @@ class AttentionBlock(nn.Module): An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - - Parameters: - channels (:obj:`int`): The number of channels in the input and output. - num_head_channels (:obj:`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + Uses three q, k, v linear layers to compute attention """ def __init__( self, - channels: int, - num_head_channels: Optional[int] = None, - num_groups: int = 32, - rescale_output_factor: float = 1.0, - eps: float = 1e-5, + channels, + num_head_channels=None, + num_groups=32, + rescale_output_factor=1.0, + eps=1e-5, ): super().__init__() self.channels = channels @@ -95,26 +86,10 @@ def forward(self, hidden_states): class SpatialTransformer(nn.Module): """ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image. - - Parameters: - in_channels (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. - depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The number of context dimensions to use. + standard transformer action. Finally, reshape to image """ - def __init__( - self, - in_channels: int, - n_heads: int, - d_head: int, - depth: int = 1, - dropout: float = 0.0, - context_dim: Optional[int] = None, - ): + def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None): super().__init__() self.n_heads = n_heads self.d_head = d_head @@ -137,44 +112,22 @@ def _set_attention_slice(self, slice_size): for block in self.transformer_blocks: block._set_attention_slice(slice_size) - def forward(self, hidden_states, context=None): + def forward(self, x, context=None): # note: if no context is given, cross-attention defaults to self-attention - batch, channel, height, weight = hidden_states.shape - residual = hidden_states - hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) - hidden_states = self.proj_out(hidden_states) - return hidden_states + residual + x = block(x, context=context) + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) + x = self.proj_out(x) + return x + x_in class BasicTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. - gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. - checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. - """ - - def __init__( - self, - dim: int, - n_heads: int, - d_head: int, - dropout=0.0, - context_dim: Optional[int] = None, - gated_ff: bool = True, - checkpoint: bool = True, - ): + def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff=True, checkpoint=True): super().__init__() self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout @@ -192,30 +145,15 @@ def _set_attention_slice(self, slice_size): self.attn1._slice_size = slice_size self.attn2._slice_size = slice_size - def forward(self, hidden_states, context=None): - hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states - hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states - hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states - hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states - return hidden_states + def forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x class CrossAttention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (:obj:`int`): The number of channels in the query. - context_dim (:obj:`int`, *optional*): - The number of channels in the context. If not given, defaults to `query_dim`. - heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - """ - - def __init__( - self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 - ): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads context_dim = context_dim if context_dim is not None else query_dim @@ -236,58 +174,52 @@ def __init__( def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads - tensor2 = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor3 = tensor2.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor3 + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads - tensor2 = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor3 = tensor2.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor3 + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor - def forward(self, hidden_states, context=None, mask=None): - batch_size, sequence_length, dim = hidden_states.shape + def forward(self, x, context=None, mask=None): + batch_size, sequence_length, dim = x.shape - query = self.to_q(hidden_states) - context = context if context is not None else hidden_states - key = self.to_k(context) - value = self.to_v(context) + q = self.to_q(x) + context = context if context is not None else x + k = self.to_k(context) + v = self.to_v(context) - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) + q = self.reshape_heads_to_batch_dim(q) + k = self.reshape_heads_to_batch_dim(k) + v = self.reshape_heads_to_batch_dim(v) # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of - hidden_states = self._attention(query, key, value, sequence_length, dim) + hidden_states = self._attention(q, k, v, sequence_length, dim) return self.to_out(hidden_states) def _attention(self, query, key, value, sequence_length, dim): batch_size_attention = query.shape[0] - # hidden_states = torch.zeros( - # (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype - # ) - slice_size = self._slice_size if self._slice_size is not None else batch_size_attention - # for i in range(hidden_states.shape[0] // slice_size): - # start_idx = i * slice_size - # end_idx = (i + 1) * slice_size - # qslice = query[start_idx:end_idx] - qslice = query - # kslice = key[start_idx:end_idx].transpose(1, 2) - kslice = key.transpose(1, 2) - attn_slice = torch.matmul(qslice, kslice) * self.scale - attn_slice = attn_slice.softmax(dim=-1) - # vslice = value[start_idx:end_idx] - vslice = value - hidden_states = torch.matmul(attn_slice, vslice) - - - # hidden_states = torch.cat(attn_slices, dim=0) - + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + attn_slice = ( + torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale + ) + attn_slice = attn_slice.softmax(dim=-1) + attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) @@ -295,20 +227,7 @@ def _attention(self, query, key, value, sequence_length, dim): class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - """ - - def __init__( - self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 - ): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim @@ -316,24 +235,16 @@ def __init__( self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) - def forward(self, hidden_states): - return self.net(hidden_states) + def forward(self, x): + return self.net(x) # feedforward class GEGLU(nn.Module): - r""" - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - - Parameters: - dim_in (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`): The number of channels in the output. - """ - - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) - def forward(self, hidden_states): - hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) - return hidden_states * F.gelu(gate) + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) From cec592890c32da3d1b78d38b49e4307aedf459b9 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 21 Sep 2022 13:38:41 +0000 Subject: [PATCH 22/53] remove restriction to run conv_norm in fp32 --- src/diffusers/models/resnet.py | 6 ++++-- src/diffusers/models/unet_2d_condition.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 785a4b91353e..36e9dd611e63 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -332,7 +332,8 @@ def forward(self, x, temb): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm1(hidden_states).type(hidden_states.dtype) + # hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: @@ -350,7 +351,8 @@ def forward(self, x, temb): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm2(hidden_states).type(hidden_states.dtype) + # hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 42b54657d290..e584520e94a8 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -261,7 +261,8 @@ def forward( # 6. post-process # make sure hidden states is in float32 # when running in half-precision - sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype) + # sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype) + sample = self.conv_norm_out(sample).type(sample.dtype) sample = self.conv_act(sample) sample = self.conv_out(sample) From c0dd0e90e8de519bff281d5396d6fd50dae49a0d Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 21 Sep 2022 14:11:34 +0000 Subject: [PATCH 23/53] use `baddbmm` instead of `matmul`for better in attention for better perf --- src/diffusers/models/attention.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e4cedbff8c9a..02454d603632 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -73,7 +73,13 @@ def forward(self, hidden_states): # get scores scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) + attention_scores = torch.baddbmm( + torch.empty(query_states.shape[0], query_states.shape[1], key_states.shape[1], dtype=query_states.dtype, device=query_states.device), + query_states, + key_states.transpose(-1, -2), + beta=0, + alpha=scale, + ) attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) # compute attention output @@ -272,7 +278,14 @@ def forward(self, hidden_states, context=None, mask=None): return self.to_out(hidden_states) def _attention(self, query, key, value): - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + # attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) attention_probs = attention_scores.softmax(dim=-1) # compute attention output hidden_states = torch.matmul(attention_probs, value) @@ -289,7 +302,7 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size - attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale + attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale # TODO: use baddbmm for better performance attn_slice = attn_slice.softmax(dim=-1) attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) From 006ccb8a8c6bc7eb7e512392e692a29d9b1553cd Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 21 Sep 2022 15:22:59 +0000 Subject: [PATCH 24/53] removing all reshapes to test perf --- src/diffusers/models/attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02454d603632..f3623c6e7ed5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -91,7 +91,7 @@ def forward(self, hidden_states): # compute next hidden_states hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + hidden_states = hidden_states.reshape(batch, channel, height, width) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor @@ -150,10 +150,10 @@ def forward(self, hidden_states, context=None): residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) + hidden_states = hidden_states.reshape(batch, height * weight, channel) for block in self.transformer_blocks: hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) + hidden_states = hidden_states.reshape(batch, channel, height, weight) hidden_states = self.proj_out(hidden_states) return hidden_states + residual @@ -262,9 +262,9 @@ def forward(self, hidden_states, context=None, mask=None): key = self.to_k(context) value = self.to_v(context) - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) + # query = self.reshape_heads_to_batch_dim(query) + # key = self.reshape_heads_to_batch_dim(key) + # value = self.reshape_heads_to_batch_dim(value) # TODO(PVP) - mask is currently never used. Remember to re-implement when used @@ -290,7 +290,7 @@ def _attention(self, query, key, value): # compute attention output hidden_states = torch.matmul(attention_probs, value) # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states def _sliced_attention(self, query, key, value, sequence_length, dim): @@ -309,7 +309,7 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): hidden_states[start_idx:end_idx] = attn_slice # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states From 75fa0297595efffdc021855c40a04ff81e521dea Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 21 Sep 2022 16:31:14 +0000 Subject: [PATCH 25/53] Revert "removing all reshapes to test perf" This reverts commit 006ccb8a8c6bc7eb7e512392e692a29d9b1553cd. --- src/diffusers/models/attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f3623c6e7ed5..02454d603632 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -91,7 +91,7 @@ def forward(self, hidden_states): # compute next hidden_states hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.reshape(batch, channel, height, width) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor @@ -150,10 +150,10 @@ def forward(self, hidden_states, context=None): residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) - hidden_states = hidden_states.reshape(batch, height * weight, channel) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) for block in self.transformer_blocks: hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, channel, height, weight) + hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) hidden_states = self.proj_out(hidden_states) return hidden_states + residual @@ -262,9 +262,9 @@ def forward(self, hidden_states, context=None, mask=None): key = self.to_k(context) value = self.to_v(context) - # query = self.reshape_heads_to_batch_dim(query) - # key = self.reshape_heads_to_batch_dim(key) - # value = self.reshape_heads_to_batch_dim(value) + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) # TODO(PVP) - mask is currently never used. Remember to re-implement when used @@ -290,7 +290,7 @@ def _attention(self, query, key, value): # compute attention output hidden_states = torch.matmul(attention_probs, value) # reshape hidden_states - # hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states def _sliced_attention(self, query, key, value, sequence_length, dim): @@ -309,7 +309,7 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): hidden_states[start_idx:end_idx] = attn_slice # reshape hidden_states - # hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states From 31c58eadb8892f95478cdf05229adf678678c5f4 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 21 Sep 2022 16:47:42 +0000 Subject: [PATCH 26/53] add shapes comments --- src/diffusers/models/attention.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02454d603632..22faaed35fa6 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -148,11 +148,11 @@ def forward(self, hidden_states, context=None): # note: if no context is given, cross-attention defaults to self-attention batch, channel, height, weight = hidden_states.shape residual = hidden_states - hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) + hidden_states = self.norm(hidden_states) # 2, 320, 64, 64 + hidden_states = self.proj_in(hidden_states) # 2, 320, 64, 64 hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=context) + hidden_states = block(hidden_states, context=context) # 2, 4096, 320 hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) hidden_states = self.proj_out(hidden_states) return hidden_states + residual @@ -241,10 +241,10 @@ def __init__( self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape + batch_size, seq_len, dim = tensor.shape # 2, 4096, 320 head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) # 2, 4096, 8, 40 + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) # 16, 4096, 40 return tensor def reshape_batch_dim_to_heads(self, tensor): @@ -271,7 +271,7 @@ def forward(self, hidden_states, context=None, mask=None): # attention, what we cannot get enough of if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value) + hidden_states = self._attention(query, key, value) # 2, 4096, 320 else: hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) @@ -286,11 +286,11 @@ def _attention(self, query, key, value): beta=0, alpha=self.scale, ) - attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_scores.softmax(dim=-1) # 16, 4096, 77 # compute attention output - hidden_states = torch.matmul(attention_probs, value) + hidden_states = torch.matmul(attention_probs, value) # 16, 4096, 40 # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) # 2, 4096, 320 return hidden_states def _sliced_attention(self, query, key, value, sequence_length, dim): From 2fa9c698eae2890ac5f8e367ca80532ecf94df9a Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 22 Sep 2022 14:44:11 +0000 Subject: [PATCH 27/53] hardcore whats needed for jitting --- .../stable_diffusion/pipeline_stable_diffusion.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 1272fe64e74a..b6034bed8ac0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -195,7 +195,7 @@ def __call__( truncation=True, return_tensors="pt", ) - text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] + text_embeddings = self.text_encoder(text_input.input_ids.to("cuda"))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -207,7 +207,7 @@ def __call__( uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to("cuda"))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -219,8 +219,8 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_device = "cpu" if self.device.type == "mps" else self.device - latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + latents_device = "cuda" + latents_shape = (batch_size, 4, height // 8, width // 8) if latents is None: latents = torch.randn( latents_shape, @@ -259,7 +259,7 @@ def __call__( latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + noise_pred = self.unet(latent_model_input, t, text_embeddings)[0] # TODO: fix for return_dict case # perform guidance if do_classifier_free_guidance: @@ -280,9 +280,9 @@ def __call__( image = image.cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to("cuda") image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values.to(text_embeddings.dtype)) - + if output_type == "pil": image = self.numpy_to_pil(image) From 47c668c569a32cd9c440c487418bc5eba9c8fa0e Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 22 Sep 2022 14:44:27 +0000 Subject: [PATCH 28/53] Revert "hardcore whats needed for jitting" This reverts commit 2fa9c698eae2890ac5f8e367ca80532ecf94df9a. --- .../stable_diffusion/pipeline_stable_diffusion.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index b6034bed8ac0..1272fe64e74a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -195,7 +195,7 @@ def __call__( truncation=True, return_tensors="pt", ) - text_embeddings = self.text_encoder(text_input.input_ids.to("cuda"))[0] + text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -207,7 +207,7 @@ def __call__( uncond_input = self.tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to("cuda"))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch @@ -219,8 +219,8 @@ def __call__( # Unlike in other pipelines, latents need to be generated in the target device # for 1-to-1 results reproducibility with the CompVis implementation. # However this currently doesn't work in `mps`. - latents_device = "cuda" - latents_shape = (batch_size, 4, height // 8, width // 8) + latents_device = "cpu" if self.device.type == "mps" else self.device + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) if latents is None: latents = torch.randn( latents_shape, @@ -259,7 +259,7 @@ def __call__( latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, text_embeddings)[0] # TODO: fix for return_dict case + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: @@ -280,9 +280,9 @@ def __call__( image = image.cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to("cuda") + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values.to(text_embeddings.dtype)) - + if output_type == "pil": image = self.numpy_to_pil(image) From cc9bc1339c998ebe9e7d733f910c6d72d9792213 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 22 Sep 2022 15:28:28 +0000 Subject: [PATCH 29/53] Revert "remove restriction to run conv_norm in fp32" This reverts commit cec592890c32da3d1b78d38b49e4307aedf459b9. --- src/diffusers/models/resnet.py | 6 ++---- src/diffusers/models/unet_2d_condition.py | 3 +-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 36e9dd611e63..785a4b91353e 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -332,8 +332,7 @@ def forward(self, x, temb): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm1(hidden_states).type(hidden_states.dtype) - # hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: @@ -351,8 +350,7 @@ def forward(self, x, temb): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm2(hidden_states).type(hidden_states.dtype) - # hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index e584520e94a8..42b54657d290 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -261,8 +261,7 @@ def forward( # 6. post-process # make sure hidden states is in float32 # when running in half-precision - # sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype) - sample = self.conv_norm_out(sample).type(sample.dtype) + sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype) sample = self.conv_act(sample) sample = self.conv_out(sample) From 419fde37d105aa65ed11f1b323c89bc54c6b2101 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 22 Sep 2022 15:30:27 +0000 Subject: [PATCH 30/53] revert using baddmm in attention's forward --- src/diffusers/models/attention.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 22faaed35fa6..5143eb7f5f40 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -72,14 +72,7 @@ def forward(self, hidden_states): # get scores scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) - - attention_scores = torch.baddbmm( - torch.empty(query_states.shape[0], query_states.shape[1], key_states.shape[1], dtype=query_states.dtype, device=query_states.device), - query_states, - key_states.transpose(-1, -2), - beta=0, - alpha=scale, - ) + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) #TODO: use baddmm attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) # compute attention output From 9312809630026e854a5d277db90a3b9cab49997c Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 22 Sep 2022 16:28:50 +0000 Subject: [PATCH 31/53] cleanup comment --- src/diffusers/models/attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5143eb7f5f40..f147534f7553 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -271,7 +271,6 @@ def forward(self, hidden_states, context=None, mask=None): return self.to_out(hidden_states) def _attention(self, query, key, value): - # attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale attention_scores = torch.baddbmm( torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), query, From 03a2ee78dee01311195ea9c075c039d502563594 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 22 Sep 2022 16:35:31 +0000 Subject: [PATCH 32/53] remove restriction to run conv_norm in fp32. no quality loss was noticed This reverts commit cc9bc1339c998ebe9e7d733f910c6d72d9792213. --- src/diffusers/models/resnet.py | 6 ++++-- src/diffusers/models/unet_2d_condition.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 785a4b91353e..36e9dd611e63 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -332,7 +332,8 @@ def forward(self, x, temb): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm1(hidden_states).type(hidden_states.dtype) + # hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: @@ -350,7 +351,8 @@ def forward(self, x, temb): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm2(hidden_states).type(hidden_states.dtype) + # hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 42b54657d290..e584520e94a8 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -261,7 +261,8 @@ def forward( # 6. post-process # make sure hidden states is in float32 # when running in half-precision - sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype) + # sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype) + sample = self.conv_norm_out(sample).type(sample.dtype) sample = self.conv_act(sample) sample = self.conv_out(sample) From d0b55790c0f0a240ce089f2cfae25d7fb6475f7c Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 22 Sep 2022 17:43:43 +0000 Subject: [PATCH 33/53] add more optimizations techniques to docs --- docs/source/optimization/fp16.mdx | 186 +++++++++++++++++++++++++++++- 1 file changed, 183 insertions(+), 3 deletions(-) diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index 064bc58f8c2b..d2e65405f26d 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -14,7 +14,62 @@ specific language governing permissions and limitations under the License. We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for memory or speed. -## CUDA `autocast` + + + + + + + + + + + +
+ Latency + Speedup +
original + 9.50s + x1 +
cuDNN auto-tuner + 9.37s + x1.01 +
autocast (fp16) + 5.47s + x1.91 +
fp16 + 3.61s + x2.91 +
channels last + 3.30s + x2.87 +
traced UNet + 3.21s + x2.96 +
+obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from the prompt "a photo of an astronaut riding a horse on mars" with 50 DDIM steps. + +## Enable cuDNN auto-tuner + +[NVIDIA cuDNN](https://developer.nvidia.com/cudnn) supports many algorithms to compute a convolution. Autotuner runs a short benchmark and selects the kernel with the best performance on a given hardware for a given input size. + +Since we’re using **convolutional networks** (other types currently not supported), we can enable cuDNN autotuner before launching the inference by setting: + +```python +import torch +torch.backends.cudnn.benchmark = True +``` + +### Use tf32 instead of fp32 (on Ampere and later CUDA devices) + +On Ampere and later CUDA devices matrix multiplications and convolutions can use the TensorFloat32 (TF32) mode for faster but slightly less accurate computations. By default PyTorch enables TF32 mode for convolutions but not matrix multiplications, and unless a network requires full float32 precision we recommend enabling this setting for matrix multiplications, too. It can significantly speed up computations with typically negligible loss of numerical accuracy. You can read more about it [here](https://huggingface.co/docs/transformers/v4.18.0/en/performance#tf32). All you need to do is to add this before your inference: + +```python +import torch +torch.backends.cuda.matmul.allow_tf32 = True +``` + +## Automatic mixed precision (AMP) If you use a CUDA GPU, you can take advantage of `torch.autocast` to perform inference roughly twice as fast at the cost of slightly lower precision. All you need to do is put your inference call inside an `autocast` context manager. The following example shows how to do it using Stable Diffusion text-to-image generation as an example: @@ -47,7 +102,7 @@ pipe = StableDiffusionPipeline.from_pretrained( ## Sliced attention for additional memory savings -For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once. +For even additional memory savings, you can use a sliced version of attention that performs the computation in steps instead of all at once. Attention slicing is useful even if a batch size of just 1 is used - as long as the model uses more than one attention head. If there is more than one attention head the *QK^T* attention matrix can be computed sequentially for each head which can save a significant amount of memory. @@ -73,4 +128,129 @@ with torch.autocast("cuda"): image = pipe(prompt).images[0] ``` -There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM! +There's a small performance penalty of about 10% slower inference times, but this method allows you to use Stable Diffusion in as little as 3.2 GB of VRAM! + +## Using Channels Last memory format + +Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model. + +For example, in order to set the UNet model in our pipeline to use channels last format, we can use the following: + +```python +print(pipe.unet.conv_out.state_dict()['weight'].stride()) # (2880, 9, 3, 1) +pipe.unet.to(memory_format=torch.channels_last) # in-place operation +print(pipe.unet.conv_out.state_dict()['weight'].stride()) # (2880, 1, 960, 320) haveing a stride of 1 for the 2nd dimension proves that it works +``` + +## Tracing + +Tracing runs an example input tensor through your model, and captures the operations that are invoked as that input makes its way through the model's layers so that an executable or `ScriptFunction` is returned that will be optimized using just-in-time compilation. + +To trace our UNet model, we can use the following: + +```python +import time +import torch +from diffusers import StableDiffusionPipeline +import functools + +# torch disable grad +torch.set_grad_enabled(False) + +# load inputs +def generate_inputs(): + sample = torch.randn(2, 4, 64, 64).half().cuda() + timestep = torch.rand(1).half().cuda() * 999 + encoder_hidden_states = torch.randn(2, 77, 768).half().cuda() + return sample, timestep, encoder_hidden_states + + +pipe = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + # scheduler=scheduler, + use_auth_token=True, + revision="fp16", + torch_dtype=torch.float16 +).to("cuda") +unet = pipe.unet +unet.eval() +unet.to(memory_format=torch.channels_last) # use channels_last memory format +unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default + +# warmup +for _ in range(3): + with torch.inference_mode(): + inputs = generate_inputs() + orig_output = unet(*inputs) + +# trace +print("tracing..") +unet_traced = torch.jit.trace(unet, inputs) +unet_traced.eval() +print("done tracing") + + +# warmup and optimize graph +for _ in range(5): + with torch.inference_mode(): + inputs = generate_inputs() + orig_output = unet_traced(*inputs) + + +# benchmarking +with torch.inference_mode(): + for _ in range(2): + torch.cuda.synchronize() + start_time = time.time() + for _ in range(50): + orig_output = unet_traced(*inputs) + torch.cuda.synchronize() + print(f"unet traced inference took {time.time() - start_time:.2f} seconds") + for _ in range(2): + torch.cuda.synchronize() + start_time = time.time() + for _ in range(50): + orig_output = unet(*inputs) + torch.cuda.synchronize() + print(f"unet inference took {time.time() - start_time:.2f} seconds") + +# save the model +unet_traced.save("unet_traced.pt") +``` + +Then we can replace the `unet` attribute of the pipeline with the traced model like the following + +```python +from diffusers import StableDiffusionPipeline +import torch +from dataclasses import dataclass + +@dataclass +class UNet2DConditionOutput(): + sample: torch.FloatTensor + +pipe = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + # scheduler=scheduler, + use_auth_token=True, + revision="fp16", + torch_dtype=torch.float16 +).to("cuda") + +# use jitted unet +unet_traced = torch.jit.load("unet_traced.pt") +# del pipe.unet +class TracedUNet(torch.nn.Module): + def __init__(self): + super().__init__() + self.in_channels = pipe.unet.in_channels + self.device = pipe.unet.device + def forward(self, latent_model_input, t, encoder_hidden_states): + sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0] + return UNet2DConditionOutput(sample=sample) + +pipe.unet = TracedUNet() + +with torch.inference_mode(): + image = pipe([prompt]*1, num_inference_steps=5).images[0] +``` From 3bdf1ed8c2bfd488de8afbd5dceb7642a4baaf41 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 22 Sep 2022 17:46:59 +0000 Subject: [PATCH 34/53] Revert "add shapes comments" This reverts commit 31c58eadb8892f95478cdf05229adf678678c5f4. --- src/diffusers/models/attention.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f147534f7553..20624003afe1 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -141,11 +141,11 @@ def forward(self, hidden_states, context=None): # note: if no context is given, cross-attention defaults to self-attention batch, channel, height, weight = hidden_states.shape residual = hidden_states - hidden_states = self.norm(hidden_states) # 2, 320, 64, 64 - hidden_states = self.proj_in(hidden_states) # 2, 320, 64, 64 + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=context) # 2, 4096, 320 + hidden_states = block(hidden_states, context=context) hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) hidden_states = self.proj_out(hidden_states) return hidden_states + residual @@ -234,10 +234,10 @@ def __init__( self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape # 2, 4096, 320 + batch_size, seq_len, dim = tensor.shape head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) # 2, 4096, 8, 40 - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) # 16, 4096, 40 + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) return tensor def reshape_batch_dim_to_heads(self, tensor): @@ -264,7 +264,7 @@ def forward(self, hidden_states, context=None, mask=None): # attention, what we cannot get enough of if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value) # 2, 4096, 320 + hidden_states = self._attention(query, key, value) else: hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) @@ -278,11 +278,11 @@ def _attention(self, query, key, value): beta=0, alpha=self.scale, ) - attention_probs = attention_scores.softmax(dim=-1) # 16, 4096, 77 + attention_probs = attention_scores.softmax(dim=-1) # compute attention output - hidden_states = torch.matmul(attention_probs, value) # 16, 4096, 40 + hidden_states = torch.matmul(attention_probs, value) # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) # 2, 4096, 320 + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states def _sliced_attention(self, query, key, value, sequence_length, dim): From aeddb45475c53eeb869670a9c52aaaa34151f00e Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Sat, 24 Sep 2022 17:44:00 +0000 Subject: [PATCH 35/53] apply suggestions --- docs/source/optimization/fp16.mdx | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index d2e65405f26d..c847eeef0557 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -157,6 +157,10 @@ import functools # torch disable grad torch.set_grad_enabled(False) +# set variables +n_experiments = 2 +unet_runs_per_experiment = 50 + # load inputs def generate_inputs(): sample = torch.randn(2, 4, 64, 64).half().cuda() @@ -199,17 +203,17 @@ for _ in range(5): # benchmarking with torch.inference_mode(): - for _ in range(2): + for _ in range(n_experiments): torch.cuda.synchronize() start_time = time.time() - for _ in range(50): + for _ in range(unet_runs_per_experiment): orig_output = unet_traced(*inputs) torch.cuda.synchronize() print(f"unet traced inference took {time.time() - start_time:.2f} seconds") - for _ in range(2): + for _ in range(n_experiments): torch.cuda.synchronize() start_time = time.time() - for _ in range(50): + for _ in range(unet_runs_per_experiment): orig_output = unet(*inputs) torch.cuda.synchronize() print(f"unet inference took {time.time() - start_time:.2f} seconds") @@ -252,5 +256,5 @@ class TracedUNet(torch.nn.Module): pipe.unet = TracedUNet() with torch.inference_mode(): - image = pipe([prompt]*1, num_inference_steps=5).images[0] + image = pipe([prompt]*1, num_inference_steps=50).images[0] ``` From f40917287b161236199de9f6e5db0f8eac1ac082 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Sat, 24 Sep 2022 17:55:40 +0000 Subject: [PATCH 36/53] make quality --- docs/source/optimization/fp16.mdx | 34 ++++++++++++------- src/diffusers/dependency_versions_table.py | 1 - src/diffusers/models/attention.py | 6 ++-- src/diffusers/models/embeddings.py | 4 ++- .../pipeline_stable_diffusion.py | 4 ++- 5 files changed, 31 insertions(+), 18 deletions(-) diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index c847eeef0557..bb58c19b49a4 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -57,6 +57,7 @@ Since we’re using **convolutional networks** (other types currently not suppor ```python import torch + torch.backends.cudnn.benchmark = True ``` @@ -66,6 +67,7 @@ On Ampere and later CUDA devices matrix multiplications and convolutions can use ```python import torch + torch.backends.cuda.matmul.allow_tf32 = True ``` @@ -137,9 +139,11 @@ Channels last memory format is an alternative way of ordering NCHW tensors in me For example, in order to set the UNet model in our pipeline to use channels last format, we can use the following: ```python -print(pipe.unet.conv_out.state_dict()['weight'].stride()) # (2880, 9, 3, 1) -pipe.unet.to(memory_format=torch.channels_last) # in-place operation -print(pipe.unet.conv_out.state_dict()['weight'].stride()) # (2880, 1, 960, 320) haveing a stride of 1 for the 2nd dimension proves that it works +print(pipe.unet.conv_out.state_dict()["weight"].stride()) # (2880, 9, 3, 1) +pipe.unet.to(memory_format=torch.channels_last) # in-place operation +print( + pipe.unet.conv_out.state_dict()["weight"].stride() +) # (2880, 1, 960, 320) haveing a stride of 1 for the 2nd dimension proves that it works ``` ## Tracing @@ -170,16 +174,16 @@ def generate_inputs(): pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "CompVis/stable-diffusion-v1-4", # scheduler=scheduler, use_auth_token=True, revision="fp16", - torch_dtype=torch.float16 + torch_dtype=torch.float16, ).to("cuda") unet = pipe.unet unet.eval() -unet.to(memory_format=torch.channels_last) # use channels_last memory format -unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default +unet.to(memory_format=torch.channels_last) # use channels_last memory format +unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default # warmup for _ in range(3): @@ -187,7 +191,7 @@ for _ in range(3): inputs = generate_inputs() orig_output = unet(*inputs) -# trace +# trace print("tracing..") unet_traced = torch.jit.trace(unet, inputs) unet_traced.eval() @@ -198,7 +202,7 @@ print("done tracing") for _ in range(5): with torch.inference_mode(): inputs = generate_inputs() - orig_output = unet_traced(*inputs) + orig_output = unet_traced(*inputs) # benchmarking @@ -229,16 +233,18 @@ from diffusers import StableDiffusionPipeline import torch from dataclasses import dataclass + @dataclass -class UNet2DConditionOutput(): +class UNet2DConditionOutput: sample: torch.FloatTensor + pipe = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", + "CompVis/stable-diffusion-v1-4", # scheduler=scheduler, use_auth_token=True, revision="fp16", - torch_dtype=torch.float16 + torch_dtype=torch.float16, ).to("cuda") # use jitted unet @@ -249,12 +255,14 @@ class TracedUNet(torch.nn.Module): super().__init__() self.in_channels = pipe.unet.in_channels self.device = pipe.unet.device + def forward(self, latent_model_input, t, encoder_hidden_states): sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0] return UNet2DConditionOutput(sample=sample) + pipe.unet = TracedUNet() with torch.inference_mode(): - image = pipe([prompt]*1, num_inference_steps=50).images[0] + image = pipe([prompt] * 1, num_inference_steps=50).images[0] ``` diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 09a7baad560d..82ca5dbb6f56 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -17,7 +17,6 @@ "jaxlib": "jaxlib>=0.1.65,<=0.3.6", "modelcards": "modelcards>=0.1.4", "numpy": "numpy", - "onnxruntime": "onnxruntime", "onnxruntime-gpu": "onnxruntime-gpu", "pytest": "pytest", "pytest-timeout": "pytest-timeout", diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e3797134de54..04ea339f1c40 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -72,7 +72,7 @@ def forward(self, hidden_states): # get scores scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) #TODO: use baddmm + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) # compute attention output @@ -296,7 +296,9 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size - attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale # TODO: use baddbmm for better performance + attn_slice = ( + torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale + ) # TODO: use baddbmm for better performance attn_slice = attn_slice.softmax(dim=-1) attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d8a6cf105a59..06b814e2bbcd 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -37,7 +37,9 @@ def get_timestep_embedding( assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 6ce4978f9eb0..a6aae982ea15 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -281,7 +281,9 @@ def __call__( # run safety checker safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values.to(text_embeddings.dtype)) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_cheker_input.pixel_values.to(text_embeddings.dtype) + ) if output_type == "pil": image = self.numpy_to_pil(image) From 76dda3edc25b1760088f5065a0f823d4ca41cf3b Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 27 Sep 2022 12:31:36 +0000 Subject: [PATCH 37/53] apply suggestions --- src/diffusers/dependency_versions_table.py | 3 ++- src/diffusers/models/resnet.py | 6 ++---- src/diffusers/models/unet_2d_condition.py | 1 - 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 82ca5dbb6f56..5ea41c4aa246 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -16,7 +16,8 @@ "jax": "jax>=0.2.8,!=0.3.2,<=0.3.6", "jaxlib": "jaxlib>=0.1.65,<=0.3.6", "modelcards": "modelcards>=0.1.4", - "numpy": "numpy", + "numpy": "numpy", + "onnxruntime": "onnxruntime", "onnxruntime-gpu": "onnxruntime-gpu", "pytest": "pytest", "pytest-timeout": "pytest-timeout", diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index a893a99b3cef..fe11b0faeaec 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -331,8 +331,7 @@ def forward(self, x, temb): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm1(hidden_states).type(hidden_states.dtype) - # hidden_states = self.norm1.float()(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm1(hidden_states) hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: @@ -350,8 +349,7 @@ def forward(self, x, temb): # make sure hidden states is in float32 # when running in half-precision - hidden_states = self.norm2(hidden_states).type(hidden_states.dtype) - # hidden_states = self.norm2.float()(hidden_states.float()).type(hidden_states.dtype) + hidden_states = self.norm2(hidden_states) hidden_states = self.nonlinearity(hidden_states) hidden_states = self.dropout(hidden_states) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 111a441780f5..cfc3a8c79462 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -278,7 +278,6 @@ def forward( # 6. post-process # make sure hidden states is in float32 # when running in half-precision - # sample = self.conv_norm_out.float()(sample.float()).type(sample.dtype) sample = self.conv_norm_out(sample).type(sample.dtype) sample = self.conv_act(sample) sample = self.conv_out(sample) From 8929d76e14750df040af21d65bf024b680f3a966 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 27 Sep 2022 12:33:27 +0000 Subject: [PATCH 38/53] styling --- src/diffusers/dependency_versions_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 5ea41c4aa246..09a7baad560d 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -16,7 +16,7 @@ "jax": "jax>=0.2.8,!=0.3.2,<=0.3.6", "jaxlib": "jaxlib>=0.1.65,<=0.3.6", "modelcards": "modelcards>=0.1.4", - "numpy": "numpy", + "numpy": "numpy", "onnxruntime": "onnxruntime", "onnxruntime-gpu": "onnxruntime-gpu", "pytest": "pytest", From 98e80da169ecee208b93b91b203f8945cdda81da Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Tue, 27 Sep 2022 23:41:08 +0000 Subject: [PATCH 39/53] `scheduler.timesteps` are now arrays so we dont need .to() --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 7028120ffe59..393536550236 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -234,7 +234,6 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - self.scheduler.timesteps = self.scheduler.timesteps.to(latents_device) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): From 9b1ec08c65a8b35da957b37d91d1efd1e3ebb2f6 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 28 Sep 2022 00:21:13 +0000 Subject: [PATCH 40/53] remove useless .type() --- src/diffusers/models/unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index cfc3a8c79462..b859909bee69 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -278,7 +278,7 @@ def forward( # 6. post-process # make sure hidden states is in float32 # when running in half-precision - sample = self.conv_norm_out(sample).type(sample.dtype) + sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) From 61ed4c359077ea5e02ff556eb5f0786c7b2e90fa Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 28 Sep 2022 09:21:24 +0000 Subject: [PATCH 41/53] use mean instead of max in `test_stable_diffusion_inpaint_pipeline_k_lms` --- tests/test_pipelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 61d5ac3a4e28..671063b5a795 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1368,7 +1368,7 @@ def test_stable_diffusion_inpaint_pipeline_k_lms(self): image = output.images[0] assert image.shape == (512, 512, 3) - assert np.abs(expected_image - image).max() < 1e-2 + assert np.abs(expected_image - image).mean() < 1e-2 @slow def test_stable_diffusion_onnx(self): From e3c38e8e8a87526ca802cc3fd6007741d602733c Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 28 Sep 2022 12:30:33 +0000 Subject: [PATCH 42/53] move scheduler timestamps to correct device if tensors --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 393536550236..d2f952eb2499 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -234,6 +234,8 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) + if isinstance(self.scheduler.timesteps, torch.Tensor): + self.scheduler.timesteps = self.scheduler.timesteps.to(self.device) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): From f25f1c1f4f13f29529055e55a54e7c21fd5acd3f Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 29 Sep 2022 09:25:07 +0000 Subject: [PATCH 43/53] add device to `set_timesteps` in LMSD scheduler --- src/diffusers/schedulers/scheduling_lms_discrete.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 4595b2fe5aaf..43a8fcb7e4f2 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -131,13 +131,15 @@ def lms_derivative(tau): return integrated_coeff - def set_timesteps(self, num_inference_steps: int): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. Args: num_inference_steps (`int`): the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps @@ -145,8 +147,8 @@ def set_timesteps(self, num_inference_steps: int): sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) - self.timesteps = torch.from_numpy(timesteps) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + self.timesteps = torch.from_numpy(timesteps).to(device=device) self.derivatives = [] From 00d5a51e5c20d8d445c8664407ef29608106d899 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 29 Sep 2022 09:27:27 +0000 Subject: [PATCH 44/53] `self.scheduler.set_timesteps` now uses device arg for schedulers that accept it --- examples/community/clip_guided_stable_diffusion.py | 2 +- src/diffusers/pipelines/ddim/pipeline_ddim.py | 2 +- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 2 +- .../pipelines/latent_diffusion/pipeline_latent_diffusion.py | 2 +- .../pipeline_latent_diffusion_uncond.py | 2 +- src/diffusers/pipelines/pndm/pipeline_pndm.py | 3 ++- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 4 +--- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_onnx.py | 2 +- src/diffusers/schedulers/scheduling_ddpm.py | 2 +- 11 files changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index a34e8ab7edfc..b26f6b13a2ac 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -259,7 +259,7 @@ def __call__( if accepts_offset: extra_set_kwargs["offset"] = 1 - self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + self.scheduler.set_timesteps(num_inference_steps, device=self.device, **extra_set_kwargs) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 74607fe87a3d..98a14509e157 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -80,7 +80,7 @@ def __call__( image = image.to(self.device) # set step values - self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_timesteps(num_inference_steps, device=self.device) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index aae29737aae3..cfd07aead99c 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -73,7 +73,7 @@ def __call__( image = image.to(self.device) # set step values - self.scheduler.set_timesteps(1000) + self.scheduler.set_timesteps(1000, device=self.device) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 556e4211892b..ef1b93229943 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -118,7 +118,7 @@ def __call__( ) latents = latents.to(self.device) - self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_timesteps(num_inference_steps, device=self.device) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index ef82abb7e6cb..db0b846b4e77 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -64,7 +64,7 @@ def __call__( ) latents = latents.to(self.device) - self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_timesteps(num_inference_steps, device=self.device) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index f360da09ac94..d0d249993b72 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -14,6 +14,7 @@ # limitations under the License. +from os import device_encoding from typing import Optional, Tuple, Union import torch @@ -80,7 +81,7 @@ def __call__( ) image = image.to(self.device) - self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_timesteps(num_inference_steps, device=self.device) for t in self.progress_bar(self.scheduler.timesteps): model_output = self.unet(image, t).sample diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index d2f952eb2499..f43d2814b4f4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -233,9 +233,7 @@ def __call__( latents = latents.to(latents_device) # set timesteps - self.scheduler.set_timesteps(num_inference_steps) - if isinstance(self.scheduler.timesteps, torch.Tensor): - self.scheduler.timesteps = self.scheduler.timesteps.to(self.device) + self.scheduler.set_timesteps(num_inference_steps, device=self.device) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f2ccee71c024..43c0994e6b40 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -189,7 +189,7 @@ def __call__( raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") # set timesteps - self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_timesteps(num_inference_steps, device=self.device) if isinstance(init_image, PIL.Image.Image): init_image = preprocess(init_image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index a95f9152279a..f433c176405a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -209,7 +209,7 @@ def __call__( raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") # set timesteps - self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_timesteps(num_inference_steps, device=self.device) # preprocess image if not isinstance(init_image, torch.FloatTensor): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 07e9c1d9acd6..5a868c91c771 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -111,7 +111,7 @@ def __call__( raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # set timesteps - self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.set_timesteps(num_inference_steps, device=self.device) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index cc17cee4c810..ec63df4cf506 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -147,7 +147,7 @@ def __init__( self.variance_type = variance_type - def set_timesteps(self, num_inference_steps: int): + def set_timesteps(self, num_inference_steps: int, **kwargs): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. From 0cd46139991abc9640f3e27e24f2cfa16f3595f1 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 29 Sep 2022 09:35:17 +0000 Subject: [PATCH 45/53] quick fix --- src/diffusers/pipelines/pndm/pipeline_pndm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index d0d249993b72..d1ab35401153 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -14,7 +14,6 @@ # limitations under the License. -from os import device_encoding from typing import Optional, Tuple, Union import torch From 0fb42d4c6515fd4080503e4a48c9035f66061f4f Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 29 Sep 2022 09:38:10 +0000 Subject: [PATCH 46/53] styling --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f43d2814b4f4..44b663fef64d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -278,9 +278,9 @@ def __call__( image = image.cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_cheker_input.pixel_values.to(text_embeddings.dtype) + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) ) if output_type == "pil": From e6969edb594568b22751052a9d5e70bd2de1f961 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 29 Sep 2022 19:20:17 +0000 Subject: [PATCH 47/53] remove kwargs from schedulers `set_timesteps` --- src/diffusers/schedulers/scheduling_ddpm.py | 2 +- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index ec63df4cf506..cc17cee4c810 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -147,7 +147,7 @@ def __init__( self.variance_type = variance_type - def set_timesteps(self, num_inference_steps: int, **kwargs): + def set_timesteps(self, num_inference_steps: int): """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 43a8fcb7e4f2..30bf8a574638 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -131,7 +131,7 @@ def lms_derivative(tau): return integrated_coeff - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs): + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. From 2ad335384ce5a999afa05500192dd5913d3fe657 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 29 Sep 2022 19:24:52 +0000 Subject: [PATCH 48/53] revert to using max in K-LMS inpaint pipeline test --- tests/test_pipelines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 671063b5a795..61d5ac3a4e28 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1368,7 +1368,7 @@ def test_stable_diffusion_inpaint_pipeline_k_lms(self): image = output.images[0] assert image.shape == (512, 512, 3) - assert np.abs(expected_image - image).mean() < 1e-2 + assert np.abs(expected_image - image).max() < 1e-2 @slow def test_stable_diffusion_onnx(self): From 7183202d81838fa92233f7e79e43dabe239b951a Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 29 Sep 2022 19:34:07 +0000 Subject: [PATCH 49/53] Revert "`self.scheduler.set_timesteps` now uses device arg for schedulers that accept it" This reverts commit 00d5a51e5c20d8d445c8664407ef29608106d899. --- examples/community/clip_guided_stable_diffusion.py | 2 +- src/diffusers/pipelines/ddim/pipeline_ddim.py | 2 +- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 2 +- .../pipelines/latent_diffusion/pipeline_latent_diffusion.py | 2 +- .../pipeline_latent_diffusion_uncond.py | 2 +- src/diffusers/pipelines/pndm/pipeline_pndm.py | 2 +- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 4 +++- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- .../stable_diffusion/pipeline_stable_diffusion_onnx.py | 2 +- 10 files changed, 12 insertions(+), 10 deletions(-) diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index 3cc3c8b03855..1129e4b3bd94 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -259,7 +259,7 @@ def __call__( if accepts_offset: extra_set_kwargs["offset"] = 1 - self.scheduler.set_timesteps(num_inference_steps, device=self.device, **extra_set_kwargs) + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 98a14509e157..74607fe87a3d 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -80,7 +80,7 @@ def __call__( image = image.to(self.device) # set step values - self.scheduler.set_timesteps(num_inference_steps, device=self.device) + self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index cfd07aead99c..aae29737aae3 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -73,7 +73,7 @@ def __call__( image = image.to(self.device) # set step values - self.scheduler.set_timesteps(1000, device=self.device) + self.scheduler.set_timesteps(1000) for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index ef1b93229943..556e4211892b 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -118,7 +118,7 @@ def __call__( ) latents = latents.to(self.device) - self.scheduler.set_timesteps(num_inference_steps, device=self.device) + self.scheduler.set_timesteps(num_inference_steps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index db0b846b4e77..ef82abb7e6cb 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -64,7 +64,7 @@ def __call__( ) latents = latents.to(self.device) - self.scheduler.set_timesteps(num_inference_steps, device=self.device) + self.scheduler.set_timesteps(num_inference_steps) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index d1ab35401153..f360da09ac94 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -80,7 +80,7 @@ def __call__( ) image = image.to(self.device) - self.scheduler.set_timesteps(num_inference_steps, device=self.device) + self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): model_output = self.unet(image, t).sample diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 44b663fef64d..4a58d8d30936 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -233,7 +233,9 @@ def __call__( latents = latents.to(latents_device) # set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=self.device) + self.scheduler.set_timesteps(num_inference_steps) + if isinstance(self.scheduler.timesteps, torch.Tensor): + self.scheduler.timesteps = self.scheduler.timesteps.to(self.device) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 43c0994e6b40..f2ccee71c024 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -189,7 +189,7 @@ def __call__( raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") # set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=self.device) + self.scheduler.set_timesteps(num_inference_steps) if isinstance(init_image, PIL.Image.Image): init_image = preprocess(init_image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index f433c176405a..a95f9152279a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -209,7 +209,7 @@ def __call__( raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") # set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=self.device) + self.scheduler.set_timesteps(num_inference_steps) # preprocess image if not isinstance(init_image, torch.FloatTensor): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py index 5a868c91c771..07e9c1d9acd6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_onnx.py @@ -111,7 +111,7 @@ def __call__( raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") # set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=self.device) + self.scheduler.set_timesteps(num_inference_steps) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): From da67fe600d4b73314a055219111d75787a21a733 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 29 Sep 2022 20:24:44 +0000 Subject: [PATCH 50/53] move timesteps to correct device before loop in SD pipeline --- src/diffusers/models/unet_2d_condition.py | 1 + .../stable_diffusion/pipeline_stable_diffusion.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index b859909bee69..5255f33ff97a 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -230,6 +230,7 @@ def forward( # 1. time timesteps = timestep if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None] diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 4a58d8d30936..6c338053fd95 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -234,8 +234,10 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - if isinstance(self.scheduler.timesteps, torch.Tensor): - self.scheduler.timesteps = self.scheduler.timesteps.to(self.device) + + # Some schedulers like PNDM have timesteps as arrays + # It's more optimzed to move all timesteps to correct device beforehand + timesteps_tensor = torch.tensor(self.scheduler.timesteps, device=self.device) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler): @@ -250,7 +252,7 @@ def __call__( if accepts_eta: extra_step_kwargs["eta"] = eta - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): + for i, t in enumerate(self.progress_bar(timesteps_tensor)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents if isinstance(self.scheduler, LMSDiscreteScheduler): From c8cc2bab67929f2553dfcaa99dd763b4649d140c Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 29 Sep 2022 20:31:12 +0000 Subject: [PATCH 51/53] apply previous fix to other SD pipelines --- .../stable_diffusion/pipeline_stable_diffusion_img2img.py | 6 +++++- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f2ccee71c024..8f2bb67b8de1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -265,7 +265,11 @@ def __call__( latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) - for i, t in enumerate(self.progress_bar(self.scheduler.timesteps[t_start:])): + # Some schedulers like PNDM have timesteps as arrays + # It's more optimzed to move all timesteps to correct device beforehand + timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): t_index = t_start + i # expand the latents if we are doing classifier free guidance diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index a95f9152279a..2e792df1803e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -298,7 +298,11 @@ def __call__( latents = init_latents t_start = max(num_inference_steps - init_timestep + offset, 0) - for i, t in tqdm(enumerate(self.scheduler.timesteps[t_start:])): + # Some schedulers like PNDM have timesteps as arrays + # It's more optimzed to move all timesteps to correct device beforehand + timesteps_tensor = torch.tensor(self.scheduler.timesteps[t_start:], device=self.device) + + for i, t in tqdm(enumerate(timesteps_tensor)): t_index = t_start + i # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents From b6162dc923f095030d691a581ea722479a1df9cf Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 29 Sep 2022 20:36:42 +0000 Subject: [PATCH 52/53] UNet now accepts tensor timesteps even on wrong device, to avoid errors - it shouldnt affect performance if timesteps are alrdy on correct device - it does slow down performance if they're on the wrong device --- src/diffusers/models/unet_2d_condition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 5255f33ff97a..3ea8829b48e1 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -233,7 +233,7 @@ def forward( # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None] + timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) From 9a1fb0355e04b9ca606795908ab1eda096be90e6 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 29 Sep 2022 21:57:52 +0000 Subject: [PATCH 53/53] fix pipeline when timesteps are arrays with strides --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 6c338053fd95..5c6890db82fd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -237,7 +237,10 @@ def __call__( # Some schedulers like PNDM have timesteps as arrays # It's more optimzed to move all timesteps to correct device beforehand - timesteps_tensor = torch.tensor(self.scheduler.timesteps, device=self.device) + if torch.is_tensor(self.scheduler.timesteps): + timesteps_tensor = self.scheduler.timesteps.to(self.device) + else: + timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device) # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas if isinstance(self.scheduler, LMSDiscreteScheduler):