Skip to content

Commit 0c98fbd

Browse files
[Serialization] Add is_main_process argument to save_torch_state_dict() (#2648)
* Add is_main_process flag * Update tests comments * Fix failing test * fix typos
1 parent 0deb17f commit 0c98fbd

File tree

2 files changed

+50
-11
lines changed

2 files changed

+50
-11
lines changed

src/huggingface_hub/serialization/_torch.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def save_torch_model(
4141
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
4242
metadata: Optional[Dict[str, str]] = None,
4343
safe_serialization: bool = True,
44+
is_main_process: bool = True,
4445
):
4546
"""
4647
Saves a given torch model to disk, handling sharding and shared tensors issues.
@@ -88,6 +89,10 @@ def save_torch_model(
8889
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
8990
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
9091
in a future version.
92+
is_main_process (`bool`, *optional*):
93+
Whether the process calling this is the main process or not. Useful when in distributed training like
94+
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
95+
the main process to avoid race conditions. Defaults to True.
9196
9297
Example:
9398
@@ -112,6 +117,7 @@ def save_torch_model(
112117
metadata=metadata,
113118
safe_serialization=safe_serialization,
114119
save_directory=save_directory,
120+
is_main_process=is_main_process,
115121
)
116122

117123

@@ -124,6 +130,7 @@ def save_torch_state_dict(
124130
max_shard_size: Union[int, str] = MAX_SHARD_SIZE,
125131
metadata: Optional[Dict[str, str]] = None,
126132
safe_serialization: bool = True,
133+
is_main_process: bool = True,
127134
) -> None:
128135
"""
129136
Save a model state dictionary to the disk, handling sharding and shared tensors issues.
@@ -171,7 +178,10 @@ def save_torch_state_dict(
171178
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
172179
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
173180
in a future version.
174-
181+
is_main_process (`bool`, *optional*):
182+
Whether the process calling this is the main process or not. Useful when in distributed training like
183+
TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
184+
the main process to avoid race conditions. Defaults to True.
175185
Example:
176186
177187
```py
@@ -222,15 +232,18 @@ def save_torch_state_dict(
222232
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
223233
)
224234

225-
# Clean the folder from previous save
226-
existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
227-
for filename in os.listdir(save_directory):
228-
if existing_files_regex.match(filename):
229-
try:
230-
logger.debug(f"Removing existing file '{filename}' from folder.")
231-
os.remove(os.path.join(save_directory, filename))
232-
except Exception as e:
233-
logger.warning(f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing...")
235+
# Only main process should clean up existing files to avoid race conditions in distributed environment
236+
if is_main_process:
237+
existing_files_regex = re.compile(filename_pattern.format(suffix=r"(-\d{5}-of-\d{5})?") + r"(\.index\.json)?")
238+
for filename in os.listdir(save_directory):
239+
if existing_files_regex.match(filename):
240+
try:
241+
logger.debug(f"Removing existing file '{filename}' from folder.")
242+
os.remove(os.path.join(save_directory, filename))
243+
except Exception as e:
244+
logger.warning(
245+
f"Error when trying to remove existing '{filename}' from folder: {e}. Continuing..."
246+
)
234247

235248
# Save each shard
236249
per_file_metadata = {"format": "pt"}
@@ -442,7 +455,7 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
442455
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
443456

444457
if is_traceable_wrapper_subclass(tensor):
445-
return _get_unique_id(tensor)
458+
return _get_unique_id(tensor) # type: ignore
446459
except ImportError:
447460
# for torch version less than 2.1, we can fallback to original implementation
448461
pass

tests/test_serialization.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None:
264264
max_shard_size="3GB",
265265
metadata={"foo": "bar"},
266266
safe_serialization=True,
267+
is_main_process=True,
267268
)
268269
safe_state_dict_mock.assert_called_once_with(
269270
state_dict=model_mock.state_dict.return_value,
@@ -273,6 +274,7 @@ def test_save_torch_model(mocker: MockerFixture, tmp_path: Path) -> None:
273274
max_shard_size="3GB",
274275
metadata={"foo": "bar"},
275276
safe_serialization=True,
277+
is_main_process=True,
276278
)
277279

278280

@@ -472,3 +474,27 @@ def test_save_torch_state_dict_delete_existing_files(
472474
assert (tmp_path / "pytorch_model-00001-of-00003.bin").is_file()
473475
assert (tmp_path / "pytorch_model-00002-of-00003.bin").is_file()
474476
assert (tmp_path / "pytorch_model-00003-of-00003.bin").is_file()
477+
478+
479+
def test_save_torch_state_dict_not_main_process(
480+
tmp_path: Path,
481+
torch_state_dict: Dict[str, "torch.Tensor"],
482+
) -> None:
483+
"""
484+
Test that previous files in the directory are not deleted when is_main_process=False.
485+
When is_main_process=True, previous files should be deleted,
486+
this is already tested in `test_save_torch_state_dict_delete_existing_files`.
487+
"""
488+
# Create some .safetensors files before saving a new state dict.
489+
(tmp_path / "model.safetensors").touch()
490+
(tmp_path / "model-00001-of-00002.safetensors").touch()
491+
(tmp_path / "model-00002-of-00002.safetensors").touch()
492+
(tmp_path / "model.safetensors.index.json").touch()
493+
# Save with is_main_process=False
494+
save_torch_state_dict(torch_state_dict, tmp_path, is_main_process=False)
495+
496+
# Previous files should still exist (not deleted)
497+
assert (tmp_path / "model.safetensors").is_file()
498+
assert (tmp_path / "model-00001-of-00002.safetensors").is_file()
499+
assert (tmp_path / "model-00002-of-00002.safetensors").is_file()
500+
assert (tmp_path / "model.safetensors.index.json").is_file()

0 commit comments

Comments
 (0)