Skip to content

[Fix] Add PyTorch 2.6+ compatibility fixes #1654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 75 additions & 6 deletions mmengine/optim/optimizer/zero_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,74 @@
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

try:
from torch.distributed.optim import \
ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer
except ImportError:
_ZeroRedundancyOptimizer = object
# Handle PyTorch 2.6+ compatibility issues with distributed optimizers

from .builder import OPTIMIZERS

def _safe_import_zero_optimizer():
"""Safely import ZeroRedundancyOptimizer to avoid JIT compilation issues.

Starting from PyTorch 2.6.0, JIT compilation issues can occur when
importing torch.distributed.optim. This function provides a safe import
mechanism with fallback options.
"""
try:
# PyTorch 2.6+ introduced changes that can cause JIT compilation issues
# when importing torch.distributed.optim. Apply safe import for 2.6+
if digit_version(TORCH_VERSION) >= digit_version('2.6.0'):
import os
import torch

# Strategy: Use dynamic import with JIT disabled
# Save original state
old_jit_enabled = os.environ.get('PYTORCH_JIT', '1')
old_jit_disable = os.environ.get('PYTORCH_JIT_DISABLE', '0')

# Disable JIT compilation
os.environ['PYTORCH_JIT'] = '0'
os.environ['PYTORCH_JIT_DISABLE'] = '1'

try:
# Try to disable JIT via torch.jit if available
if hasattr(torch.jit, 'set_enabled'):
old_jit_torch_enabled = torch.jit.is_enabled()
torch.jit.set_enabled(False)
else:
old_jit_torch_enabled = None

try:
# Import with JIT disabled
from torch.distributed.optim import \
ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer
return _ZeroRedundancyOptimizer
finally:
# Restore torch.jit state
if (old_jit_torch_enabled is not None and
hasattr(torch.jit, 'set_enabled')):
torch.jit.set_enabled(old_jit_torch_enabled)
finally:
# Restore environment variables
os.environ['PYTORCH_JIT'] = old_jit_enabled
os.environ['PYTORCH_JIT_DISABLE'] = old_jit_disable
else:
from torch.distributed.optim import \
ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer
return _ZeroRedundancyOptimizer
except (ImportError, RuntimeError, AttributeError) as e:
# If import fails due to JIT compilation or other issues, return object
import warnings
warnings.warn(
f"Failed to import ZeroRedundancyOptimizer from "
f"torch.distributed.optim. This is likely due to PyTorch "
f"version compatibility issues. ZeroRedundancyOptimizer will "
f"not be available. Error: {e}",
UserWarning
)
return object


_ZeroRedundancyOptimizer = _safe_import_zero_optimizer()

from .builder import OPTIMIZERS # noqa: E402


@OPTIMIZERS.register_module()
Expand Down Expand Up @@ -57,6 +118,14 @@ def __init__(self, params, optimizer_type: str, **kwargs):
'`torch.distributed.optim.ZeroReundancyOptimizer` is only '
'available when pytorch version >= 1.8.')
assert is_available(), 'torch.distributed.rpc is not available.'

# Check if ZeroRedundancyOptimizer is actually available
if _ZeroRedundancyOptimizer is object:
raise ImportError(
'ZeroRedundancyOptimizer is not available. This might be '
'due to PyTorch version compatibility issues. Please check '
'if your PyTorch version is compatible with MMEngine.'
)
# Avoid the generator becoming empty after the following check
params = list(params)
assert (
Expand Down
62 changes: 57 additions & 5 deletions mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,58 @@
from mmengine.model import BaseTTAModel, is_model_wrapper
from mmengine.utils import (apply_to, deprecated_function, digit_version,
mkdir_or_exist)
from mmengine.utils.dl_utils import load_url
from mmengine.utils.dl_utils import load_url, TORCH_VERSION


def _safe_torch_load(file, map_location=None, weights_only=None):
"""Safe torch.load that handles PyTorch 2.6+ weights_only compatibility.

PyTorch 2.6+ changed the default value of weights_only from False to True
for security reasons. This function provides backward compatibility by
automatically handling the parameter based on PyTorch version.

Args:
file: File path or file-like object to load from
map_location: Device to load the checkpoint to
weights_only: Whether to load only weights. If None, auto-detect
based on PyTorch version

Returns:
The loaded checkpoint
"""
# Auto-detect weights_only behavior for PyTorch 2.6+
if weights_only is None:
if digit_version(TORCH_VERSION) >= digit_version('2.6.0'):
# For PyTorch 2.6+, first try with weights_only=True
# If that fails, fall back to weights_only=False for compatibility
try:
# Add safe globals for common numpy operations
import torch.serialization
safe_globals = [
'numpy._core.multiarray._reconstruct',
'numpy.core.multiarray._reconstruct',
'numpy.dtype',
'numpy.ndarray',
'builtins.slice',
'collections.OrderedDict',
]

with torch.serialization.safe_globals(safe_globals):
return torch.load(file, map_location=map_location,
weights_only=True)
except Exception:
# If weights_only=True fails, fall back to weights_only=False
# This is safe for checkpoints from trusted sources
return torch.load(file, map_location=map_location,
weights_only=False)
else:
# For older PyTorch versions, use default behavior
return torch.load(file, map_location=map_location)
else:
# Use explicit weights_only setting
return torch.load(file, map_location=map_location,
weights_only=weights_only)


# `MMENGINE_HOME` is the highest priority directory to save checkpoints
# downloaded from Internet. If it is not set, as a workaround, using
Expand Down Expand Up @@ -344,7 +395,7 @@ def load_from_local(filename, map_location):
filename = osp.expanduser(filename)
if not osp.isfile(filename):
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location)
checkpoint = _safe_torch_load(filename, map_location=map_location)
return checkpoint


Expand Down Expand Up @@ -412,7 +463,8 @@ def load_from_pavi(filename, map_location=None):
with TemporaryDirectory() as tmp_dir:
downloaded_file = osp.join(tmp_dir, model.name)
model.download(downloaded_file)
checkpoint = torch.load(downloaded_file, map_location=map_location)
checkpoint = _safe_torch_load(downloaded_file,
map_location=map_location)
return checkpoint


Expand All @@ -435,7 +487,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
file_backend = get_file_backend(
filename, backend_args={'backend': backend})
with io.BytesIO(file_backend.get(filename)) as buffer:
checkpoint = torch.load(buffer, map_location=map_location)
checkpoint = _safe_torch_load(buffer, map_location=map_location)
return checkpoint


Expand Down Expand Up @@ -504,7 +556,7 @@ def load_from_openmmlab(filename, map_location=None):
filename = osp.join(_get_mmengine_home(), model_url)
if not osp.isfile(filename):
raise FileNotFoundError(f'{filename} can not be found.')
checkpoint = torch.load(filename, map_location=map_location)
checkpoint = _safe_torch_load(filename, map_location=map_location)
return checkpoint


Expand Down