@@ -41,6 +41,7 @@ def save_torch_model(
41
41
max_shard_size : Union [int , str ] = MAX_SHARD_SIZE ,
42
42
metadata : Optional [Dict [str , str ]] = None ,
43
43
safe_serialization : bool = True ,
44
+ is_main_process : bool = True ,
44
45
):
45
46
"""
46
47
Saves a given torch model to disk, handling sharding and shared tensors issues.
@@ -88,6 +89,10 @@ def save_torch_model(
88
89
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
89
90
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
90
91
in a future version.
92
+ is_main_process (`bool`, *optional*):
93
+ Whether the process calling this is the main process or not. Useful when in distributed training like
94
+ TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
95
+ the main process to avoid race conditions. Defaults to True.
91
96
92
97
Example:
93
98
@@ -112,6 +117,7 @@ def save_torch_model(
112
117
metadata = metadata ,
113
118
safe_serialization = safe_serialization ,
114
119
save_directory = save_directory ,
120
+ is_main_process = is_main_process ,
115
121
)
116
122
117
123
@@ -124,6 +130,7 @@ def save_torch_state_dict(
124
130
max_shard_size : Union [int , str ] = MAX_SHARD_SIZE ,
125
131
metadata : Optional [Dict [str , str ]] = None ,
126
132
safe_serialization : bool = True ,
133
+ is_main_process : bool = True ,
127
134
) -> None :
128
135
"""
129
136
Save a model state dictionary to the disk, handling sharding and shared tensors issues.
@@ -171,7 +178,10 @@ def save_torch_state_dict(
171
178
Whether to save as safetensors, which is the default behavior. If `False`, the shards are saved as pickle.
172
179
Safe serialization is recommended for security reasons. Saving as pickle is deprecated and will be removed
173
180
in a future version.
174
-
181
+ is_main_process (`bool`, *optional*):
182
+ Whether the process calling this is the main process or not. Useful when in distributed training like
183
+ TPUs and need to call this function from all processes. In this case, set `is_main_process=True` only on
184
+ the main process to avoid race conditions. Defaults to True.
175
185
Example:
176
186
177
187
```py
@@ -222,15 +232,18 @@ def save_torch_state_dict(
222
232
state_dict , filename_pattern = filename_pattern , max_shard_size = max_shard_size
223
233
)
224
234
225
- # Clean the folder from previous save
226
- existing_files_regex = re .compile (filename_pattern .format (suffix = r"(-\d{5}-of-\d{5})?" ) + r"(\.index\.json)?" )
227
- for filename in os .listdir (save_directory ):
228
- if existing_files_regex .match (filename ):
229
- try :
230
- logger .debug (f"Removing existing file '{ filename } ' from folder." )
231
- os .remove (os .path .join (save_directory , filename ))
232
- except Exception as e :
233
- logger .warning (f"Error when trying to remove existing '{ filename } ' from folder: { e } . Continuing..." )
235
+ # Only main process should clean up existing files to avoid race conditions in distributed environment
236
+ if is_main_process :
237
+ existing_files_regex = re .compile (filename_pattern .format (suffix = r"(-\d{5}-of-\d{5})?" ) + r"(\.index\.json)?" )
238
+ for filename in os .listdir (save_directory ):
239
+ if existing_files_regex .match (filename ):
240
+ try :
241
+ logger .debug (f"Removing existing file '{ filename } ' from folder." )
242
+ os .remove (os .path .join (save_directory , filename ))
243
+ except Exception as e :
244
+ logger .warning (
245
+ f"Error when trying to remove existing '{ filename } ' from folder: { e } . Continuing..."
246
+ )
234
247
235
248
# Save each shard
236
249
per_file_metadata = {"format" : "pt" }
@@ -442,7 +455,7 @@ def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
442
455
from torch .utils ._python_dispatch import is_traceable_wrapper_subclass
443
456
444
457
if is_traceable_wrapper_subclass (tensor ):
445
- return _get_unique_id (tensor )
458
+ return _get_unique_id (tensor ) # type: ignore
446
459
except ImportError :
447
460
# for torch version less than 2.1, we can fallback to original implementation
448
461
pass
0 commit comments