@@ -120,8 +120,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
120120 the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
121121 sampling, and `solver_order=3` for unconditional sampling.
122122 prediction_type (`str`, default `epsilon`):
123- indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`.
124- `v-prediction` is not supported for this scheduler .
123+ indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`,
124+ or `v-prediction`.
125125 thresholding (`bool`, default `False`):
126126 whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
127127 For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
@@ -252,7 +252,7 @@ def convert_model_output(
252252 """
253253 Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs.
254254
255- DPM-Solver is designed to discretize an integral of the noise prediciton model, and DPM-Solver++ is designed to
255+ DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to
256256 discretize an integral of the data prediction model. So we need to first convert the model output to the
257257 corresponding type to match the algorithm.
258258
@@ -275,10 +275,13 @@ def convert_model_output(
275275 x0_pred = (sample - sigma_t * model_output ) / alpha_t
276276 elif self .config .prediction_type == "sample" :
277277 x0_pred = model_output
278+ elif self .config .prediction_type == "v_prediction" :
279+ alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
280+ x0_pred = alpha_t * sample - sigma_t * model_output
278281 else :
279282 raise ValueError (
280- f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample` "
281- " for the FlaxDPMSolverMultistepScheduler."
283+ f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample`, "
284+ " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
282285 )
283286
284287 if self .config .thresholding :
@@ -299,10 +302,14 @@ def convert_model_output(
299302 alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
300303 epsilon = (sample - alpha_t * model_output ) / sigma_t
301304 return epsilon
305+ elif self .config .prediction_type == "v_prediction" :
306+ alpha_t , sigma_t = self .alpha_t [timestep ], self .sigma_t [timestep ]
307+ epsilon = alpha_t * model_output + sigma_t * sample
308+ return epsilon
302309 else :
303310 raise ValueError (
304- f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample` "
305- " for the FlaxDPMSolverMultistepScheduler."
311+ f"prediction_type given as { self .config .prediction_type } must be one of `epsilon`, `sample`, "
312+ " or `v_prediction` for the FlaxDPMSolverMultistepScheduler."
306313 )
307314
308315 def dpm_solver_first_order_update (
0 commit comments