-
Notifications
You must be signed in to change notification settings - Fork 26
Open
Description
Hi @Kinyugo ,
Great repo, thank you for the work!
I just wanted to clarify something. From my understanding, what you have implemented here is Consistency Distillation (CD), right? The consistency models trained is isolation to operate independently of the teacher/student framework, which led to some confusion. Your TODO list mentions a work to be done is implementing CD, but it feel to be that it's already done. Am I mistaken?
Thank you for the clarification.
for current_training_step in range(total_training_steps):
data = data_distribution()
num_timesteps = timestep_schedule(current_training_step, total_training_steps, initial_timesteps, final_timesteps)
sigmas = karras_schedule(num_timesteps, sigma_min, sigma_max)
timesteps = uniform_distribution(batch_size, start=0, end=num_timesteps-1)
noise = standard_gaussian_noise()
current_sigmas = sigmas[timesteps]
next_sigmas = sigmas[timesteps + 1]
current_noisy_data = data + current_sigmas * noise
next_noisy_data = data + next_sigmas * noise
student_model_prediction = (skip_scaling(next_sigmas, sigma_data, sigma_min) * next_noisy_data
+ output_scaling(next_sigmas, sigma_data, sigma_min) * student_model(next_noisy_data, next_sigmas))
with no_grad():
teacher_model_prediction = (skip_scaling(current_sigmas, sigma_data, sigma_min) * current_noisy_data
+ output_scaling(current_sigmas, sigma_data, sigma_min) * teacher_model(current_noisy_data, next_sigmas))
loss = distance_metric(student_model_prediction, teacher_model_prediction)
loss.backward()
with no_grad():
current_ema_decay_rate = ema_decay_schedule(current_training_step, initial_ema_decay_rate, initial_timesteps)
teacher_model_params = current_ema_decay_rate * teacher_model_params + (1 - current_ema_decay_rate) * student_model_params
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels