1919import logging
2020import math
2121import os
22+ import random
2223import shutil
2324import warnings
2425from pathlib import Path
4041from PIL .ImageOps import exif_transpose
4142from torch .utils .data import Dataset
4243from torchvision import transforms
44+ from torchvision .transforms .functional import crop
4345from tqdm .auto import tqdm
4446from transformers import AutoTokenizer , PretrainedConfig
4547
@@ -304,18 +306,6 @@ def parse_args(input_args=None):
304306 " resolution"
305307 ),
306308 )
307- parser .add_argument (
308- "--crops_coords_top_left_h" ,
309- type = int ,
310- default = 0 ,
311- help = ("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet." ),
312- )
313- parser .add_argument (
314- "--crops_coords_top_left_w" ,
315- type = int ,
316- default = 0 ,
317- help = ("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet." ),
318- )
319309 parser .add_argument (
320310 "--center_crop" ,
321311 default = False ,
@@ -325,6 +315,11 @@ def parse_args(input_args=None):
325315 " cropped. The images will be resized to the resolution first before cropping."
326316 ),
327317 )
318+ parser .add_argument (
319+ "--random_flip" ,
320+ action = "store_true" ,
321+ help = "whether to randomly flip images horizontally" ,
322+ )
328323 parser .add_argument (
329324 "--train_text_encoder" ,
330325 action = "store_true" ,
@@ -669,6 +664,41 @@ def __init__(
669664 self .instance_images = []
670665 for img in instance_images :
671666 self .instance_images .extend (itertools .repeat (img , repeats ))
667+
668+ # image processing to prepare for using SD-XL micro-conditioning
669+ self .original_sizes = []
670+ self .crop_top_lefts = []
671+ self .pixel_values = []
672+ train_resize = transforms .Resize (size , interpolation = transforms .InterpolationMode .BILINEAR )
673+ train_crop = transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size )
674+ train_flip = transforms .RandomHorizontalFlip (p = 1.0 )
675+ train_transforms = transforms .Compose (
676+ [
677+ transforms .ToTensor (),
678+ transforms .Normalize ([0.5 ], [0.5 ]),
679+ ]
680+ )
681+ for image in self .instance_images :
682+ image = exif_transpose (image )
683+ if not image .mode == "RGB" :
684+ image = image .convert ("RGB" )
685+ self .original_sizes .append ((image .height , image .width ))
686+ image = train_resize (image )
687+ if args .random_flip and random .random () < 0.5 :
688+ # flip
689+ image = train_flip (image )
690+ if args .center_crop :
691+ y1 = max (0 , int (round ((image .height - args .resolution ) / 2.0 )))
692+ x1 = max (0 , int (round ((image .width - args .resolution ) / 2.0 )))
693+ image = train_crop (image )
694+ else :
695+ y1 , x1 , h , w = train_crop .get_params (image , (args .resolution , args .resolution ))
696+ image = crop (image , y1 , x1 , h , w )
697+ crop_top_left = (y1 , x1 )
698+ self .crop_top_lefts .append (crop_top_left )
699+ image = train_transforms (image )
700+ self .pixel_values .append (image )
701+
672702 self .num_instance_images = len (self .instance_images )
673703 self ._length = self .num_instance_images
674704
@@ -698,12 +728,12 @@ def __len__(self):
698728
699729 def __getitem__ (self , index ):
700730 example = {}
701- instance_image = self .instance_images [index % self .num_instance_images ]
702- instance_image = exif_transpose ( instance_image )
703-
704- if not instance_image . mode == "RGB" :
705- instance_image = instance_image . convert ( "RGB" )
706- example ["instance_images " ] = self . image_transforms ( instance_image )
731+ instance_image = self .pixel_values [index % self .num_instance_images ]
732+ original_size = self . original_sizes [ index % self . num_instance_images ]
733+ crop_top_left = self . crop_top_lefts [ index % self . num_instance_images ]
734+ example [ "instance_images" ] = instance_image
735+ example [ "original_size" ] = original_size
736+ example ["crop_top_left " ] = crop_top_left
707737
708738 if self .custom_instance_prompts :
709739 caption = self .custom_instance_prompts [index % self .num_instance_images ]
@@ -730,6 +760,8 @@ def __getitem__(self, index):
730760def collate_fn (examples , with_prior_preservation = False ):
731761 pixel_values = [example ["instance_images" ] for example in examples ]
732762 prompts = [example ["instance_prompt" ] for example in examples ]
763+ original_sizes = [example ["original_size" ] for example in examples ]
764+ crop_top_lefts = [example ["crop_top_left" ] for example in examples ]
733765
734766 # Concat class and instance examples for prior preservation.
735767 # We do this to avoid doing two forward passes.
@@ -740,7 +772,12 @@ def collate_fn(examples, with_prior_preservation=False):
740772 pixel_values = torch .stack (pixel_values )
741773 pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
742774
743- batch = {"pixel_values" : pixel_values , "prompts" : prompts }
775+ batch = {
776+ "pixel_values" : pixel_values ,
777+ "prompts" : prompts ,
778+ "original_sizes" : original_sizes ,
779+ "crop_top_lefts" : crop_top_lefts ,
780+ }
744781 return batch
745782
746783
@@ -1233,11 +1270,9 @@ def load_model_hook(models, input_dir):
12331270 # pooled text embeddings
12341271 # time ids
12351272
1236- def compute_time_ids ():
1273+ def compute_time_ids (original_size , crops_coords_top_left ):
12371274 # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1238- original_size = (args .resolution , args .resolution )
12391275 target_size = (args .resolution , args .resolution )
1240- crops_coords_top_left = (args .crops_coords_top_left_h , args .crops_coords_top_left_w )
12411276 add_time_ids = list (original_size + crops_coords_top_left + target_size )
12421277 add_time_ids = torch .tensor ([add_time_ids ])
12431278 add_time_ids = add_time_ids .to (accelerator .device , dtype = weight_dtype )
@@ -1254,9 +1289,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
12541289 pooled_prompt_embeds = pooled_prompt_embeds .to (accelerator .device )
12551290 return prompt_embeds , pooled_prompt_embeds
12561291
1257- # Handle instance prompt.
1258- instance_time_ids = compute_time_ids ()
1259-
12601292 # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
12611293 # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
12621294 # the redundant encoding.
@@ -1267,7 +1299,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
12671299
12681300 # Handle class prompt for prior-preservation.
12691301 if args .with_prior_preservation :
1270- class_time_ids = compute_time_ids ()
12711302 if not args .train_text_encoder :
12721303 class_prompt_hidden_states , class_pooled_prompt_embeds = compute_text_embeddings (
12731304 args .class_prompt , text_encoders , tokenizers
@@ -1282,9 +1313,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
12821313 # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
12831314 # pack the statically computed variables appropriately here. This is so that we don't
12841315 # have to pass them to the dataloader.
1285- add_time_ids = instance_time_ids
1286- if args .with_prior_preservation :
1287- add_time_ids = torch .cat ([add_time_ids , class_time_ids ], dim = 0 )
12881316
12891317 if not train_dataset .custom_instance_prompts :
12901318 if not args .train_text_encoder :
@@ -1436,18 +1464,24 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14361464 # (this is the forward diffusion process)
14371465 noisy_model_input = noise_scheduler .add_noise (model_input , noise , timesteps )
14381466
1467+ # time ids
1468+ add_time_ids = torch .cat (
1469+ [
1470+ compute_time_ids (original_size = s , crops_coords_top_left = c )
1471+ for s , c in zip (batch ["original_sizes" ], batch ["crop_top_lefts" ])
1472+ ]
1473+ )
1474+
14391475 # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
14401476 if not train_dataset .custom_instance_prompts :
14411477 elems_to_repeat_text_embeds = bsz // 2 if args .with_prior_preservation else bsz
1442- elems_to_repeat_time_ids = bsz // 2 if args .with_prior_preservation else bsz
14431478 else :
14441479 elems_to_repeat_text_embeds = 1
1445- elems_to_repeat_time_ids = bsz // 2 if args .with_prior_preservation else bsz
14461480
14471481 # Predict the noise residual
14481482 if not args .train_text_encoder :
14491483 unet_added_conditions = {
1450- "time_ids" : add_time_ids . repeat ( elems_to_repeat_time_ids , 1 ) ,
1484+ "time_ids" : add_time_ids ,
14511485 "text_embeds" : unet_add_text_embeds .repeat (elems_to_repeat_text_embeds , 1 ),
14521486 }
14531487 prompt_embeds_input = prompt_embeds .repeat (elems_to_repeat_text_embeds , 1 , 1 )
@@ -1459,7 +1493,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14591493 return_dict = False ,
14601494 )[0 ]
14611495 else :
1462- unet_added_conditions = {"time_ids" : add_time_ids . repeat ( elems_to_repeat_time_ids , 1 ) }
1496+ unet_added_conditions = {"time_ids" : add_time_ids }
14631497 prompt_embeds , pooled_prompt_embeds = encode_prompt (
14641498 text_encoders = [text_encoder_one , text_encoder_two ],
14651499 tokenizers = None ,
0 commit comments