-
Notifications
You must be signed in to change notification settings - Fork 12
Description
AssertionError Traceback (most recent call last)
in <cell line: 21>()
19
20 control_image.save("control.png")
---> 21 async_diff = AsyncDiff(pipe, model_n=2, stride=1, time_shift=False)
22
23 async_diff.reset_state(warm_up=1)
/content/AsyncDiff/src/async_sd.py in init(self, pipeline, model_n, stride, warm_up, time_shift)
46 def init(self, pipeline, model_n=2, stride=1, warm_up=1, time_shift=False):
47 dist.init_process_group("nccl")
---> 48 if not dist.get_rank(): assert model_n + stride - 1 == dist.get_world_size(), "[ERROR]: The strategy is not compatible with the number of devices. (model_n + stride - 1) should be equal to world_size."
49 assert stride==1 or stride==2, "[ERROR]: The stride should be set as 1 or 2"
50 self.model_n = model_n
AssertionError: [ERROR]: The strategy is not compatible with the number of devices. (model_n + stride - 1) should be equal to world_size.
from AsyncDiff.src.async_sd import AsyncDiff
start_time=time.time()
image = load_image("https://img0.baidu.com/it/u=1483900031,3964539741&fm=253&fmt=auto&app=138&f=JPEG?w=751&h=500")
ip_image = load_image("https://replicate.delivery/pbxt/Kc2uOvUQ0sVmT6ALE7MIBwOAXlQ6kqBGc2XOz54oKVO2NxvF/style.jpg")
image_orig = image
image = np.array(image)
low_threshold = 100
high_threshold = 200
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
control_image = Image.fromarray(image)
control_image.save("control.png")
async_diff = AsyncDiff(pipe, model_n=2, stride=1, time_shift=False)
async_diff.reset_state(warm_up=1)
image = pipe(
prompt="a nice room, best quality, highres, wallpaper,",
image=image_orig,
control_image=control_image,
negative_prompt="(low quality, bad quality, worst quality:1.2)",
ip_adapter_image=ip_image,
num_inference_steps=20,
guidance_scale=8,
controlnet_conditioning_scale=0.6,
guess_mode=False,
strength=1.0,
generator=generator,
).images[0]
image.save('./image/dreamshaper8.png')
end_time=time.time()
print("时间",end_time-start_time)