From 8bd73e5c7f9f456ed9356c0843df897adda3b16f Mon Sep 17 00:00:00 2001 From: dnn Date: Thu, 10 Jul 2025 10:56:54 +0800 Subject: [PATCH] [Fix] Add PyTorch 2.6+ and 2.7+ compatibility fixes This commit addresses compatibility issues with PyTorch 2.6+ and 2.7+ that cause runtime errors in MMEngine. **PyTorch 2.6+ JIT Compilation Fix:** - Add safe import mechanism for ZeroRedundancyOptimizer in zero_optimizer.py - Temporarily disable JIT compilation during distributed optimizer import - Apply fix for PyTorch >=2.6.0 where JIT compilation issues were introduced - Graceful fallback when distributed optimizers are unavailable - Resolves: RuntimeError during torch.distributed.optim import **PyTorch 2.6+ torch.load weights_only Fix:** - Add _safe_torch_load function in checkpoint.py with automatic version detection - Handle weights_only parameter changes with safe globals for numpy arrays - Fallback to weights_only=False for compatibility with existing checkpoints - Resolves: "Weights only load failed" errors when loading models **Key Features:** - Maintains full backward compatibility with older PyTorch versions - Automatic version detection and appropriate handling - Conservative approach: only applies fixes to versions that need them - Comprehensive error handling and user warnings - Follows MMEngine coding standards **Version Support:** - PyTorch 2.6+ JIT compilation issues handled - PyTorch 2.6+ weights_only parameter changes handled - Full compatibility maintained for PyTorch 1.6-2.5 **Files Changed:** - mmengine/optim/optimizer/zero_optimizer.py: Safe distributed optimizer import - mmengine/runner/checkpoint.py: Safe torch.load with weights_only handling --- mmengine/optim/optimizer/zero_optimizer.py | 81 ++++++++++++++++++++-- mmengine/runner/checkpoint.py | 62 +++++++++++++++-- 2 files changed, 132 insertions(+), 11 deletions(-) diff --git a/mmengine/optim/optimizer/zero_optimizer.py b/mmengine/optim/optimizer/zero_optimizer.py index 0c5630a765..68c71d94fc 100644 --- a/mmengine/optim/optimizer/zero_optimizer.py +++ b/mmengine/optim/optimizer/zero_optimizer.py @@ -7,13 +7,74 @@ from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION -try: - from torch.distributed.optim import \ - ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer -except ImportError: - _ZeroRedundancyOptimizer = object +# Handle PyTorch 2.6+ compatibility issues with distributed optimizers -from .builder import OPTIMIZERS + +def _safe_import_zero_optimizer(): + """Safely import ZeroRedundancyOptimizer to avoid JIT compilation issues. + + Starting from PyTorch 2.6.0, JIT compilation issues can occur when + importing torch.distributed.optim. This function provides a safe import + mechanism with fallback options. + """ + try: + # PyTorch 2.6+ introduced changes that can cause JIT compilation issues + # when importing torch.distributed.optim. Apply safe import for 2.6+ + if digit_version(TORCH_VERSION) >= digit_version('2.6.0'): + import os + import torch + + # Strategy: Use dynamic import with JIT disabled + # Save original state + old_jit_enabled = os.environ.get('PYTORCH_JIT', '1') + old_jit_disable = os.environ.get('PYTORCH_JIT_DISABLE', '0') + + # Disable JIT compilation + os.environ['PYTORCH_JIT'] = '0' + os.environ['PYTORCH_JIT_DISABLE'] = '1' + + try: + # Try to disable JIT via torch.jit if available + if hasattr(torch.jit, 'set_enabled'): + old_jit_torch_enabled = torch.jit.is_enabled() + torch.jit.set_enabled(False) + else: + old_jit_torch_enabled = None + + try: + # Import with JIT disabled + from torch.distributed.optim import \ + ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer + return _ZeroRedundancyOptimizer + finally: + # Restore torch.jit state + if (old_jit_torch_enabled is not None and + hasattr(torch.jit, 'set_enabled')): + torch.jit.set_enabled(old_jit_torch_enabled) + finally: + # Restore environment variables + os.environ['PYTORCH_JIT'] = old_jit_enabled + os.environ['PYTORCH_JIT_DISABLE'] = old_jit_disable + else: + from torch.distributed.optim import \ + ZeroRedundancyOptimizer as _ZeroRedundancyOptimizer + return _ZeroRedundancyOptimizer + except (ImportError, RuntimeError, AttributeError) as e: + # If import fails due to JIT compilation or other issues, return object + import warnings + warnings.warn( + f"Failed to import ZeroRedundancyOptimizer from " + f"torch.distributed.optim. This is likely due to PyTorch " + f"version compatibility issues. ZeroRedundancyOptimizer will " + f"not be available. Error: {e}", + UserWarning + ) + return object + + +_ZeroRedundancyOptimizer = _safe_import_zero_optimizer() + +from .builder import OPTIMIZERS # noqa: E402 @OPTIMIZERS.register_module() @@ -57,6 +118,14 @@ def __init__(self, params, optimizer_type: str, **kwargs): '`torch.distributed.optim.ZeroReundancyOptimizer` is only ' 'available when pytorch version >= 1.8.') assert is_available(), 'torch.distributed.rpc is not available.' + + # Check if ZeroRedundancyOptimizer is actually available + if _ZeroRedundancyOptimizer is object: + raise ImportError( + 'ZeroRedundancyOptimizer is not available. This might be ' + 'due to PyTorch version compatibility issues. Please check ' + 'if your PyTorch version is compatible with MMEngine.' + ) # Avoid the generator becoming empty after the following check params = list(params) assert ( diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 2bf5f50f7c..a9c6be584c 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -20,7 +20,58 @@ from mmengine.model import BaseTTAModel, is_model_wrapper from mmengine.utils import (apply_to, deprecated_function, digit_version, mkdir_or_exist) -from mmengine.utils.dl_utils import load_url +from mmengine.utils.dl_utils import load_url, TORCH_VERSION + + +def _safe_torch_load(file, map_location=None, weights_only=None): + """Safe torch.load that handles PyTorch 2.6+ weights_only compatibility. + + PyTorch 2.6+ changed the default value of weights_only from False to True + for security reasons. This function provides backward compatibility by + automatically handling the parameter based on PyTorch version. + + Args: + file: File path or file-like object to load from + map_location: Device to load the checkpoint to + weights_only: Whether to load only weights. If None, auto-detect + based on PyTorch version + + Returns: + The loaded checkpoint + """ + # Auto-detect weights_only behavior for PyTorch 2.6+ + if weights_only is None: + if digit_version(TORCH_VERSION) >= digit_version('2.6.0'): + # For PyTorch 2.6+, first try with weights_only=True + # If that fails, fall back to weights_only=False for compatibility + try: + # Add safe globals for common numpy operations + import torch.serialization + safe_globals = [ + 'numpy._core.multiarray._reconstruct', + 'numpy.core.multiarray._reconstruct', + 'numpy.dtype', + 'numpy.ndarray', + 'builtins.slice', + 'collections.OrderedDict', + ] + + with torch.serialization.safe_globals(safe_globals): + return torch.load(file, map_location=map_location, + weights_only=True) + except Exception: + # If weights_only=True fails, fall back to weights_only=False + # This is safe for checkpoints from trusted sources + return torch.load(file, map_location=map_location, + weights_only=False) + else: + # For older PyTorch versions, use default behavior + return torch.load(file, map_location=map_location) + else: + # Use explicit weights_only setting + return torch.load(file, map_location=map_location, + weights_only=weights_only) + # `MMENGINE_HOME` is the highest priority directory to save checkpoints # downloaded from Internet. If it is not set, as a workaround, using @@ -344,7 +395,7 @@ def load_from_local(filename, map_location): filename = osp.expanduser(filename) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) + checkpoint = _safe_torch_load(filename, map_location=map_location) return checkpoint @@ -412,7 +463,8 @@ def load_from_pavi(filename, map_location=None): with TemporaryDirectory() as tmp_dir: downloaded_file = osp.join(tmp_dir, model.name) model.download(downloaded_file) - checkpoint = torch.load(downloaded_file, map_location=map_location) + checkpoint = _safe_torch_load(downloaded_file, + map_location=map_location) return checkpoint @@ -435,7 +487,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): file_backend = get_file_backend( filename, backend_args={'backend': backend}) with io.BytesIO(file_backend.get(filename)) as buffer: - checkpoint = torch.load(buffer, map_location=map_location) + checkpoint = _safe_torch_load(buffer, map_location=map_location) return checkpoint @@ -504,7 +556,7 @@ def load_from_openmmlab(filename, map_location=None): filename = osp.join(_get_mmengine_home(), model_url) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) + checkpoint = _safe_torch_load(filename, map_location=map_location) return checkpoint