Skip to content

MeshGraphNet Performance: Automaticaly Use transformer engine for LayerNorm. #1036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Aug 4, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
092761d
adding layer norm utils, skipping precommit since ORD is so unstable
coreyjadams Jul 16, 2025
b028649
Add doc strings and type annotations to layer_norm implementation
coreyjadams Jul 16, 2025
efa86eb
merge
coreyjadams Jul 17, 2025
e3ecb1b
Merge branch 'NVIDIA:main' into te-ln
coreyjadams Jul 17, 2025
8da32fd
Merge branch 'NVIDIA:main' into te-ln
coreyjadams Jul 29, 2025
44be5ea
Snapshot layernorm optimizations.
coreyjadams Jul 29, 2025
46cd1ec
Enable dynamic selection of layer norm in MeshGraphNet. Yields a goo…
coreyjadams Jul 29, 2025
45031ff
Remove old code.
coreyjadams Jul 29, 2025
801ce6b
Remove unneeded file.
coreyjadams Jul 29, 2025
c230d2f
Update test to avoid te on CPU
coreyjadams Jul 29, 2025
f188566
Merge branch 'main' into te-ln
coreyjadams Jul 30, 2025
927a385
Update formatting
coreyjadams Jul 30, 2025
918fa35
Update meshgraphnet.py
coreyjadams Jul 31, 2025
a5fbb69
Update meshgraphkan.py
coreyjadams Jul 31, 2025
acb4b6a
Update meshgraphnet.py
coreyjadams Jul 31, 2025
52396d8
Fix ruff formatting
coreyjadams Jul 31, 2025
2ec3d68
Formatting ....
coreyjadams Jul 31, 2025
80bdf1b
Merge branch 'NVIDIA:main' into te-ln
coreyjadams Aug 1, 2025
74de9c7
Address PR feedback:
coreyjadams Aug 1, 2025
9c263de
Update tests: env modification coming through a fixture now.
coreyjadams Aug 1, 2025
017cd1e
Address graphcast too: use a fixture instead of contexts.
coreyjadams Aug 1, 2025
ec577ea
Fix layer norm tests too.
coreyjadams Aug 1, 2025
064e067
Merge branch 'main' into te-ln
coreyjadams Aug 1, 2025
d5c44b7
Fix a test
coreyjadams Aug 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 20 additions & 39 deletions physicsnemo/models/gnn_layers/mesh_graph_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Optional, Tuple, Union

import torch
Expand All @@ -22,17 +23,11 @@
from torch import Tensor
from torch.autograd.function import once_differentiable

from .utils import GraphType, concat_efeat, sum_efeat

try:
from transformer_engine import pytorch as te

te_imported = True
except ImportError:
te_imported = False

from physicsnemo.models.layers.layer_norm import get_layer_norm_class
from physicsnemo.utils.profiling import profile

from .utils import GraphType, concat_efeat, sum_efeat


