Skip to content

MeshGraphNet Performance: Automaticaly Use transformer engine for LayerNorm. #1036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Aug 4, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
092761d
adding layer norm utils, skipping precommit since ORD is so unstable
coreyjadams Jul 16, 2025
b028649
Add doc strings and type annotations to layer_norm implementation
coreyjadams Jul 16, 2025
efa86eb
merge
coreyjadams Jul 17, 2025
e3ecb1b
Merge branch 'NVIDIA:main' into te-ln
coreyjadams Jul 17, 2025
8da32fd
Merge branch 'NVIDIA:main' into te-ln
coreyjadams Jul 29, 2025
44be5ea
Snapshot layernorm optimizations.
coreyjadams Jul 29, 2025
46cd1ec
Enable dynamic selection of layer norm in MeshGraphNet. Yields a goo…
coreyjadams Jul 29, 2025
45031ff
Remove old code.
coreyjadams Jul 29, 2025
801ce6b
Remove unneeded file.
coreyjadams Jul 29, 2025
c230d2f
Update test to avoid te on CPU
coreyjadams Jul 29, 2025
f188566
Merge branch 'main' into te-ln
coreyjadams Jul 30, 2025
927a385
Update formatting
coreyjadams Jul 30, 2025
918fa35
Update meshgraphnet.py
coreyjadams Jul 31, 2025
a5fbb69
Update meshgraphkan.py
coreyjadams Jul 31, 2025
acb4b6a
Update meshgraphnet.py
coreyjadams Jul 31, 2025
52396d8
Fix ruff formatting
coreyjadams Jul 31, 2025
2ec3d68
Formatting ....
coreyjadams Jul 31, 2025
80bdf1b
Merge branch 'NVIDIA:main' into te-ln
coreyjadams Aug 1, 2025
74de9c7
Address PR feedback:
coreyjadams Aug 1, 2025
9c263de
Update tests: env modification coming through a fixture now.
coreyjadams Aug 1, 2025
017cd1e
Address graphcast too: use a fixture instead of contexts.
coreyjadams Aug 1, 2025
ec577ea
Fix layer norm tests too.
coreyjadams Aug 1, 2025
064e067
Merge branch 'main' into te-ln
coreyjadams Aug 1, 2025
d5c44b7
Fix a test
coreyjadams Aug 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 20 additions & 39 deletions physicsnemo/models/gnn_layers/mesh_graph_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Optional, Tuple, Union

import torch
Expand All @@ -22,17 +23,11 @@
from torch import Tensor
from torch.autograd.function import once_differentiable

from .utils import GraphType, concat_efeat, sum_efeat

try:
from transformer_engine import pytorch as te

te_imported = True
except ImportError:
te_imported = False

from physicsnemo.models.layers.layer_norm import get_layer_norm_class
from physicsnemo.utils.profiling import profile

from .utils import GraphType, concat_efeat, sum_efeat


