Skip to content

Commit 34133d0

Browse files
authored
Fix placeholders replacement logic in auto_docstring (#39433)
Fix and simplify placeholders replacement logic
1 parent 433d2a2 commit 34133d0

File tree

1 file changed

+18
-22
lines changed

1 file changed

+18
-22
lines changed

src/transformers/utils/auto_docstring.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,37 +1145,34 @@ def get_placeholders_dict(placeholders: list, model_name: str) -> dict:
11451145
place_holder_value = getattr(
11461146
getattr(auto_module, PLACEHOLDER_TO_AUTO_MODULE[placeholder][0]),
11471147
PLACEHOLDER_TO_AUTO_MODULE[placeholder][1],
1148-
)[model_name]
1149-
if isinstance(place_holder_value, (list, tuple)):
1150-
place_holder_value = place_holder_value[0]
1151-
placeholders_dict[placeholder] = place_holder_value
1148+
).get(model_name, None)
1149+
if place_holder_value is not None:
1150+
if isinstance(place_holder_value, (list, tuple)):
1151+
place_holder_value = place_holder_value[0]
1152+
placeholders_dict[placeholder] = place_holder_value
1153+
else:
1154+
placeholders_dict[placeholder] = placeholder
11521155

11531156
return placeholders_dict
11541157

11551158

1156-
def format_args_docstring(args, model_name):
1159+
def format_args_docstring(docstring, model_name):
11571160
"""
11581161
Replaces placeholders such as {image_processor_class} in the docstring with the actual values,
11591162
deducted from the model name and the auto modules.
11601163
"""
1161-
# first check if there are any placeholders in the args, if not return them as is
1162-
placeholders = set(re.findall(r"{(.*?)}", "".join(args[arg]["description"] for arg in args)))
1164+
# first check if there are any placeholders in the docstring, if not return it as is
1165+
placeholders = set(re.findall(r"{(.*?)}", docstring))
11631166
if not placeholders:
1164-
return args
1167+
return docstring
11651168

11661169
# get the placeholders dictionary for the given model name
11671170
placeholders_dict = get_placeholders_dict(placeholders, model_name)
1171+
# replace the placeholders in the docstring with the values from the placeholders_dict
1172+
for placeholder, value in placeholders_dict.items():
1173+
docstring = docstring.replace(f"{{{placeholder}}}", value)
11681174

1169-
# replace the placeholders in the args with the values from the placeholders_dict
1170-
for arg in args:
1171-
new_arg = args[arg]["description"]
1172-
placeholders = re.findall(r"{(.*?)}", new_arg)
1173-
placeholders = [placeholder for placeholder in placeholders if placeholder in placeholders_dict]
1174-
if placeholders:
1175-
new_arg = new_arg.format(**{placeholder: placeholders_dict[placeholder] for placeholder in placeholders})
1176-
args[arg]["description"] = new_arg
1177-
1178-
return args
1175+
return docstring
11791176

11801177

11811178
def get_args_doc_from_source(args_classes: Union[object, list[object]]) -> dict:
@@ -1494,8 +1491,6 @@ def _process_kwargs_parameters(
14941491
kwargs_documentation = kwarg_param.annotation.__args__[0].__doc__
14951492
if kwargs_documentation is not None:
14961493
documented_kwargs, _ = parse_docstring(kwargs_documentation)
1497-
if model_name_lowercase is not None:
1498-
documented_kwargs = format_args_docstring(documented_kwargs, model_name_lowercase)
14991494

15001495
# Process each kwarg parameter
15011496
for param_name, param_type_annotation in kwarg_param.annotation.__args__[0].__annotations__.items():
@@ -1573,8 +1568,6 @@ def _process_parameters_section(
15731568
# Parse existing docstring if available
15741569
if func_documentation is not None:
15751570
documented_params, func_documentation = parse_docstring(func_documentation)
1576-
if model_name_lowercase is not None:
1577-
documented_params = format_args_docstring(documented_params, model_name_lowercase)
15781571

15791572
# Process regular parameters
15801573
param_docstring, missing_args = _process_regular_parameters(
@@ -1772,6 +1765,9 @@ def auto_method_docstring(
17721765
)
17731766
docstring += example_docstring
17741767

1768+
# Format the docstring with the placeholders
1769+
docstring = format_args_docstring(docstring, model_name_lowercase)
1770+
17751771
# Assign the dynamically generated docstring to the wrapper function
17761772
func.__doc__ = docstring
17771773
return func

0 commit comments

Comments
 (0)