Skip to content

Commit 24c8d99

Browse files
committed
make cfg_min an option
1 parent 0deae16 commit 24c8d99

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

seva/sampling.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def __call__(self, scale: float | torch.Tensor) -> float | torch.Tensor:
158158

159159

160160
class MultiviewScaleRule(object):
161-
def __init__(self, min_scale: float = 1.2):
161+
def __init__(self, min_scale: float = 1.0):
162162
self.min_scale = min_scale
163163

164164
def __call__(
@@ -243,8 +243,9 @@ def prepare_inputs(
243243

244244

245245
class MultiviewCFG(VanillaCFG):
246-
def __init__(self):
247-
self.scale_rule = MultiviewScaleRule()
246+
def __init__(self, cfg_min: float = 1.0):
247+
self.scale_min = cfg_min
248+
self.scale_rule = MultiviewScaleRule(min_scale=cfg_min)
248249
self.scale_schedule = ConstantScaleSchedule()
249250
self.guidance = ConstantGuidance()
250251

@@ -265,8 +266,8 @@ def __call__( # type: ignore
265266

266267

267268
class MultiviewTemporalCFG(MultiviewCFG):
268-
def __init__(self, num_frames: int):
269-
super().__init__()
269+
def __init__(self, *args, num_frames: int, **kwargs):
270+
super().__init__(*args, **kwargs)
270271

271272
self.num_frames = num_frames
272273
distance_matrix = (
@@ -291,8 +292,7 @@ def __call__(
291292
+ (~input_frame_mask[:, None]) * self.num_frames
292293
).min(-1)[0]
293294
min_distance = min_distance / min_distance.max(-1, keepdim=True)[0].clamp(min=1)
294-
min_scale = self.scale_rule.min_scale
295-
scale = min_distance * (scale - min_scale) + min_scale
295+
scale = min_distance * (scale - self.scale_min) + self.scale_min
296296
scale = rearrange(scale, "b t ... -> (b t) ...")
297297
scale = append_dims(scale, x.ndim)
298298
return super().__call__(x, sigma, scale, c2w, K, input_frame_mask.flatten(0, 1))

0 commit comments

Comments
 (0)