Skip to content

Commit 1a65c49

Browse files
committed
do not allow parameters for documents.map to simplify pipeline
1 parent e918514 commit 1a65c49

File tree

1 file changed

+1
-11
lines changed

1 file changed

+1
-11
lines changed

src/pytorch_ie/pipeline.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __init__(
6969
self._dataloader_params,
7070
self._forward_params,
7171
self._postprocess_params,
72-
self._dataset_map_params,
7372
) = self._sanitize_parameters(**kwargs)
7473

7574
def save_pretrained(self, save_directory: str):
@@ -163,7 +162,7 @@ def _ensure_tensor_on_device(self, inputs, device):
163162

164163
def _sanitize_parameters(
165164
self, **pipeline_parameters
166-
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
165+
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
167166
"""
168167
_sanitize_parameters will be called with any excessive named arguments from either `__init__` or `__call__`
169168
methods. It should return 4 dictionaries of the resolved parameters used by the various `preprocess`,
@@ -177,7 +176,6 @@ def _sanitize_parameters(
177176
dataloader_params = {}
178177
forward_parameters = {}
179178
postprocess_parameters: Dict[str, Any] = {}
180-
dataset_map_parameters = {}
181179

182180
# set preprocess parameters
183181
field = pipeline_parameters.get("predict_field")
@@ -199,16 +197,11 @@ def _sanitize_parameters(
199197
if p_name in pipeline_parameters:
200198
postprocess_parameters[p_name] = pipeline_parameters[p_name]
201199

202-
for p_name in ["document_batch_size"]:
203-
if p_name in pipeline_parameters:
204-
dataset_map_parameters["batch_size"] = pipeline_parameters[p_name]
205-
206200
return (
207201
preprocess_parameters,
208202
dataloader_params,
209203
forward_parameters,
210204
postprocess_parameters,
211-
dataset_map_parameters,
212205
)
213206

214207
def preprocess(
@@ -342,7 +335,6 @@ def __call__(
342335
dataloader_params,
343336
forward_params,
344337
postprocess_params,
345-
dataset_map_params,
346338
) = self._sanitize_parameters(**kwargs)
347339

348340
if "TOKENIZERS_PARALLELISM" not in os.environ:
@@ -356,7 +348,6 @@ def __call__(
356348
dataloader_params = {**self._dataloader_params, **dataloader_params}
357349
forward_params = {**self._forward_params, **forward_params}
358350
postprocess_params = {**self._postprocess_params, **postprocess_params}
359-
dataset_map_params = {**self._dataset_map_params, **dataset_map_params}
360351

361352
self.call_count += 1
362353
if self.call_count > 10 and self.device.type == "cuda":
@@ -394,7 +385,6 @@ def __call__(
394385
postprocess_params=postprocess_params,
395386
),
396387
batched=True,
397-
**dataset_map_params,
398388
)
399389
finally:
400390
if was_caching_enabled:

0 commit comments

Comments
 (0)