@@ -158,7 +158,7 @@ def __call__(self, scale: float | torch.Tensor) -> float | torch.Tensor:
158158
159159
160160class MultiviewScaleRule (object ):
161- def __init__ (self , min_scale : float = 1.2 ):
161+ def __init__ (self , min_scale : float = 1.0 ):
162162 self .min_scale = min_scale
163163
164164 def __call__ (
@@ -243,8 +243,9 @@ def prepare_inputs(
243243
244244
245245class MultiviewCFG (VanillaCFG ):
246- def __init__ (self ):
247- self .scale_rule = MultiviewScaleRule ()
246+ def __init__ (self , cfg_min : float = 1.0 ):
247+ self .scale_min = cfg_min
248+ self .scale_rule = MultiviewScaleRule (min_scale = cfg_min )
248249 self .scale_schedule = ConstantScaleSchedule ()
249250 self .guidance = ConstantGuidance ()
250251
@@ -265,8 +266,8 @@ def __call__( # type: ignore
265266
266267
267268class MultiviewTemporalCFG (MultiviewCFG ):
268- def __init__ (self , num_frames : int ):
269- super ().__init__ ()
269+ def __init__ (self , * args , num_frames : int , ** kwargs ):
270+ super ().__init__ (* args , ** kwargs )
270271
271272 self .num_frames = num_frames
272273 distance_matrix = (
@@ -291,8 +292,7 @@ def __call__(
291292 + (~ input_frame_mask [:, None ]) * self .num_frames
292293 ).min (- 1 )[0 ]
293294 min_distance = min_distance / min_distance .max (- 1 , keepdim = True )[0 ].clamp (min = 1 )
294- min_scale = self .scale_rule .min_scale
295- scale = min_distance * (scale - min_scale ) + min_scale
295+ scale = min_distance * (scale - self .scale_min ) + self .scale_min
296296 scale = rearrange (scale , "b t ... -> (b t) ..." )
297297 scale = append_dims (scale , x .ndim )
298298 return super ().__call__ (x , sigma , scale , c2w , K , input_frame_mask .flatten (0 , 1 ))
0 commit comments