From c5d4e87948a65b93a84f02a2c0a4dffb1cbefd56 Mon Sep 17 00:00:00 2001 From: dnwalkup Date: Tue, 6 Dec 2022 12:56:48 -0800 Subject: [PATCH 1/4] Use autocast to correct precision error --- examples/dreambooth/train_dreambooth.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 0522b3fb8f8a..bae372e291d4 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -11,6 +11,8 @@ import torch.utils.checkpoint from torch.utils.data import Dataset +from torch.cuda.amp import autocast #Autocast for proper type casting half vs full + from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import set_seed @@ -629,7 +631,8 @@ def collate_fn(examples): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + with autocast(): + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": From 03acabd9844a8047252b7c9914213cf9113356b6 Mon Sep 17 00:00:00 2001 From: dnwalkup Date: Tue, 6 Dec 2022 13:12:15 -0800 Subject: [PATCH 2/4] Fixed code quality check --- examples/dreambooth/train_dreambooth.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index bae372e291d4..c08d8f91e7e3 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -11,7 +11,8 @@ import torch.utils.checkpoint from torch.utils.data import Dataset -from torch.cuda.amp import autocast #Autocast for proper type casting half vs full +# Autocast for proper type casting half vs full +from torch.cuda.amp import autocast from accelerate import Accelerator from accelerate.logging import get_logger From 2a0e74b6442d528a399874b1a535c2100680d193 Mon Sep 17 00:00:00 2001 From: dnwalkup Date: Tue, 6 Dec 2022 14:04:15 -0800 Subject: [PATCH 3/4] Remove comments --- examples/dreambooth/train_dreambooth.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index c08d8f91e7e3..0ace419a136a 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -11,7 +11,6 @@ import torch.utils.checkpoint from torch.utils.data import Dataset -# Autocast for proper type casting half vs full from torch.cuda.amp import autocast from accelerate import Accelerator From 49d8a1b72c1e7c000062025ee72be7dcdf01a1ab Mon Sep 17 00:00:00 2001 From: dnwalkup Date: Tue, 6 Dec 2022 14:07:27 -0800 Subject: [PATCH 4/4] Remove code comment --- examples/dreambooth/train_dreambooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 0ace419a136a..9bec871c31ad 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -11,7 +11,7 @@ import torch.utils.checkpoint from torch.utils.data import Dataset -from torch.cuda.amp import autocast +from torch.cuda.amp import autocast from accelerate import Accelerator from accelerate.logging import get_logger