class CustomSiLuLinearAutogradFunction(torch.autograd.Function):
"""Custom SiLU + Linear autograd function"""
Expand Down Expand Up @@ -151,21 +146,14 @@ def __init__(

self.norm_type = norm_type
if norm_type is not None:
if norm_type not in [
"LayerNorm",
"TELayerNorm",
]:
raise ValueError(
f"Invalid norm type {norm_type}. Supported types are LayerNorm and TELayerNorm."
)
if norm_type == "TELayerNorm" and te_imported:
norm_layer = te.LayerNorm
elif norm_type == "TELayerNorm" and not te_imported:
raise ValueError(
"TELayerNorm requires transformer-engine to be installed."
)
else:
norm_layer = getattr(nn, norm_type)
warnings.warn(
"The MeshGraphNet 'norm_type' argument is deprecated and will be removed in a future release."
"In the future, transformer engine will be used automatically for layer norm, if it is installed."
"Override this behavior by setting the PHYSICSNEMO_FORCE_TE environment variable to 'False'.",
DeprecationWarning,
stacklevel=2,
)
norm_layer = get_layer_norm_class()
layers.append(norm_layer(output_dim))

self.model = nn.Sequential(*layers)
Expand Down Expand Up @@ -360,21 +348,14 @@ def __init__(

self.norm_type = norm_type
if norm_type is not None:
if norm_type not in [
"LayerNorm",
"TELayerNorm",
]:
raise ValueError(
f"Invalid norm type {norm_type}. Supported types are LayerNorm and TELayerNorm."
)
if norm_type == "TELayerNorm" and te_imported:
norm_layer = te.LayerNorm
elif norm_type == "TELayerNorm" and not te_imported:
raise ValueError(
"TELayerNorm requires transformer-engine to be installed."
)
else:
norm_layer = getattr(nn, norm_type)
warnings.warn(
"The MeshGraphNet 'norm_type' argument is deprecated and will be removed in a future release."
"In the future, transformer engine will be used automatically for layer norm, if it is installed."
"Override this behavior by setting the PHYSICSNEMO_FORCE_TE environment variable to 'False'.",
DeprecationWarning,
stacklevel=2,
)
norm_layer = get_layer_norm_class()
layers.append(norm_layer(output_dim))

self.model = nn.Sequential(*layers)
Expand Down
154 changes: 154 additions & 0 deletions physicsnemo/models/layers/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
from torch import nn

try:
import transformer_engine.pytorch as te

TE_AVAILABLE = True
except ImportError:
TE_AVAILABLE = False


def remove_extra_state_hook_for_torch(
module: nn.Module,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: list,
unexpected_keys: list,
error_msgs: list,
) -> None:
"""
Pre-hook to remove Transformer Engine's extra state from the state_dict when loading into a PyTorch LayerNorm.

This function scans the state_dict for any keys that match the pattern '{prefix}norm._extra_state'
and removes them. These keys are specific to Transformer Engine's LayerNorm and are not needed
(and may cause errors) when loading into a standard PyTorch LayerNorm.

Args:
module (nn.Module): The module into which the state_dict is being loaded.
state_dict (dict): The state dictionary being loaded.
prefix (str): The prefix for parameters in this module.
local_metadata (dict): Metadata for this module.
strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict function.
missing_keys (list): List of missing keys.
unexpected_keys (list): List of unexpected keys.
error_msgs (list): List of error messages.
"""
# Go through the state dict, and for any keys that have
# prefix + "norm._extra_state", remove those.
# They are extra from transformer engine and not needed in the
# torch layernorm.
keys_to_remove = [
key for key in state_dict if key.startswith(prefix + "_extra_state")
]
for key in keys_to_remove:
del state_dict[key]


def ignore_missing_extra_state_key(
module: nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys
) -> None:
"""
Post-hook to ignore missing 'ln.norm._extra_state' key when loading state_dict.

This function removes 'ln.norm._extra_state' from the list of missing keys in
the IncompatibleKeys object. This is useful when loading a checkpoint saved
from a Transformer Engine LayerNorm into a PyTorch LayerNorm, where this extra
state is not present or needed.

Args:
module (nn.Module): The module into which the state_dict is being loaded.
incompatible_keys: An object with a 'missing_keys' attribute (typically torch.nn.modules.module._IncompatibleKeys).
"""
# Remove 'ln.norm._extra_state' from the missing keys:
problem_key = "ln._extra_state"
if problem_key in incompatible_keys.missing_keys:
incompatible_keys.missing_keys.remove(problem_key)


def get_layer_norm_class() -> nn.Module:
"""
Dynamically pick the layer norm provider based on availability of transformer engine.
If transformer engine is available, it will use the transformer engine implementation of
LayerNorm. Otherwise, it will use the pytorch implementation of LayerNorm.

Override the default behavior by setting the PHYSICSNEMO_FORCE_TE environment variable.
"""

# This is to allow users to force the use of TE or pytorch layer norm
force_te_setting = os.environ.get("PHYSICSNEMO_FORCE_TE")
te_available = (
TE_AVAILABLE # make a local copy to avoid changing the global variable
)

# Can't use transformer engine without cuda:
if not torch.cuda.is_available():
te_available = False

# Let the users force the setting no matter what:
if force_te_setting is not None:
if force_te_setting.lower() == "true" or force_te_setting.lower() == "1":
te_available = True
elif force_te_setting.lower() == "false" or force_te_setting.lower() == "0":
te_available = False

if te_available:
base = te.LayerNorm
else:
base = nn.LayerNorm

class LayerNorm(base):
"""
Wrapper around layer norm utilities.

This class will default to using the transformer engine implementation of
LayerNorm - it is significantly faster in the backwards pass.

If transformer engine is not available, it will fall back to the
pytorch implementation of LayerNorm.

Additionally, this class registers pre or post hooks to allow you to
train with / without transformer engine, and run inference
with / without transformer engine.

.. note::
Transformer engine adds additional state parameters that affect
fp8 stability. **Do NOT** switch from transformer engine to pytorch
or from pytorch to transformer engine with a checkpoint if you
are using fp8 precision in the layer norm regions.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if te_available:
self.register_load_state_dict_post_hook(ignore_missing_extra_state_key)
else:
self.register_load_state_dict_pre_hook(
remove_extra_state_hook_for_torch
)

return LayerNorm


LayerNorm = get_layer_norm_class()
7 changes: 7 additions & 0 deletions physicsnemo/models/meshgraphnet/meshgraphkan.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ class MeshGraphKAN(Module):

Example
-------
>>> # `norm_layer` is in MGM Layers is deprecated,
>>> # TE will be automatically used if possible unless told otherwise.
>>> # (You don't have to set this varialbe, it's faster to use TE!)
>>> # Example of how to disable:
>>> import os
>>> os.environ['PHYSICSNEMO_FORCE_TE'] = 'False'
>>>
>>> model = MeshGraphKAN(
... input_dim_nodes=4,
... input_dim_edges=3,
Expand Down
7 changes: 7 additions & 0 deletions physicsnemo/models/meshgraphnet/meshgraphnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ class MeshGraphNet(Module):

Example
-------
>>> # `norm_layer` is in MGN is deprecated,
>>> # TE will be automatically used if possible unless told otherwise.
>>> # (You don't have to set this varialbe, it's faster to use TE!)
>>> # Example of how to disable:
>>> import os
>>> os.environ['PHYSICSNEMO_FORCE_TE'] = 'False'
>>>
>>> model = physicsnemo.models.meshgraphnet.MeshGraphNet(
... input_dim_nodes=4,
... input_dim_edges=3,
Expand Down
57 changes: 0 additions & 57 deletions test/distributed/distributed_utils_for_testing.py

This file was deleted.

2 changes: 1 addition & 1 deletion test/distributed/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest
import torch
from distributed_utils_for_testing import modify_environment
from pytest_utils import modify_environment

from physicsnemo.distributed import DistributedManager
from physicsnemo.distributed.autograd import (
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_ball_query_shard_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)


from distributed_utils_for_testing import modify_environment # noqa: E402
from pytest_utils import modify_environment # noqa: E402
from test_shard_tensor_initialization import init_dist
from torch.distributed.tensor import distribute_module # noqa: E402
from torch.distributed.tensor.placement_types import ( # noqa: E402
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest
import torch
from distributed_utils_for_testing import modify_environment
from pytest_utils import modify_environment

from physicsnemo.distributed import (
DistributedManager,
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_distributed_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import torch
import torch.distributed as dist
from distributed_utils_for_testing import modify_environment
from pytest_utils import modify_environment

from physicsnemo.distributed import DistributedManager
from physicsnemo.distributed.fft import DistributedRFFT2
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest
import torch
from distributed_utils_for_testing import modify_environment
from pytest_utils import modify_environment

from physicsnemo.distributed import (
DistributedManager,
Expand Down
2 changes: 1 addition & 1 deletion test/distributed/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pytest
import torch
from distributed_utils_for_testing import modify_environment
from pytest_utils import modify_environment

from physicsnemo.distributed import (
DistributedManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
allow_module_level=True,
)

from distributed_utils_for_testing import modify_environment
from pytest_utils import modify_environment
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import Replicate

Expand Down
Loading