diff --git a/optimizedSD/optimized_img2img.py b/optimizedSD/optimized_img2img.py index 8bf8635d7..c0cbb9479 100644 --- a/optimizedSD/optimized_img2img.py +++ b/optimizedSD/optimized_img2img.py @@ -3,7 +3,7 @@ import numpy as np from random import randint from omegaconf import OmegaConf -from PIL import Image +from PIL import Image, PngImagePlugin from tqdm import tqdm, trange from itertools import islice from einops import rearrange @@ -314,8 +314,20 @@ def load_img(path, h0, w0): x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0)) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") + info = PngImagePlugin.PngInfo() + info.add_text("Init Img", os.path.join(os.path.basename(os.path.dirname(opt.init_img)), os.path.basename(opt.init_img))) + info.add_text("Prompt", opt.prompt) + info.add_text("Seed", str(opt.seed)) + info.add_text("Scale", str(opt.scale)) + info.add_text("Steps", str(opt.ddim_steps)) + info.add_text("Strength", str(opt.strength)) + info.add_text("Precision", opt.precision) + info.add_text("Batch Size", str(opt.n_samples)) + info.add_text("Batch Index", str(i)) Image.fromarray(x_sample.astype(np.uint8)).save( - os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.png") + os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.png"), + "PNG", + pnginfo=info ) seeds += str(opt.seed) + "," opt.seed += 1 diff --git a/optimizedSD/optimized_txt2img.py b/optimizedSD/optimized_txt2img.py index ef4b814c3..d81d4285a 100644 --- a/optimizedSD/optimized_txt2img.py +++ b/optimizedSD/optimized_txt2img.py @@ -3,7 +3,7 @@ import numpy as np from random import randint from omegaconf import OmegaConf -from PIL import Image +from PIL import Image, PngImagePlugin from tqdm import tqdm, trange from itertools import islice from einops import rearrange @@ -287,8 +287,18 @@ def load_model_from_config(ckpt, verbose=False): x_samples_ddim = modelFS.decode_first_stage(samples_ddim[i].unsqueeze(0)) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") + info = PngImagePlugin.PngInfo() + info.add_text("Prompt", opt.prompt) + info.add_text("Seed", str(opt.seed)) + info.add_text("Scale", str(opt.scale)) + info.add_text("Steps", str(opt.ddim_steps)) + info.add_text("Precision", opt.precision) + info.add_text("Batch Size", str(opt.n_samples)) + info.add_text("Batch Index", str(i)) Image.fromarray(x_sample.astype(np.uint8)).save( - os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.png") + os.path.join(sample_path, "seed_" + str(opt.seed) + "_" + f"{base_count:05}.png"), + "PNG", + pnginfo=info ) seeds += str(opt.seed) + "," opt.seed += 1