Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
b870c49
Refactor GroupNorm and log unmatched state_dict keys
Jun 24, 2025
72416a6
Merge branch 'main' into improve_diffusion
juliusberner Jun 24, 2025
119c983
Merge branch 'main' into improve_diffusion
CharlelieLrt Jun 27, 2025
59a6ff1
Refactor GroupNorm and log unmatched state_dict keys
Jun 24, 2025
f8e01c7
Add changes from MR996
Jul 9, 2025
08f5368
Merge branch 'improve_diffusion' of https://github.com/juliusberner/p…
CharlelieLrt Jul 22, 2025
f07cbf3
Merge branch 'main' into improve_diffusion
CharlelieLrt Jul 22, 2025
133d925
Made load_state_dict method semi-private
CharlelieLrt Jul 22, 2025
ca424a8
Move the attention migration into a load_state_dict pre-hook
CharlelieLrt Jul 31, 2025
bae6d39
Deleted duplicate line in CHANGELOG.md
CharlelieLrt Jul 31, 2025
235b928
Removed warnings in UNetBlock load_state_dict pre-hook
CharlelieLrt Jul 31, 2025
4fed09c
Added test for UNetBlock checkpoint loading from v1.0.1
CharlelieLrt Jul 31, 2025
647528b
Merge branch 'main' into fix-attention-load-state-dict
CharlelieLrt Aug 1, 2025
f70e0b7
Changed tol in test + added new test with fused_conv_bias=True
CharlelieLrt Aug 1, 2025
b105b32
Merge branch 'fix-attention-load-state-dict' of https://github.com/Ch…
CharlelieLrt Aug 1, 2025
49d78a5
Merge branch 'main' into improve_diffusion
CharlelieLrt Aug 1, 2025
6026b40
Updated CHANGELOG.md
CharlelieLrt Aug 2, 2025
66b201c
Updated CHANGELOG.md
CharlelieLrt Aug 2, 2025
ae0bb0e
Merge branch 'main' into improve_diffusion
CharlelieLrt Aug 2, 2025
170aba0
Changed a GroupNorm into get_group_norm
CharlelieLrt Aug 2, 2025
5400790
Merge branch 'improve_diffusion' of https://github.com/juliusberner/p…
CharlelieLrt Aug 2, 2025
135e132
Improved docstring for get_group_norm
CharlelieLrt Aug 2, 2025
ed76f10
Merge branch 'main' into fix-attention-load-state-dict
CharlelieLrt Aug 2, 2025
81725d0
Removed unused test
CharlelieLrt Aug 3, 2025
f963b91
Merge branch 'main' into improve_diffusion
CharlelieLrt Aug 4, 2025
86b83eb
Initial commit of group_norm tests
CharlelieLrt Aug 4, 2025
2a981b4
Merge branch 'improve_diffusion' of https://github.com/juliusberner/p…
CharlelieLrt Aug 4, 2025
c5e8010
Added non-regression test for GroupNorm
CharlelieLrt Aug 5, 2025
385e14f
Merge branch 'main' into improve_diffusion
CharlelieLrt Aug 5, 2025
f2baa23
Fix BC compatibility of GroupNorm
Aug 5, 2025
3c9a69c
Fixed some formatting in group norm + replaced deprecation warning wi…
CharlelieLrt Aug 6, 2025
8268fe3
New tests for GRoupNorm and get_group_norm
CharlelieLrt Aug 6, 2025
c38c8f1
Merge branch 'main' into improve_diffusion
CharlelieLrt Aug 6, 2025
a924f8e
Fixed some tests
CharlelieLrt Aug 6, 2025
4d26964
Merge branch 'main' into fix-attention-load-state-dict
CharlelieLrt Aug 6, 2025
211e9fd
Merge branch 'fix-attention-load-state-dict' into improve_diffusion
CharlelieLrt Aug 6, 2025
743fced
Improvements in UNetBlock docstring
CharlelieLrt Aug 6, 2025
a32c409
Improvements in layers.py docstrings
CharlelieLrt Aug 6, 2025
188b799
Added non-regression test for UNetBlock
CharlelieLrt Aug 7, 2025
0e2e1a2
Removed load_state_dict from UNetBlock
CharlelieLrt Aug 7, 2025
816149a
Refactored group_norm test to use pytest parameterize instead of loops
CharlelieLrt Aug 9, 2025
5d83417
Fix bugs in Attention layer
CharlelieLrt Aug 9, 2025
f0ec057
Some ongoing work on unet_block tests
CharlelieLrt Aug 9, 2025
0eb3f64
Added non-regression checkpoints and data + non-regression test for U…
CharlelieLrt Aug 9, 2025
4866347
Added IDs for group_norm tests
CharlelieLrt Aug 9, 2025
3640bfe
Added new tests for UNet block
CharlelieLrt Aug 11, 2025
96bca9a
Added more param validation in Attention
CharlelieLrt Aug 11, 2025
05debe1
Added tests for new Attention layer
CharlelieLrt Aug 11, 2025
db8a738
Merge branch 'main' into improve_diffusion
CharlelieLrt Aug 11, 2025
f2e850c
Pin C++ backend for attention op
CharlelieLrt Aug 11, 2025
94d4401
Added reference input data for attention tests
CharlelieLrt Aug 11, 2025
e1c8abb
Some files renaming
CharlelieLrt Aug 11, 2025
60b06a8
Reverted back attention to previous implementation
CharlelieLrt Aug 11, 2025
7a2fec7
Updates on new tests
CharlelieLrt Aug 12, 2025
153990a
Updates on new tests
CharlelieLrt Aug 12, 2025
163d1d0
Deleted tests ref data
CharlelieLrt Aug 12, 2025
a166fe4
Group norm test working
CharlelieLrt Aug 12, 2025
07f39d2
Group norm test working
CharlelieLrt Aug 12, 2025
00f39ef
Tests for attention layer passing locally
CharlelieLrt Aug 12, 2025
f3cc71e
Removed backend in attention test
CharlelieLrt Aug 12, 2025
94cc4da
Modified UNetBlock tests
CharlelieLrt Aug 13, 2025
b6291ca
Tests for UNetBlock passing locally
CharlelieLrt Aug 13, 2025
5366a76
Merge branch 'main' into improve_diffusion
CharlelieLrt Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion physicsnemo/launch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,18 @@ def load_checkpoint(
model.load(file_name)
else:
file_to_load = _cache_if_needed(file_name)
model.load_state_dict(torch.load(file_to_load, map_location=device))
missing_keys, unexpected_keys = model.load_state_dict(
torch.load(file_to_load, map_location=device)
)
if missing_keys:
checkpoint_logging.warning(
f"Missing keys when loading {name}: {missing_keys}"
)
if unexpected_keys:
checkpoint_logging.warning(
f"Unexpected keys when loading {name}: {unexpected_keys}"
)

checkpoint_logging.success(
f"Loaded model state dictionary {file_name} to device {device}"
)
Expand Down
1 change: 1 addition & 0 deletions physicsnemo/models/diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Conv2d,
FourierEmbedding,
GroupNorm,
get_group_norm,
Linear,
PositionalEmbedding,
UNetBlock,
Expand Down
4 changes: 2 additions & 2 deletions physicsnemo/models/diffusion/dhariwal_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

from physicsnemo.models.diffusion import (
Conv2d,
GroupNorm,
Linear,
PositionalEmbedding,
UNetBlock,
get_group_norm,
)
from physicsnemo.models.meta import ModelMetaData
from physicsnemo.models.module import Module
Expand Down Expand Up @@ -264,7 +264,7 @@ def __init__(
attention=(res in attn_resolutions),
**block_kwargs,
)
self.out_norm = GroupNorm(num_channels=cout)
self.out_norm = get_group_norm(num_channels=cout)
self.out_conv = Conv2d(
in_channels=cout, out_channels=out_channels, kernel=3, **init_zero
)
Expand Down
136 changes: 76 additions & 60 deletions physicsnemo/models/diffusion/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,64 @@ def forward(self, x):
return x


def get_group_norm(
num_channels: int,
num_groups: int = 32,
min_channels_per_group: int = 4,
eps: float = 1e-5,
use_apex_gn: bool = False,
act: str = None,
amp_mode: bool = False,
):
"""
Utility function to get the GroupNorm layer, either from apex or from torch.

Parameters
----------
num_channels : int
Number of channels in the input tensor.
num_groups : int, optional
Desired number of groups to divide the input channels, by default 32.
This might be adjusted based on the `min_channels_per_group`.
eps : float, optional
A small number added to the variance to prevent division by zero, by default
1e-5.
use_apex_gn : bool, optional
A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout.
Need to set this as False on cpu. Defaults to False.
act : str, optional
The activation function to use when fusing activation with GroupNorm. Defaults to None.
amp_mode : bool, optional
A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False.
Notes
-----
If `num_channels` is not divisible by `num_groups`, the actual number of groups
might be adjusted to satisfy the `min_channels_per_group` condition.
"""

num_groups = min(num_groups, num_channels // min_channels_per_group)
if use_apex_gn and not _is_apex_available:
raise ValueError("'apex' is not installed, set `use_apex_gn=False`")

act = act.lower() if act else act
if use_apex_gn:
return ApexGroupNorm(
num_groups=num_groups,
num_channels=num_channels,
eps=eps,
affine=True,
act=act,
)
else:
return GroupNorm(
num_groups=num_groups,
num_channels=num_channels,
eps=eps,
act=act,
amp_mode=amp_mode,
)


class GroupNorm(torch.nn.Module):
"""
A custom Group Normalization layer implementation.
Expand All @@ -301,22 +359,13 @@ class GroupNorm(torch.nn.Module):

Parameters
----------
num_groups : int
Desired number of groups to divide the input channels.
num_channels : int
Number of channels in the input tensor.
num_groups : int, optional
Desired number of groups to divide the input channels, by default 32.
This might be adjusted based on the `min_channels_per_group`.
min_channels_per_group : int, optional
Minimum channels required per group. This ensures that no group has fewer
channels than this number. By default 4.
eps : float, optional
A small number added to the variance to prevent division by zero, by default
1e-5.
use_apex_gn : bool, optional
A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout.
Need to set this as False on cpu. Defaults to False.
fused_act : bool, optional
Whether to fuse the activation function with GroupNorm. Defaults to False.
act : str, optional
The activation function to use when fusing activation with GroupNorm. Defaults to None.
amp_mode : bool, optional
Expand All @@ -329,61 +378,32 @@ class GroupNorm(torch.nn.Module):

def __init__(
self,
num_groups: int,
num_channels: int,
num_groups: int = 32,
min_channels_per_group: int = 4,
eps: float = 1e-5,
use_apex_gn: bool = False,
fused_act: bool = False,
act: str = None,
amp_mode: bool = False,
):
if fused_act and act is None:
raise ValueError("'act' must be specified when 'fused_act' is set to True.")

super().__init__()
self.num_groups = min(num_groups, num_channels // min_channels_per_group)
self.num_groups = num_groups
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(num_channels))
self.bias = torch.nn.Parameter(torch.zeros(num_channels))
if use_apex_gn and not _is_apex_available:
raise ValueError("'apex' is not installed, set `use_apex_gn=False`")
self.use_apex_gn = use_apex_gn
self.fused_act = fused_act
self.act = act.lower() if act else act
self.act_fn = None
self.amp_mode = amp_mode
if self.use_apex_gn:
if self.act:
self.gn = ApexGroupNorm(
num_groups=self.num_groups,
num_channels=num_channels,
eps=self.eps,
affine=True,
act=self.act,
)

else:
self.gn = ApexGroupNorm(
num_groups=self.num_groups,
num_channels=num_channels,
eps=self.eps,
affine=True,
)
if self.fused_act:
if self.act is not None:
self.act_fn = self.get_activation_function()
self.amp_mode = amp_mode

def forward(self, x):
weight, bias = self.weight, self.bias
if not self.amp_mode:
if not self.use_apex_gn:
if weight.dtype != x.dtype:
weight = self.weight.to(x.dtype)
if bias.dtype != x.dtype:
bias = self.bias.to(x.dtype)
if self.use_apex_gn:
x = self.gn(x)
elif self.training:
if weight.dtype != x.dtype:
weight = self.weight.to(x.dtype)
if bias.dtype != x.dtype:
bias = self.bias.to(x.dtype)

if self.training:
# Use default torch implementation of GroupNorm for training
# This does not support channels last memory format
x = torch.nn.functional.group_norm(
Expand All @@ -393,8 +413,6 @@ def forward(self, x):
bias=bias,
eps=self.eps,
)
if self.fused_act:
x = self.act_fn(x)
else:
# Use custom GroupNorm implementation that supports channels last
# memory layout for inference
Expand All @@ -411,8 +429,8 @@ def forward(self, x):
bias = rearrange(bias, "c -> 1 c 1 1")
x = x * weight + bias

if self.fused_act:
x = self.act_fn(x)
if self.act_fn is not None:
x = self.act_fn(x)
return x

def get_activation_function(self):
Expand Down Expand Up @@ -574,11 +592,10 @@ def __init__(
self.adaptive_scale = adaptive_scale
self.profile_mode = profile_mode
self.amp_mode = amp_mode
self.norm0 = GroupNorm(
self.norm0 = get_group_norm(
num_channels=in_channels,
eps=eps,
use_apex_gn=use_apex_gn,
fused_act=True,
act=act,
amp_mode=amp_mode,
)
Expand All @@ -600,19 +617,18 @@ def __init__(
**init,
)
if self.adaptive_scale:
self.norm1 = GroupNorm(
self.norm1 = get_group_norm(
num_channels=out_channels,
eps=eps,
use_apex_gn=use_apex_gn,
amp_mode=amp_mode,
)
else:
self.norm1 = GroupNorm(
self.norm1 = get_group_norm(
num_channels=out_channels,
eps=eps,
use_apex_gn=use_apex_gn,
act=act,
fused_act=True,
amp_mode=amp_mode,
)
self.conv1 = Conv2d(
Expand Down Expand Up @@ -641,7 +657,7 @@ def __init__(
)

if self.num_heads:
self.norm2 = GroupNorm(
self.norm2 = get_group_norm(
num_channels=out_channels,
eps=eps,
use_apex_gn=use_apex_gn,
Expand Down
8 changes: 5 additions & 3 deletions physicsnemo/models/diffusion/song_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from physicsnemo.models.diffusion import (
Conv2d,
FourierEmbedding,
GroupNorm,
Linear,
PositionalEmbedding,
UNetBlock,
get_group_norm,
)
from physicsnemo.models.meta import ModelMetaData
from physicsnemo.models.module import Module
Expand Down Expand Up @@ -479,7 +479,7 @@ def __init__(
resample_filter=resample_filter,
amp_mode=amp_mode,
)
self.dec[f"{res}x{res}_aux_norm"] = GroupNorm(
self.dec[f"{res}x{res}_aux_norm"] = get_group_norm(
num_channels=cout,
eps=1e-6,
use_apex_gn=use_apex_gn,
Expand Down Expand Up @@ -825,7 +825,9 @@ def __init__(
if self.gridtype == "learnable":
self.pos_embd = self._get_positional_embedding()
else:
self.register_buffer("pos_embd", self._get_positional_embedding().float())
self.register_buffer(
"pos_embd", self._get_positional_embedding().float(), persistent=False
)
self.lead_time_mode = lead_time_mode
if self.lead_time_mode:
self.lead_time_channels = lead_time_channels
Expand Down
35 changes: 28 additions & 7 deletions physicsnemo/models/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import importlib
import inspect
import json
Expand All @@ -30,11 +29,36 @@

import physicsnemo
from physicsnemo.models.meta import ModelMetaData
from physicsnemo.models.util_compatibility import convert_ckp_apex
from physicsnemo.registry import ModelRegistry
from physicsnemo.utils.filesystem import _download_cached, _get_fs


def load_state_dict_with_logging(
module: torch.nn.Module, state_dict: Dict[str, Any], *args, **kwargs
):
"""Load state dictionary and log missing and unexpected keys

Parameters
----------
module : torch.nn.Module
Module to load state dictionary into
state_dict : Dict[str, Any]
State dictionary to load
*args, **kwargs
Additional arguments to pass to load_state_dict
"""
missing_keys, unexpected_keys = module.load_state_dict(state_dict, *args, **kwargs)
if missing_keys:
logging.warning(
f"Missing keys when loading {module.__class__.__name__}: {missing_keys}"
)
if unexpected_keys:
logging.warning(
f"Unexpected keys when loading {module.__class__.__name__}: {unexpected_keys}"
)
return missing_keys, unexpected_keys


class Module(torch.nn.Module):
"""The base class for all network models in PhysicsNeMo.

Expand Down Expand Up @@ -381,7 +405,7 @@ def load(
model_dict = torch.load(
local_path.joinpath("model.pt"), map_location=device
)
self.load_state_dict(model_dict, strict=strict)
load_state_dict_with_logging(self, model_dict, strict=strict)

@classmethod
def from_checkpoint(
Expand Down Expand Up @@ -424,8 +448,6 @@ def from_checkpoint(
with open(local_path.joinpath("args.json"), "r") as f:
args = json.load(f)

ckp_args = copy.deepcopy(args)

# Load metadata to get version
with open(local_path.joinpath("metadata.json"), "r") as f:
metadata = json.load(f)
Expand Down Expand Up @@ -461,8 +483,7 @@ def from_checkpoint(
local_path.joinpath("model.pt"), map_location=model.device
)

model_dict = convert_ckp_apex(ckp_args, model_args, model_dict)
model.load_state_dict(model_dict, strict=False)
load_state_dict_with_logging(model, model_dict, strict=False)
return model

@staticmethod
Expand Down
Loading