2424import  re 
2525from  typing  import  Any , Protocol 
2626
27+ from  absl  import  flags 
2728from  absl  import  logging 
2829import  jax 
2930from  recml .core .utils  import  types 
@@ -162,12 +163,23 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
162163      Defaults to False. 
163164    seed: An optional seed to use for deterministic shuffling / preprocessing. 
164165      Defaults to None. 
165-     tf_data_service_address: An optional URI of a tf.data service to offload 
166-       preprocessing onto during training. The URI should be in the format 
167-       "protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no 
168-       data service will be applied. 
166+     enable_tf_data_service: Whether to apply tf.data service for this dataset. 
167+       If True, flag `tf_data_service_address` must be set. 
169168    tf_data_service_policy: Sharding policy to use for tf.data service when it 
170169      is enabled. 
170+     tf_data_service_job_name: Job name to use for tf.data service. If None, the 
171+       default job name will be used. 
172+     offload_preprocessing_to_tf_data_service: Whether to offload preprocessing 
173+       to tf.data service. If True, enable_tf_data_service must also be True, and 
174+       the preprocessing transformation will be offloaded to tf data service 
175+       workers. Otherwise, the preprocessing transformation will be applied on 
176+       the host CPU. If tf data service is not enabled, this arg must be set 
177+       False. Defaults to False. 
178+     tf_data_service_replicate_on_split: Whether to replicate the file dataset on 
179+       split when distributing data to tf.data service workers. Note: it could be 
180+       used in the case where multiple datasets are processed together under 
181+       `Dynamic` mode. The dataset with `tf_data_service_replicate_on_split` 
182+       enabled is equivalent to having that dataset processed as `Off` mode. 
171183    feature_spec: A mapping of feature keys to `FixedLenFeature`, 
172184      `VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be 
173185      used to parse the TF examples, or as context_features spec to parse TF 
@@ -208,7 +220,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
208220      tensorflow. 
209221    debug: An optional boolean indicating whether to debug input boundedness. If 
210222      `True`, the dataset will consist of a single batch that's cached and 
211-       infinitely repeated 
223+       infinitely repeated.  
212224  """ 
213225
214226  cache_reading : bool  =  False 
@@ -231,10 +243,12 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
231243  readahead : str  |  None  =  None 
232244  group_uris_by_dir : bool  =  False 
233245  seed : int  |  None  =  None 
234-   tf_data_service_address : str  |  None  =  None 
246+   enable_tf_data_service : bool  =  False 
247+   tf_data_service_job_name : str  |  None  =  None 
235248  tf_data_service_policy : tf .data .experimental .service .ShardingPolicy  =  (
236249      tf .data .experimental .service .ShardingPolicy .OFF 
237250  )
251+   offload_preprocessing_to_tf_data_service : bool  =  False 
238252  feature_spec : Mapping [str , IO_Feature ] |  None  =  None 
239253  sequence_feature_spec : Mapping [str , IO_Feature ] |  None  =  None 
240254  tf_transform_output : TFTransformOutput  |  None  =  None 
@@ -246,14 +260,26 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
246260  sharding_info : DatasetShardingInfo  =  dataclasses .field (
247261      default_factory = DatasetShardingInfo 
248262  )
263+   tf_data_service_replicate_on_split : bool  =  False 
249264  debug : bool  =  False 
250265
251266  def  __post_init__ (self ):
252-     if  self .tf_data_service_address  is  not None :
267+     if  self .enable_tf_data_service :
268+       if  flags .FLAGS .tf_data_service_address  is  None :
269+         raise  ValueError (
270+             "Flag `tf_data_service_address` must be set when" 
271+             " `enable_tf_data_service` is True." 
272+         )
253273      if  self .seed  is  not None :
254274        raise  ValueError ("`seed` must be None for data service." )
255275      if  self .sharding :
256276        raise  ValueError ("`sharding` must be set to False for data service." )
277+     else :
278+       if  self .offload_preprocessing_to_tf_data_service :
279+         raise  ValueError (
280+             "`offload_preprocessing_to_tf_data_service` must be False when" 
281+             " `enable_tf_data_service` is False." 
282+         )
257283
258284  @functools .cached_property  
259285  def  tfds_metadata (self ) ->  TFDSMetadata  |  None :
@@ -464,6 +490,9 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
464490    # Create a dataset of file / file group uris. 
465491    dataset  =  tf .data .Dataset .from_tensor_slices (uris )
466492
493+     if  self .tf_data_service_replicate_on_split :
494+       dataset  =  tf .data .apply_rewrite (dataset , rewrite = "replicate_on_split" )
495+ 
467496    # Repeat the dataset. We might need to repeat the dataset here in case the 
468497    # issue is encountered: internal screenshot link:6jAKKoEMT3afXRe 
469498    # even we do have enough shards for the input data. 
@@ -533,23 +562,26 @@ def _maybe_apply_tf_data_service(
533562      self , dataset : tf .data .Dataset 
534563  ) ->  tf .data .Dataset :
535564    """Applies the tf.data service to the dataset.""" 
536-     if  self .tf_data_service_address   is   None :
565+     if  not   self .enable_tf_data_service :
537566      return  dataset 
538567
568+     tf_data_service_address  =  flags .FLAGS .tf_data_service_address 
569+ 
539570    per_proc_batch_size  =  self .sharding_info .per_process_batch_size (
540571        self .global_batch_size 
541572    )
542573    logging .info (
543574        "Applying tf.data service with address %s and per replica batch" 
544575        " size %s" ,
545-         self . tf_data_service_address ,
576+         tf_data_service_address ,
546577        per_proc_batch_size ,
547578    )
548579    return  dataset .apply (
549580        tf .data .experimental .service .distribute (
550581            processing_mode = self .tf_data_service_policy ,
551-             service = self .tf_data_service_address ,
552-             job_name = f"bs_{ per_proc_batch_size }  ,
582+             service = tf_data_service_address ,
583+             job_name = self .tf_data_service_job_name 
584+             or  "tf_data_service_shared_job_name" ,
553585        )
554586    )
555587
@@ -566,12 +598,18 @@ def make(self) -> tf.data.Dataset:
566598    dataset  =  self ._parse_dataset (dataset )
567599    # Apply filters to the batched dataset. 
568600    dataset  =  self ._maybe_filter_dataset (dataset )
569-     # Apply data service. 
570-     dataset  =  self ._maybe_apply_tf_data_service (dataset )
601+     # Apply TF Data service before preprocessing. 
602+     if  not  self .offload_preprocessing_to_tf_data_service :
603+       dataset  =  self ._maybe_apply_tf_data_service (dataset )
604+ 
571605    # Apply transformations on the dataset. 
572606    for  fn  in  self .map_fns :
573607      dataset  =  dataset .map (fn , num_parallel_calls = self .num_parallel_threads )
574608
609+     # Apply TF Data Service after preprocessing. 
610+     if  self .offload_preprocessing_to_tf_data_service :
611+       dataset  =  self ._maybe_apply_tf_data_service (dataset )
612+ 
575613    if  self .debug :
576614      dataset  =  dataset .take (1 ).cache ().repeat ()
577615
@@ -778,8 +816,7 @@ def _vectorized_filter(features: FeaturesDictType) -> FeaturesDictType:
778816      if  isinstance (features [name ], tf .SparseTensor ):
779817        outputs [name ] =  tf .sparse_boolean_mask (features [name ], mask )
780818      elif  isinstance (features [name ], tf .RaggedTensor ):
781-         # TODO(b/307323524): Support this when we start using Ragged tensors. 
782-         raise  ValueError ("Filtering ragged tensors is not supported." )
819+         outputs [name ] =  tf .ragged .boolean_mask (features [name ], mask )
783820      else :
784821        outputs [name ] =  tf .boolean_mask (features [name ], mask )
785822    return  outputs 
0 commit comments