Skip to content

Commit 29261df

Browse files
zucchini-nlpLysandreJik
authored andcommitted
Processor load with multi-processing (#40786)
push
1 parent 694410d commit 29261df

File tree

3 files changed

+66
-45
lines changed

3 files changed

+66
-45
lines changed

src/transformers/feature_extraction_utils.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
logging,
4545
requires_backends,
4646
)
47-
from .utils.hub import cached_files
47+
from .utils.hub import cached_file
4848

4949

5050
if TYPE_CHECKING:
@@ -506,20 +506,27 @@ def get_feature_extractor_dict(
506506
feature_extractor_file = FEATURE_EXTRACTOR_NAME
507507
try:
508508
# Load from local folder or from cache or download from model Hub and cache
509-
resolved_feature_extractor_files = cached_files(
510-
pretrained_model_name_or_path,
511-
filenames=[feature_extractor_file, PROCESSOR_NAME],
512-
cache_dir=cache_dir,
513-
force_download=force_download,
514-
proxies=proxies,
515-
resume_download=resume_download,
516-
local_files_only=local_files_only,
517-
subfolder=subfolder,
518-
token=token,
519-
user_agent=user_agent,
520-
revision=revision,
521-
_raise_exceptions_for_missing_entries=False,
522-
)
509+
resolved_feature_extractor_files = [
510+
resolved_file
511+
for filename in [feature_extractor_file, PROCESSOR_NAME]
512+
if (
513+
resolved_file := cached_file(
514+
pretrained_model_name_or_path,
515+
filename=filename,
516+
cache_dir=cache_dir,
517+
force_download=force_download,
518+
proxies=proxies,
519+
resume_download=resume_download,
520+
local_files_only=local_files_only,
521+
subfolder=subfolder,
522+
token=token,
523+
user_agent=user_agent,
524+
revision=revision,
525+
_raise_exceptions_for_missing_entries=False,
526+
)
527+
)
528+
is not None
529+
]
523530
resolved_feature_extractor_file = resolved_feature_extractor_files[0]
524531
except OSError:
525532
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to

src/transformers/image_processing_base.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
is_remote_url,
3535
logging,
3636
)
37-
from .utils.hub import cached_files
37+
from .utils.hub import cached_file
3838

3939

4040
ImageProcessorType = TypeVar("ImageProcessorType", bound="ImageProcessingMixin")
@@ -330,20 +330,27 @@ def get_image_processor_dict(
330330
image_processor_file = image_processor_filename
331331
try:
332332
# Load from local folder or from cache or download from model Hub and cache
333-
resolved_image_processor_files = cached_files(
334-
pretrained_model_name_or_path,
335-
filenames=[image_processor_file, PROCESSOR_NAME],
336-
cache_dir=cache_dir,
337-
force_download=force_download,
338-
proxies=proxies,
339-
resume_download=resume_download,
340-
local_files_only=local_files_only,
341-
token=token,
342-
user_agent=user_agent,
343-
revision=revision,
344-
subfolder=subfolder,
345-
_raise_exceptions_for_missing_entries=False,
346-
)
333+
resolved_image_processor_files = [
334+
resolved_file
335+
for filename in [image_processor_file, PROCESSOR_NAME]
336+
if (
337+
resolved_file := cached_file(
338+
pretrained_model_name_or_path,
339+
filename=filename,
340+
cache_dir=cache_dir,
341+
force_download=force_download,
342+
proxies=proxies,
343+
resume_download=resume_download,
344+
local_files_only=local_files_only,
345+
token=token,
346+
user_agent=user_agent,
347+
revision=revision,
348+
subfolder=subfolder,
349+
_raise_exceptions_for_missing_entries=False,
350+
)
351+
)
352+
is not None
353+
]
347354
resolved_image_processor_file = resolved_image_processor_files[0]
348355
except OSError:
349356
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to

src/transformers/video_processing_utils.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
is_torchvision_v2_available,
5151
logging,
5252
)
53-
from .utils.hub import cached_files
53+
from .utils.hub import cached_file
5454
from .utils.import_utils import requires
5555
from .video_utils import (
5656
VideoInput,
@@ -683,20 +683,27 @@ def get_video_processor_dict(
683683
try:
684684
# Try to load with a new config name first and if not successfull try with the old file name
685685
# NOTE: we will gradually change to saving all processor configs as nested dict in PROCESSOR_NAME
686-
resolved_video_processor_files = cached_files(
687-
pretrained_model_name_or_path,
688-
filenames=[VIDEO_PROCESSOR_NAME, IMAGE_PROCESSOR_NAME, PROCESSOR_NAME],
689-
cache_dir=cache_dir,
690-
force_download=force_download,
691-
proxies=proxies,
692-
resume_download=resume_download,
693-
local_files_only=local_files_only,
694-
token=token,
695-
user_agent=user_agent,
696-
revision=revision,
697-
subfolder=subfolder,
698-
_raise_exceptions_for_missing_entries=False,
699-
)
686+
resolved_video_processor_files = [
687+
resolved_file
688+
for filename in [VIDEO_PROCESSOR_NAME, IMAGE_PROCESSOR_NAME, PROCESSOR_NAME]
689+
if (
690+
resolved_file := cached_file(
691+
pretrained_model_name_or_path,
692+
filename=filename,
693+
cache_dir=cache_dir,
694+
force_download=force_download,
695+
proxies=proxies,
696+
resume_download=resume_download,
697+
local_files_only=local_files_only,
698+
token=token,
699+
user_agent=user_agent,
700+
revision=revision,
701+
subfolder=subfolder,
702+
_raise_exceptions_for_missing_entries=False,
703+
)
704+
)
705+
is not None
706+
]
700707
resolved_video_processor_file = resolved_video_processor_files[0]
701708
except EnvironmentError:
702709
# Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to

0 commit comments

Comments
 (0)