@@ -104,6 +104,22 @@ def __call__(
104104 ):
105105 residual = hidden_states
106106
107+ # separate ip_hidden_states from encoder_hidden_states
108+ if encoder_hidden_states is not None :
109+ if isinstance (encoder_hidden_states , tuple ):
110+ encoder_hidden_states , ip_hidden_states = encoder_hidden_states
111+ else :
112+ deprecation_message = (
113+ "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
114+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
115+ )
116+ deprecate ("encoder_hidden_states not a tuple" , "1.0.0" , deprecation_message , standard_warn = False )
117+ end_pos = encoder_hidden_states .shape [1 ] - self .num_tokens [0 ]
118+ encoder_hidden_states , ip_hidden_states = (
119+ encoder_hidden_states [:, :end_pos , :],
120+ [encoder_hidden_states [:, end_pos :, :]],
121+ )
122+
107123 if attn .spatial_norm is not None :
108124 hidden_states = attn .spatial_norm (hidden_states , temb )
109125
@@ -125,15 +141,8 @@ def __call__(
125141
126142 if encoder_hidden_states is None :
127143 encoder_hidden_states = hidden_states
128- else :
129- # get encoder_hidden_states, ip_hidden_states
130- end_pos = encoder_hidden_states .shape [1 ] - self .num_tokens
131- encoder_hidden_states , ip_hidden_states = (
132- encoder_hidden_states [:, :end_pos , :],
133- encoder_hidden_states [:, end_pos :, :],
134- )
135- if attn .norm_cross :
136- encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
144+ elif attn .norm_cross :
145+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
137146
138147 key = attn .to_k (encoder_hidden_states ) + self .lora_scale * self .to_k_lora (encoder_hidden_states )
139148 value = attn .to_v (encoder_hidden_states ) + self .lora_scale * self .to_v_lora (encoder_hidden_states )
@@ -233,6 +242,22 @@ def __call__(
233242 ):
234243 residual = hidden_states
235244
245+ # separate ip_hidden_states from encoder_hidden_states
246+ if encoder_hidden_states is not None :
247+ if isinstance (encoder_hidden_states , tuple ):
248+ encoder_hidden_states , ip_hidden_states = encoder_hidden_states
249+ else :
250+ deprecation_message = (
251+ "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
252+ " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
253+ )
254+ deprecate ("encoder_hidden_states not a tuple" , "1.0.0" , deprecation_message , standard_warn = False )
255+ end_pos = encoder_hidden_states .shape [1 ] - self .num_tokens [0 ]
256+ encoder_hidden_states , ip_hidden_states = (
257+ encoder_hidden_states [:, :end_pos , :],
258+ [encoder_hidden_states [:, end_pos :, :]],
259+ )
260+
236261 if attn .spatial_norm is not None :
237262 hidden_states = attn .spatial_norm (hidden_states , temb )
238263
@@ -259,15 +284,8 @@ def __call__(
259284
260285 if encoder_hidden_states is None :
261286 encoder_hidden_states = hidden_states
262- else :
263- # get encoder_hidden_states, ip_hidden_states
264- end_pos = encoder_hidden_states .shape [1 ] - self .num_tokens
265- encoder_hidden_states , ip_hidden_states = (
266- encoder_hidden_states [:, :end_pos , :],
267- encoder_hidden_states [:, end_pos :, :],
268- )
269- if attn .norm_cross :
270- encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
287+ elif attn .norm_cross :
288+ encoder_hidden_states = attn .norm_encoder_hidden_states (encoder_hidden_states )
271289
272290 key = attn .to_k (encoder_hidden_states ) + self .lora_scale * self .to_k_lora (encoder_hidden_states )
273291 value = attn .to_v (encoder_hidden_states ) + self .lora_scale * self .to_v_lora (encoder_hidden_states )
@@ -951,30 +969,6 @@ def encode_prompt(
951969
952970 return prompt_embeds , negative_prompt_embeds
953971
954- def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
955- dtype = next (self .image_encoder .parameters ()).dtype
956-
957- if not isinstance (image , torch .Tensor ):
958- image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
959-
960- image = image .to (device = device , dtype = dtype )
961- if output_hidden_states :
962- image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
963- image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
964- uncond_image_enc_hidden_states = self .image_encoder (
965- torch .zeros_like (image ), output_hidden_states = True
966- ).hidden_states [- 2 ]
967- uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
968- num_images_per_prompt , dim = 0
969- )
970- return image_enc_hidden_states , uncond_image_enc_hidden_states
971- else :
972- image_embeds = self .image_encoder (image ).image_embeds
973- image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
974- uncond_image_embeds = torch .zeros_like (image_embeds )
975-
976- return image_embeds , uncond_image_embeds
977-
978972 def run_safety_checker (self , image , device , dtype ):
979973 if self .safety_checker is None :
980974 has_nsfw_concept = None
@@ -1302,7 +1296,6 @@ def __call__(
13021296 not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
13031297 image_embeds (`torch.FloatTensor`, *optional*):
13041298 Pre-generated image embeddings.
1305- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
13061299 output_type (`str`, *optional*, defaults to `"pil"`):
13071300 The output format of the generated image. Choose between `PIL.Image` or `np.array`.
13081301 return_dict (`bool`, *optional*, defaults to `True`):
@@ -1411,7 +1404,7 @@ def __call__(
14111404 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
14121405
14131406 if image_embeds is not None :
1414- image_embeds = image_embeds . repeat_interleave ( num_images_per_prompt , dim = 0 ).to (
1407+ image_embeds = torch . stack ([ image_embeds ] * num_images_per_prompt , dim = 0 ).to (
14151408 device = device , dtype = prompt_embeds .dtype
14161409 )
14171410 negative_image_embeds = torch .zeros_like (image_embeds )
0 commit comments