-
Notifications
You must be signed in to change notification settings - Fork 409
MeshGraphNet Performance: Automaticaly Use transformer engine for LayerNorm. #1036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… training speed up on GPU.
/blossom-ci |
FYI - this PR looks bigger than it is. Many tests are updated but most are not really changed: I have some env variables to force the torch version of layernorm on tests that explicitly use CPU. This prevents transformer engine from being used on CPU models. There are new additional tests for layer norm, however, that should get reviewed before merge. |
/blossom-ci |
/blossom-ci |
/blossom-ci |
Update docstring to use torch layernorm (for CPU tests).
Disable TE for docstring tests.
/blossom-ci |
/blossom-ci |
/blossom-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with minor comments!
- remove warnings about deprecation - add warning if env variable PHYSICSNEMO_FORCE_TE is set, but to an unexpected value.
/blossom-ci |
/blossom-ci |
/blossom-ci |
PhysicsNeMo Pull Request
Description
This PR provides performance enhancements for mesh graph net on GPUs by automatically using transformer engine for layer norm. Here's what happens:
A note about test changes: transformer engine is not supported on CPU. So the tests are modified to let physicsnemo select the layernorm backend optimally for performance, except if the test is a CPU test: then it is forced to torch instead of TE. MGN tests therefore use both torch and TE backends during testing, and both show agreement after restoring from file.
On synthetic data, the performance improvement is pretty good on large graphs. I measured with PyG and DGL, up to 200k nodes and up to 1M+ edges (more goes out of memory). At small graphs, the performance is similar. For large graphs, transformer engine is better - especially during training.
Additionally, PyG is faster than DGL in all measurements.
Including performance measurements here in the PR for posterity.
Float 32
Training

Inference

Float16
Training

Inference

BFloat16
Training

Inference

Checklist
Dependencies