@@ -511,25 +511,30 @@ def call(
511
511
# === Input extraction and validation ===
512
512
513
513
# Extract text part of the input.
514
- prompts , responses = x ["prompts" ], x ["responses" ]
515
- tf .debugging .assert_shapes ([(prompts , ("N" ,)), (responses , ("N" ,))])
514
+ if isinstance (x , dict ):
515
+ prompts , responses = x ["prompts" ], x ["responses" ]
516
+ tf .debugging .assert_shapes ([(prompts , ("N" ,)), (responses , ("N" ,))])
517
+ images = x .get ("images" , None )
518
+ has_prompt = True
519
+ else :
520
+ responses = tf .convert_to_tensor (x )
521
+ prompts = None
522
+ images = None
523
+ has_prompt = False
516
524
517
525
# Find out if the input is batched/not batched. Uprank if not batched.
518
526
# In other preprocessors, we don't have to do this, but here, all
519
527
# the following logic (indices, etc.) uses tensors with a batch dim.
520
528
# We will squeeze these back at the end.
521
529
batched = True
522
- if isinstance (prompts , str ):
530
+ if isinstance (responses , str ):
523
531
batched = False
524
- prompts = [prompts ]
532
+ prompts = [prompts ] if has_prompt else None
525
533
responses = [responses ]
526
- if isinstance (prompts , tf .Tensor ) and len (prompts .shape ) == 0 :
534
+ if isinstance (responses , tf .Tensor ) and len (responses .shape ) == 0 :
527
535
batched = False
528
- prompts = tf .expand_dims (prompts , axis = 0 )
529
- responses = tf .expand_dims (responses , axis = 0 )
530
-
531
- # Extract images from the input.
532
- images = x .get ("images" , None )
536
+ prompts = prompts [None ] if has_prompt else None
537
+ responses = responses [None ]
533
538
534
539
# There are 8 cases, based on values of
535
540
# a = `self.text_only_model`, b = `images` is `None`, and whether
@@ -563,18 +568,20 @@ def call(
563
568
# === Tokenization, padding, etc. ===
564
569
565
570
# Tokenise the inputs.
566
- prompts = self .tokenizer (prompts )
567
- responses = self .tokenizer (responses )
571
+ if has_prompt :
572
+ segments = (self .tokenizer (prompts ), self .tokenizer (responses ))
573
+ else :
574
+ segments = (self .tokenizer (responses ),)
568
575
569
576
# Padding.
570
577
token_ids , segment_ids = self .packer (
571
- ( prompts , responses ) ,
578
+ segments ,
572
579
sequence_length = sequence_length + 1 ,
573
580
add_start_value = self .add_start_token ,
574
581
add_end_value = self .add_end_token ,
575
582
)
576
- response_mask = segment_ids == 1
577
583
padding_mask = token_ids != self .tokenizer .pad_token_id
584
+ response_mask = segment_ids == 1 if has_prompt else padding_mask
578
585
579
586
# === Text Model ===
580
587
if self .text_only_model :
@@ -600,7 +607,7 @@ def call(
600
607
601
608
# === Vision processing ===
602
609
603
- batch_size = tf .shape (prompts )[0 ]
610
+ batch_size = tf .shape (responses )[0 ]
604
611
desired_height = self .image_converter .image_size [0 ]
605
612
desired_width = self .image_converter .image_size [1 ]
606
613
if images is None :
0 commit comments