Skip to content

Commit 5321f3e

Browse files
authored
add add_noise method in LMSDiscreteScheduler, PNDMScheduler (#227)
add add_noise method in more schedulers
1 parent 3f1861e commit 5321f3e

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

src/diffusers/schedulers/scheduling_lms_discrete.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,14 @@ def add_noise(self, original_samples, noise, timesteps):
130130
noisy_samples = (alpha_prod**0.5) * original_samples + ((1 - alpha_prod) ** 0.5) * noise
131131
return noisy_samples
132132

133+
def add_noise(self, original_samples, noise, timesteps):
134+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
135+
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
136+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
137+
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
138+
139+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
140+
return noisy_samples
141+
133142
def __len__(self):
134143
return self.config.num_train_timesteps

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,5 +250,14 @@ def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
250250

251251
return prev_sample
252252

253+
def add_noise(self, original_samples, noise, timesteps):
254+
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
255+
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
256+
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
257+
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
258+
259+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
260+
return noisy_samples
261+
253262
def __len__(self):
254263
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)