-
Notifications
You must be signed in to change notification settings - Fork 410
MeshGraphNet Performance: Automaticaly Use transformer engine for LayerNorm. #1036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 17 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
092761d
adding layer norm utils, skipping precommit since ORD is so unstable
coreyjadams b028649
Add doc strings and type annotations to layer_norm implementation
coreyjadams efa86eb
merge
coreyjadams e3ecb1b
Merge branch 'NVIDIA:main' into te-ln
coreyjadams 8da32fd
Merge branch 'NVIDIA:main' into te-ln
coreyjadams 44be5ea
Snapshot layernorm optimizations.
coreyjadams 46cd1ec
Enable dynamic selection of layer norm in MeshGraphNet. Yields a goo…
coreyjadams 45031ff
Remove old code.
coreyjadams 801ce6b
Remove unneeded file.
coreyjadams c230d2f
Update test to avoid te on CPU
coreyjadams f188566
Merge branch 'main' into te-ln
coreyjadams 927a385
Update formatting
coreyjadams 918fa35
Update meshgraphnet.py
coreyjadams a5fbb69
Update meshgraphkan.py
coreyjadams acb4b6a
Update meshgraphnet.py
coreyjadams 52396d8
Fix ruff formatting
coreyjadams 2ec3d68
Formatting ....
coreyjadams 80bdf1b
Merge branch 'NVIDIA:main' into te-ln
coreyjadams 74de9c7
Address PR feedback:
coreyjadams 9c263de
Update tests: env modification coming through a fixture now.
coreyjadams 017cd1e
Address graphcast too: use a fixture instead of contexts.
coreyjadams ec577ea
Fix layer norm tests too.
coreyjadams 064e067
Merge branch 'main' into te-ln
coreyjadams d5c44b7
Fix a test
coreyjadams File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. | ||
# SPDX-FileCopyrightText: All rights reserved. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
import torch | ||
from torch import nn | ||
|
||
try: | ||
import transformer_engine.pytorch as te | ||
|
||
TE_AVAILABLE = True | ||
except ImportError: | ||
TE_AVAILABLE = False | ||
|
||
|
||
def remove_extra_state_hook_for_torch( | ||
module: nn.Module, | ||
state_dict: dict, | ||
prefix: str, | ||
local_metadata: dict, | ||
strict: bool, | ||
missing_keys: list, | ||
unexpected_keys: list, | ||
error_msgs: list, | ||
) -> None: | ||
""" | ||
Pre-hook to remove Transformer Engine's extra state from the state_dict when loading into a PyTorch LayerNorm. | ||
|
||
This function scans the state_dict for any keys that match the pattern '{prefix}norm._extra_state' | ||
and removes them. These keys are specific to Transformer Engine's LayerNorm and are not needed | ||
(and may cause errors) when loading into a standard PyTorch LayerNorm. | ||
|
||
Args: | ||
module (nn.Module): The module into which the state_dict is being loaded. | ||
state_dict (dict): The state dictionary being loaded. | ||
prefix (str): The prefix for parameters in this module. | ||
local_metadata (dict): Metadata for this module. | ||
strict (bool): Whether to strictly enforce that the keys in state_dict match the keys returned by this module's state_dict function. | ||
missing_keys (list): List of missing keys. | ||
unexpected_keys (list): List of unexpected keys. | ||
error_msgs (list): List of error messages. | ||
""" | ||
# Go through the state dict, and for any keys that have | ||
# prefix + "norm._extra_state", remove those. | ||
# They are extra from transformer engine and not needed in the | ||
# torch layernorm. | ||
keys_to_remove = [ | ||
key for key in state_dict if key.startswith(prefix + "_extra_state") | ||
] | ||
for key in keys_to_remove: | ||
del state_dict[key] | ||
|
||
|
||
def ignore_missing_extra_state_key( | ||
module: nn.Module, incompatible_keys: torch.nn.modules.module._IncompatibleKeys | ||
) -> None: | ||
""" | ||
Post-hook to ignore missing 'ln.norm._extra_state' key when loading state_dict. | ||
|
||
This function removes 'ln.norm._extra_state' from the list of missing keys in | ||
the IncompatibleKeys object. This is useful when loading a checkpoint saved | ||
from a Transformer Engine LayerNorm into a PyTorch LayerNorm, where this extra | ||
state is not present or needed. | ||
|
||
Args: | ||
module (nn.Module): The module into which the state_dict is being loaded. | ||
incompatible_keys: An object with a 'missing_keys' attribute (typically torch.nn.modules.module._IncompatibleKeys). | ||
""" | ||
# Remove 'ln.norm._extra_state' from the missing keys: | ||
problem_key = "ln._extra_state" | ||
if problem_key in incompatible_keys.missing_keys: | ||
incompatible_keys.missing_keys.remove(problem_key) | ||
|
||
|
||
def get_layer_norm_class() -> nn.Module: | ||
""" | ||
Dynamically pick the layer norm provider based on availability of transformer engine. | ||
If transformer engine is available, it will use the transformer engine implementation of | ||
LayerNorm. Otherwise, it will use the pytorch implementation of LayerNorm. | ||
|
||
Override the default behavior by setting the PHYSICSNEMO_FORCE_TE environment variable. | ||
""" | ||
|
||
# This is to allow users to force the use of TE or pytorch layer norm | ||
force_te_setting = os.environ.get("PHYSICSNEMO_FORCE_TE") | ||
te_available = ( | ||
TE_AVAILABLE # make a local copy to avoid changing the global variable | ||
) | ||
|
||
# Can't use transformer engine without cuda: | ||
if not torch.cuda.is_available(): | ||
te_available = False | ||
|
||
# Let the users force the setting no matter what: | ||
if force_te_setting is not None: | ||
if force_te_setting.lower() == "true" or force_te_setting.lower() == "1": | ||
coreyjadams marked this conversation as resolved.
Show resolved
Hide resolved
|
||
te_available = True | ||
elif force_te_setting.lower() == "false" or force_te_setting.lower() == "0": | ||
te_available = False | ||
|
||
if te_available: | ||
base = te.LayerNorm | ||
else: | ||
base = nn.LayerNorm | ||
|
||
class LayerNorm(base): | ||
""" | ||
Wrapper around layer norm utilities. | ||
|
||
This class will default to using the transformer engine implementation of | ||
LayerNorm - it is significantly faster in the backwards pass. | ||
|
||
If transformer engine is not available, it will fall back to the | ||
pytorch implementation of LayerNorm. | ||
|
||
Additionally, this class registers pre or post hooks to allow you to | ||
train with / without transformer engine, and run inference | ||
with / without transformer engine. | ||
|
||
.. note:: | ||
Transformer engine adds additional state parameters that affect | ||
fp8 stability. **Do NOT** switch from transformer engine to pytorch | ||
or from pytorch to transformer engine with a checkpoint if you | ||
are using fp8 precision in the layer norm regions. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
if te_available: | ||
self.register_load_state_dict_post_hook(ignore_missing_extra_state_key) | ||
else: | ||
self.register_load_state_dict_pre_hook( | ||
remove_extra_state_hook_for_torch | ||
) | ||
|
||
return LayerNorm | ||
|
||
|
||
LayerNorm = get_layer_norm_class() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
coreyjadams marked this conversation as resolved.
Show resolved
Hide resolved
|
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.