class CustomSiLuLinearAutogradFunction(torch.autograd.Function):
"""Custom SiLU + Linear autograd function"""
Expand Down Expand Up @@ -147,21 +142,14 @@ def __init__(

self.norm_type = norm_type
if norm_type is not None:
if norm_type not in [
"LayerNorm",
"TELayerNorm",
]:
raise ValueError(
f"Invalid norm type {norm_type}. Supported types are LayerNorm and TELayerNorm."
)
if norm_type == "TELayerNorm" and te_imported:
norm_layer = te.LayerNorm
elif norm_type == "TELayerNorm" and not te_imported:
raise ValueError(
"TELayerNorm requires transformer-engine to be installed."
)
else:
norm_layer = getattr(nn, norm_type)
warnings.warn(
"The MeshGraphNet 'norm_type' argument is deprecated and will be removed in a future release."
"In the future, transformer engine will be used automatically for layer norm, if it is installed."
"Override this behavior by setting the PHYSICSNEMO_FORCE_TE environment variable to 'False'.",
DeprecationWarning,
stacklevel=2,
)
norm_layer = get_layer_norm_class()
layers.append(norm_layer(output_dim))

self.model = nn.Sequential(*layers)
Expand Down Expand Up @@ -356,21 +344,14 @@ def __init__(

self.norm_type = norm_type
if norm_type is not None:
if norm_type not in [
"LayerNorm",
"TELayerNorm",
]:
raise ValueError(
f"Invalid norm type {norm_type}. Supported types are LayerNorm and TELayerNorm."
)
if norm_type == "TELayerNorm" and te_imported:
norm_layer = te.LayerNorm
elif norm_type == "TELayerNorm" and not te_imported:
raise ValueError(
"TELayerNorm requires transformer-engine to be installed."
)
else:
norm_layer = getattr(nn, norm_type)
warnings.warn(
"The MeshGraphNet 'norm_type' argument is deprecated and will be removed in a future release."
"In the future, transformer engine will be used automatically for layer norm, if it is installed."
"Override this behavior by setting the PHYSICSNEMO_FORCE_TE environment variable to 'False'.",
DeprecationWarning,
stacklevel=2,
)
norm_layer = get_layer_norm_class()
layers.append(norm_layer(output_dim))

self.model = nn.Sequential(*layers)
Expand Down
195 changes: 195 additions & 0 deletions physicsnemo/models/layers/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
from torch import nn

try:
import transformer_engine.pytorch as te

TE_AVAILABLE = True
except ImportError:
TE_AVAILABLE = False


def remove_extra_state_hook_for_torch(
module: nn.Module,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: list,
unexpected_keys: list,
error_msgs: list,
) -> None:
"""
Pre-hook to remove Transformer Engine's extra state from the state_dict when loading into a PyTorch LayerNorm.

This function scans the state_dict for any keys that match the pattern '{prefix}norm._extra_state'
and removes them. These keys are specific to Transformer Engine's LayerNorm and are not needed
(and may cause errors) when loading into a standard PyTorch LayerNorm.

Args:
module (nn.Module): The module into which the state_dict is being loaded.
state_dict (dict): The state dictionary being loaded.
prefix (str): The prefix for parameters in this module.
local_metadata (dict): Metadata for this module.
strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict function.
missing_keys (list): List of missing keys.
unexpected_keys (list): List of unexpected keys.
error_msgs (list): List of error messages.
"""
# Go through the state dict, and for any keys that have
# prefix + "norm._extra_state", remove those.
# They are extra from transformer engine and not needed in the
# torch layernorm.
keys_to_remove = [
key for key in state_dict if key.startswith(prefix + "_extra_state")
]
for key in keys_to_remove:
del state_dict[key]


def ignore_missing_extra_state_key(
module: nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys
) -> None:
"""
Post-hook to ignore missing 'ln.norm._extra_state' key when loading state_dict.

This function removes 'ln.norm._extra_state' from the list of missing keys in
the IncompatibleKeys object. This is useful when loading a checkpoint saved
from a Transformer Engine LayerNorm into a PyTorch LayerNorm, where this extra
state is not present or needed.

Args:
module (nn.Module): The module into which the state_dict is being loaded.
incompatible_keys: An object with a 'missing_keys' attribute (typically torch.nn.modules.module._IncompatibleKeys).
"""
# Remove 'ln.norm._extra_state' from the missing keys:
problem_key = "ln._extra_state"
if problem_key in incompatible_keys.missing_keys:
incompatible_keys.missing_keys.remove(problem_key)


# class LayerNorm(nn.Module):


# def __init__(self, *args, **kwargs):
# super().__init__()

# # This is to allow users to force the use of TE or pytorch layer norm
# force_te_setting = os.environ.get("PHYSICSNEMO_FORCE_TE")
# te_available = (
# TE_AVAILABLE # make a local copy to avoid changing the global variable
# )
# if force_te_setting is not None:
# if force_te_setting.lower() == "true" or force_te_setting.lower() == "1":
# te_available = True
# elif force_te_setting.lower() == "false" or force_te_setting.lower() == "0":
# te_available = False

# self.use_te = te_available

# # TE uses an extra state to manage fp8 scaling
# # It shows up in the state dict, making the two
# # layers incompatiple with each other
# # https://github.com/NVIDIA/TransformerEngine/issues/458

# # As a workaround, we we're loading a te-trained layer norm
# # into torch layer norm, remove that state:

# if self.use_te:
# self.norm = te.LayerNorm(*args, **kwargs)
# self.register_load_state_dict_post_hook(ignore_missing_extra_state_key)
# else:
# self.norm = nn.LayerNorm(*args, **kwargs)
# self.register_load_state_dict_pre_hook(remove_extra_state_hook_for_torch)

# def forward(self, x: torch.Tensor) -> torch.Tensor:
# """
# Pass the layer norm computation onto the sub layer.
# """
# return self.norm(x)


def get_layer_norm_class() -> nn.Module:
"""
Dynamically pick the layer norm provider based on availability of transformer engine.
If transformer engine is available, it will use the transformer engine implementation of
LayerNorm. Otherwise, it will use the pytorch implementation of LayerNorm.

Override the default behavior by setting the PHYSICSNEMO_FORCE_TE environment variable.
"""

# This is to allow users to force the use of TE or pytorch layer norm
force_te_setting = os.environ.get("PHYSICSNEMO_FORCE_TE")
te_available = (
TE_AVAILABLE # make a local copy to avoid changing the global variable
)

# Can't use transformer engine without cuda:
if not torch.cuda.is_available():
te_available = False

# Let the users force the setting no matter what:
if force_te_setting is not None:
if force_te_setting.lower() == "true" or force_te_setting.lower() == "1":
te_available = True
elif force_te_setting.lower() == "false" or force_te_setting.lower() == "0":
te_available = False

if te_available:
base = te.LayerNorm
else:
base = nn.LayerNorm

class LayerNorm(base):
"""
Wrapper around layer norm utilities.

This class will default to using the transformer engine implementation of
LayerNorm - it is significantly faster in the backwards pass.

If transformer engine is not available, it will fall back to the
pytorch implementation of LayerNorm.

Additionally, this class registers pre or post hooks to allow you to
train with / without transformer engine, and run inference
with / without transformer engine.

.. note::
Transformer engine adds additional state parameters that affect
fp8 stability. **Do NOT** switch from transformer engine to pytorch
or from pytorch to transformer engine with a checkpoint if you
are using fp8 precision in the layer norm regions.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if te_available:
self.register_load_state_dict_post_hook(ignore_missing_extra_state_key)
else:
self.register_load_state_dict_pre_hook(
remove_extra_state_hook_for_torch
)

return LayerNorm


LayerNorm = get_layer_norm_class()
42 changes: 0 additions & 42 deletions test/distributed/distributed_utils_for_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,45 +13,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import contextlib
import os


@contextlib.contextmanager
def modify_environment(*remove, **update):
"""
Context manager to allow modification of the environment variables.

Based on the implementation here:
https://stackoverflow.com/questions/2059482/temporarily-modify-the-current-processs-environment

"""

env = os.environ

update = update or {}
remove = remove or []

# Make sure all update values are strings:
update = {k: str(v) for k, v in update.items()}

# Find out which environment variables are updated OR removed
# This compares the keys in both the remove list and update list
# and returns the overlap with current env.
stomped = (set(update.keys()) | set(remove)) & set(env.keys())

# Cache everything getting changed from the default env:
restore_after = {k: env[k] for k in stomped}

# Keep a list of things that need to be purged after:
purge_after = tuple(k for k in update if k not in env)

try:
env.update(update)
[env.pop(k, None) for k in remove]
yield
finally:
env.update(restore_after)
[env.pop(k, None) for k in purge_after]
2 changes: 1 addition & 1 deletion test/distributed/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest
import torch
from distributed_utils_for_testing import modify_environment
from pytest_utils import modify_environment

from physicsnemo.distributed import DistributedManager
from physicsnemo.distributed.autograd import (
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_ball_query_shard_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)


from distributed_utils_for_testing import modify_environment # noqa: E402
from pytest_utils import modify_environment # noqa: E402
from test_shard_tensor_initialization import init_dist
from torch.distributed.tensor import distribute_module # noqa: E402
from torch.distributed.tensor.placement_types import ( # noqa: E402
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest
import torch
from distributed_utils_for_testing import modify_environment
from pytest_utils import modify_environment

from physicsnemo.distributed import (
DistributedManager,
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_distributed_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import torch
import torch.distributed as dist
from distributed_utils_for_testing import modify_environment
from pytest_utils import modify_environment

from physicsnemo.distributed import DistributedManager
from physicsnemo.distributed.fft import DistributedRFFT2
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest
import torch
from distributed_utils_for_testing import modify_environment
from pytest_utils import modify_environment

from physicsnemo.distributed import (
DistributedManager,
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest
import torch
from distributed_utils_for_testing import modify_environment
from pytest_utils import modify_environment

from physicsnemo.distributed import (
DistributedManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
allow_module_level=True,
)

from distributed_utils_for_testing import modify_environment
from pytest_utils import modify_environment
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import Replicate

Expand Down
Loading