@@ -69,7 +69,6 @@ def __init__(
6969 self ._dataloader_params ,
7070 self ._forward_params ,
7171 self ._postprocess_params ,
72- self ._dataset_map_params ,
7372 ) = self ._sanitize_parameters (** kwargs )
7473
7574 def save_pretrained (self , save_directory : str ):
@@ -163,7 +162,7 @@ def _ensure_tensor_on_device(self, inputs, device):
163162
164163 def _sanitize_parameters (
165164 self , ** pipeline_parameters
166- ) -> Tuple [Dict [str , Any ], Dict [str , Any ], Dict [str , Any ], Dict [str , Any ], Dict [ str , Any ] ]:
165+ ) -> Tuple [Dict [str , Any ], Dict [str , Any ], Dict [str , Any ], Dict [str , Any ]]:
167166 """
168167 _sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__`
169168 methods. It should return 4 dictionaries of the resolved parameters used by the various `preprocess`,
@@ -177,7 +176,6 @@ def _sanitize_parameters(
177176 dataloader_params = {}
178177 forward_parameters = {}
179178 postprocess_parameters : Dict [str , Any ] = {}
180- dataset_map_parameters = {}
181179
182180 # set preprocess parameters
183181 field = pipeline_parameters .get ("predict_field" )
@@ -199,16 +197,11 @@ def _sanitize_parameters(
199197 if p_name in pipeline_parameters :
200198 postprocess_parameters [p_name ] = pipeline_parameters [p_name ]
201199
202- for p_name in ["document_batch_size" ]:
203- if p_name in pipeline_parameters :
204- dataset_map_parameters ["batch_size" ] = pipeline_parameters [p_name ]
205-
206200 return (
207201 preprocess_parameters ,
208202 dataloader_params ,
209203 forward_parameters ,
210204 postprocess_parameters ,
211- dataset_map_parameters ,
212205 )
213206
214207 def preprocess (
@@ -342,7 +335,6 @@ def __call__(
342335 dataloader_params ,
343336 forward_params ,
344337 postprocess_params ,
345- dataset_map_params ,
346338 ) = self ._sanitize_parameters (** kwargs )
347339
348340 if "TOKENIZERS_PARALLELISM" not in os .environ :
@@ -356,7 +348,6 @@ def __call__(
356348 dataloader_params = {** self ._dataloader_params , ** dataloader_params }
357349 forward_params = {** self ._forward_params , ** forward_params }
358350 postprocess_params = {** self ._postprocess_params , ** postprocess_params }
359- dataset_map_params = {** self ._dataset_map_params , ** dataset_map_params }
360351
361352 self .call_count += 1
362353 if self .call_count > 10 and self .device .type == "cuda" :
@@ -394,7 +385,6 @@ def __call__(
394385 postprocess_params = postprocess_params ,
395386 ),
396387 batched = True ,
397- ** dataset_map_params ,
398388 )
399389 finally :
400390 if was_caching_enabled :
0 commit comments