@@ -1107,6 +1107,7 @@ def create_samplers(
11071107 discretization ,
11081108 num_frames : list [int ] | None ,
11091109 num_steps : int ,
1110+ cfg_min : float = 1.0 ,
11101111 device : str | torch .device = "cuda" ,
11111112 abort_event : threading .Event | None = None ,
11121113):
@@ -1124,11 +1125,14 @@ def create_samplers(
11241125 f"Invalid guider type { guider_type } . Must be one of { list (guider_mapping .keys ())} "
11251126 )
11261127 guider_cls = guider_mapping [guider_type ]
1127- if guider_type != 2 :
1128- guider = guider_cls ()
1129- else :
1130- assert num_frames is not None
1131- guider = guider_cls (num_frames [i ])
1128+ guider_args = ()
1129+ if guider_type > 0 :
1130+ guider_args += (cfg_min ,)
1131+ if guider_type == 2 :
1132+ assert num_frames is not None
1133+ guider_args += (num_frames [i ],)
1134+ guider = guider_cls (* guider_args )
1135+
11321136 if abort_event is not None :
11331137 sampler = GradioTrackedSampler (
11341138 abort_event ,
@@ -1569,8 +1573,8 @@ def run_one_scene(
15691573 )
15701574 ]
15711575 value_dict = get_value_dict (
1572- curr_imgs ,
1573- curr_imgs_clip ,
1576+ curr_imgs . to ( "cuda" ) ,
1577+ curr_imgs_clip . to ( "cuda" ) ,
15741578 curr_input_sels
15751579 + [
15761580 sel
@@ -1598,6 +1602,7 @@ def run_one_scene(
15981602 discretization ,
15991603 [len (curr_imgs )],
16001604 options ["num_steps" ],
1605+ options ["cfg_min" ],
16011606 abort_event = abort_event ,
16021607 )
16031608 assert len (samplers ) == 1
@@ -1756,8 +1761,8 @@ def run_one_scene(
17561761 )
17571762 ]
17581763 value_dict = get_value_dict (
1759- curr_imgs ,
1760- curr_imgs_clip ,
1764+ curr_imgs . to ( "cuda" ) ,
1765+ curr_imgs_clip . to ( "cuda" ) ,
17611766 curr_input_sels ,
17621767 curr_c2ws ,
17631768 curr_Ks ,
@@ -1769,6 +1774,7 @@ def run_one_scene(
17691774 discretization ,
17701775 [T_first_pass , T_second_pass ],
17711776 options ["num_steps" ],
1777+ options ["cfg_min" ],
17721778 abort_event = abort_event ,
17731779 )
17741780 samples = do_sample (
@@ -1912,8 +1918,8 @@ def run_one_scene(
19121918 )
19131919 ]
19141920 value_dict = get_value_dict (
1915- curr_imgs ,
1916- curr_imgs_clip ,
1921+ curr_imgs . to ( "cuda" ) ,
1922+ curr_imgs_clip . to ( "cuda" ) ,
19171923 curr_prior_sels ,
19181924 curr_c2ws ,
19191925 curr_Ks ,
0 commit comments