|
14 | 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
15 | 15 | # See the License for the specific language governing permissions and |
16 | 16 |
|
17 | | -import contextlib |
18 | 17 | import gc |
19 | 18 | import itertools |
20 | 19 | import json |
|
26 | 25 | from pathlib import Path |
27 | 26 |
|
28 | 27 | import diffusers |
29 | | -import numpy as np |
30 | 28 | import torch |
31 | 29 | import torch.nn.functional as F |
32 | 30 | import torch.utils.checkpoint |
|
36 | 34 | from diffusers import ( |
37 | 35 | AutoencoderKL, |
38 | 36 | DDPMScheduler, |
39 | | - DPMSolverMultistepScheduler, |
40 | 37 | EDMEulerScheduler, |
41 | 38 | EulerDiscreteScheduler, |
42 | 39 | StableDiffusionXLPipeline, |
@@ -78,59 +75,6 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision): |
78 | 75 | return scheduler_type |
79 | 76 |
|
80 | 77 |
|
81 | | -def log_validation( |
82 | | - pipeline, |
83 | | - args, |
84 | | - accelerator, |
85 | | - pipeline_args, |
86 | | - epoch, |
87 | | - is_final_validation=False, |
88 | | -): |
89 | | - logger.info( |
90 | | - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" |
91 | | - f" {args.validation_prompt}." |
92 | | - ) |
93 | | - |
94 | | - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it |
95 | | - scheduler_args = {} |
96 | | - |
97 | | - if not args.do_edm_style_training: |
98 | | - if "variance_type" in pipeline.scheduler.config: |
99 | | - variance_type = pipeline.scheduler.config.variance_type |
100 | | - |
101 | | - if variance_type in ["learned", "learned_range"]: |
102 | | - variance_type = "fixed_small" |
103 | | - |
104 | | - scheduler_args["variance_type"] = variance_type |
105 | | - |
106 | | - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) |
107 | | - |
108 | | - pipeline = pipeline.to(accelerator.device) |
109 | | - pipeline.set_progress_bar_config(disable=True) |
110 | | - |
111 | | - # run inference |
112 | | - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None |
113 | | - # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better |
114 | | - # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051 |
115 | | - inference_ctx = ( |
116 | | - contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast() |
117 | | - ) |
118 | | - |
119 | | - with inference_ctx: |
120 | | - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] |
121 | | - |
122 | | - for tracker in accelerator.trackers: |
123 | | - phase_name = "test" if is_final_validation else "validation" |
124 | | - if tracker.name == "tensorboard": |
125 | | - np_images = np.stack([np.asarray(img) for img in images]) |
126 | | - tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") |
127 | | - |
128 | | - del pipeline |
129 | | - torch.cuda.empty_cache() |
130 | | - |
131 | | - return images |
132 | | - |
133 | | - |
134 | 78 | def import_model_class_from_model_name_or_path( |
135 | 79 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" |
136 | 80 | ): |
@@ -1239,42 +1183,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): |
1239 | 1183 | if global_step >= args.max_train_steps: |
1240 | 1184 | break |
1241 | 1185 |
|
1242 | | - if accelerator.is_main_process: |
1243 | | - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: |
1244 | | - # create pipeline |
1245 | | - if not args.train_text_encoder: |
1246 | | - text_encoder_one = text_encoder_cls_one.from_pretrained( |
1247 | | - args.pretrained_model_name_or_path, |
1248 | | - subfolder="text_encoder", |
1249 | | - revision=args.revision, |
1250 | | - variant=args.variant, |
1251 | | - ) |
1252 | | - text_encoder_two = text_encoder_cls_two.from_pretrained( |
1253 | | - args.pretrained_model_name_or_path, |
1254 | | - subfolder="text_encoder_2", |
1255 | | - revision=args.revision, |
1256 | | - variant=args.variant, |
1257 | | - ) |
1258 | | - pipeline = StableDiffusionXLPipeline.from_pretrained( |
1259 | | - args.pretrained_model_name_or_path, |
1260 | | - vae=vae, |
1261 | | - text_encoder=accelerator.unwrap_model(text_encoder_one), |
1262 | | - text_encoder_2=accelerator.unwrap_model(text_encoder_two), |
1263 | | - unet=accelerator.unwrap_model(unet), |
1264 | | - revision=args.revision, |
1265 | | - variant=args.variant, |
1266 | | - torch_dtype=weight_dtype, |
1267 | | - ) |
1268 | | - pipeline_args = {"prompt": args.validation_prompt} |
1269 | | - |
1270 | | - images = log_validation( |
1271 | | - pipeline, |
1272 | | - args, |
1273 | | - accelerator, |
1274 | | - pipeline_args, |
1275 | | - epoch, |
1276 | | - ) |
1277 | | - |
1278 | 1186 | # Save the lora layers |
1279 | 1187 | accelerator.wait_for_everyone() |
1280 | 1188 | if accelerator.is_main_process: |
|
0 commit comments