diff --git a/optimizedSD/optimized_img2img.py b/optimizedSD/optimized_img2img.py index 24f3338f0..53e2fd3d8 100644 --- a/optimizedSD/optimized_img2img.py +++ b/optimizedSD/optimized_img2img.py @@ -54,7 +54,7 @@ def load_img(path, h0, w0): config = "optimizedSD/v1-inference.yaml" -ckpt = "models/ldm/stable-diffusion-v1/model.ckpt" +DEFAULT_CKPT = "models/ldm/stable-diffusion-v1/model.ckpt" parser = argparse.ArgumentParser() @@ -174,6 +174,12 @@ def load_img(path, h0, w0): choices=["ddim"], default="ddim", ) +parser.add_argument( + "--ckpt", + type=str, + help="path to checkpoint of model", + default=DEFAULT_CKPT +) opt = parser.parse_args() tic = time.time() @@ -188,7 +194,7 @@ def load_img(path, h0, w0): # Logging logger(vars(opt), log_csv = "logs/img2img_logs.csv") -sd = load_model_from_config(f"{ckpt}") +sd = load_model_from_config(f"{opt.ckpt}") li, lo = [], [] for key, value in sd.items(): sp = key.split(".")