Skip to content

Commit 58adcdc

Browse files
pszemrajPeter Szemraj
andauthored
Training Utils, Clean up (#3)
* 🐛 Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> * 🎨 Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> * suppress fla FutureWarning Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> * 🔥 remove fused rotary emb code Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> * ✨ model summary, auto tf32 Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> * 📝 document config params Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> * 🔊 more detailed model summary Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> * more closely match samba421m cfg Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> --------- Signed-off-by: Peter Szemraj <peterszemraj+dev@gmail.com> Co-authored-by: Peter Szemraj <peterszemraj+dev@gmail.com>
1 parent 19a0ab6 commit 58adcdc

File tree

7 files changed

+187
-130
lines changed

7 files changed

+187
-130
lines changed

samba_pytorch/config.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,60 @@
1515

1616
@dataclass
1717
class Config:
18-
org: str = "Lightning-AI"
18+
"""Configuration class for SAMBA (Simple Hybrid State Space Models) architecture.
19+
20+
The SAMBA architecture combines Mamba (selective state space model) with
21+
Sliding Window Attention (SWA) and Multi-Layer Perceptrons (MLP) in a layer-wise fashion.
22+
23+
Attributes:
24+
org (str): Organization name, defaults to "samba-pytorch"
25+
name (str): Model name, defaults to "lit-GPT"
26+
block_size (int): Maximum sequence length for the model, defaults to 4096
27+
vocab_size (int): Size of the vocabulary, defaults to 50254
28+
padding_multiple (int): Padding factor for vocab size optimization, defaults to 512
29+
padded_vocab_size (Optional[int]): Actual padded vocabulary size after adjustment
30+
n_layer (int): Number of transformer layers, defaults to 16
31+
n_head (int): Number of attention heads, defaults to 32
32+
n_embd (int): Embedding dimension / hidden state size, defaults to 4096
33+
rotary_percentage (float): Fraction of dimensions to apply rotary embeddings to, defaults to 0.25
34+
parallel_residual (bool): Whether to use parallel residual connections, defaults to True
35+
bias (bool): Whether to include bias terms in linear layers, defaults to True
36+
37+
# SAMBA-specific parameters
38+
local_window (int): Size of sliding window for attention, -1 means full attention
39+
mlp (bool): Whether to include MLP layers, defaults to True
40+
full_per_layer (int): Number of tokens for full attention per layer
41+
mb_per_layer (int): Number of Mamba layers per block
42+
ret_per_layer (int): Number of RetNet layers per block
43+
gla_per_layer (int): Number of GLA (Gated Linear Attention) layers per block
44+
nope (bool): Skip certain layers if True
45+
mamba (bool): Whether to use Mamba layers, defaults to False
46+
sc_attn (bool): Whether to use short convolution in attention, defaults to False
47+
rms_norm (bool): Use RMSNorm instead of LayerNorm, defaults to True
48+
49+
# Performance optimizations
50+
residual_in_fp32 (bool): Keep residual connections in fp32, defaults to True
51+
fused_add_norm (bool): Use fused add+norm operations, defaults to True
52+
mamba_init (bool): Use specialized Mamba initialization, defaults to False
53+
attn_layer_pos (str): Position of attention layers in architecture
54+
n_query_groups (Optional[int]): Number of query groups for grouped-query attention
55+
shared_attention_norm (bool): Share normalization across attention heads, defaults to False
56+
57+
_norm_class (str): Normalization layer class to use ("LayerNorm" or "RMSNorm")
58+
norm_eps (float): Epsilon for normalization layers, defaults to 1e-5
59+
_mlp_class (str): MLP implementation class ("GptNeoxMLP" or "LLaMAMLP")
60+
intermediate_size (Optional[int]): Size of intermediate MLP layers
61+
condense_ratio (int): Ratio for condensing layers, defaults to 1
62+
63+
Key Implementation Details from Paper:
64+
- SAMBA combines Mamba, SWA and MLP through layer-wise interleaving
65+
- Default sliding window size is 2048 tokens
66+
- Uses PreNorm and skip connections for each intermediate layer
67+
- Mamba layers capture time-dependent semantics and provide efficient decoding
68+
- SWA handles complex non-Markovian dependencies
69+
- MLPs handle factual knowledge recall
70+
"""
71+
org: str = "samba-pytorch"
1972
name: str = "lit-GPT"
2073
block_size: int = 4096
2174
vocab_size: int = 50254

samba_pytorch/modules/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
"""Core model component modules."""
22

3-
from samba_pytorch.modules.fused_rotary_embedding import (
4-
ApplyRotaryEmb,
5-
apply_rotary_emb_func,
6-
)
73
from samba_pytorch.modules.gla import GatedLinearAttention
84
from samba_pytorch.modules.mamba_simple import Mamba
95
from samba_pytorch.modules.multiscale_retention import MultiScaleRetention
106
from samba_pytorch.modules.rmsnorm import RMSNorm, rms_norm
117
from samba_pytorch.modules.rotary import RotaryEmbedding, apply_rotary_emb
128

139
__all__ = [
14-
"apply_rotary_emb_func",
15-
"ApplyRotaryEmb",
1610
"GatedLinearAttention",
1711
"Mamba",
1812
"MultiScaleRetention",

samba_pytorch/modules/fused_rotary_embedding.py

Lines changed: 0 additions & 99 deletions
This file was deleted.

samba_pytorch/modules/rmsnorm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
from typing import Optional, Tuple, Union
2+
13
import torch
2-
from torch import nn
34
from einops import rearrange
4-
from typing import Optional, Tuple, Union
5+
from torch import nn
56

67

78
def maybe_align(x: torch.Tensor, alignment_in_bytes: int = 16) -> torch.Tensor:

samba_pytorch/samba.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
# see LICENSE file at https://github.com/Lightning-AI/litgpt/blob/main/LICENSE
66

77
import math
8+
import warnings
89
from functools import partial
910
from typing import Any, List, Optional, Tuple
1011

1112
import torch
1213
import torch.nn as nn
14+
from rotary_embedding_torch import RotaryEmbedding
1315
from torch import Tensor
1416
from typing_extensions import Self
1517
from xformers.ops import SwiGLU
16-
from rotary_embedding_torch import RotaryEmbedding
1718

1819
try:
1920
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
@@ -22,12 +23,12 @@
2223
from causal_conv1d import causal_conv1d_fn
2324
from einops import rearrange
2425

25-
from samba_pytorch.config import Config
26-
27-
from samba_pytorch.modules.gla import GatedLinearAttention
28-
from samba_pytorch.modules.mamba_simple import Mamba
29-
from samba_pytorch.modules.multiscale_retention import MultiScaleRetention
26+
warnings.filterwarnings("ignore", category=FutureWarning, module="fla.ops")
3027

28+
from samba_pytorch.config import Config # noqa
29+
from samba_pytorch.modules.gla import GatedLinearAttention # noqa
30+
from samba_pytorch.modules.mamba_simple import Mamba # noqa
31+
from samba_pytorch.modules.multiscale_retention import MultiScaleRetention # noqa
3132

3233
RoPECache = Tuple[torch.Tensor, torch.Tensor]
3334
KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -74,7 +75,7 @@ def __init__(self, config: Config) -> None:
7475
self.config = config
7576

7677
self.rotary_emb = RotaryEmbedding(
77-
dim=int(config.rotary_percentage * config.head_size), # TODO: validate
78+
dim=int(config.rotary_percentage * config.head_size), # TODO: validate
7879
use_xpos=getattr(config, "use_xpos", False),
7980
interpolate_factor=getattr(config, "interpolate_factor", 1.0),
8081
)
@@ -243,7 +244,7 @@ def forward(
243244

244245
# Initialize rotary embedding variables
245246
if self.config.nope:
246-
rope = None # Set rope to None if config.nope
247+
rope = None # Set rope to None if config.nope
247248
else:
248249
# Using rotary_emb to rotate queries and keys in attention modules
249250
rope = self.rotary_emb
@@ -668,5 +669,3 @@ def __init__(
668669
def forward(self, x: torch.Tensor) -> torch.Tensor:
669670
x = self.swiglu(x)
670671
return x
671-
672-

samba_pytorch/utils.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from functools import partial
1717
from io import BytesIO
1818
from pathlib import Path
19-
from typing import Any, Dict, List, Mapping, Optional, TypeVar, Union
19+
from typing import Any, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
2020

2121
import torch
2222
import torch.nn as nn
@@ -669,3 +669,109 @@ def get_default_supported_precision(training: bool, tpu: bool = False) -> str:
669669
if not torch.cuda.is_available() or torch.cuda.is_bf16_supported():
670670
return "bf16-mixed" if training else "bf16-true"
671671
return "16-mixed" if training else "16-true"
672+
673+
674+
def activate_tf32_if_available():
675+
"""
676+
Check if the GPU supports NVIDIA Ampere or later and enable FP32 in PyTorch if it does.
677+
"""
678+
# Check if CUDA is available
679+
if not torch.cuda.is_available():
680+
warnings.warn("No GPU detected, running on CPU.")
681+
return
682+
683+
try:
684+
device = torch.cuda.current_device()
685+
capability = torch.cuda.get_device_capability(device)
686+
major, minor = capability
687+
688+
# Check if the GPU is Ampere or newer
689+
if major >= 8:
690+
torch.backends.cuda.matmul.allow_tf32 = True
691+
torch.backends.cudnn.allow_tf32 = True
692+
gpu_name = torch.cuda.get_device_name(device)
693+
print(
694+
f"{gpu_name} (compute capability {major}.{minor}) supports NVIDIA Ampere or later, enabled TF32 in PyTorch."
695+
)
696+
else:
697+
gpu_name = torch.cuda.get_device_name(device)
698+
print(
699+
f"{gpu_name} (compute capability {major}.{minor}) does not support NVIDIA Ampere or later."
700+
)
701+
702+
except Exception as e:
703+
warnings.warn(f"Error occurred while checking GPU: {e}")
704+
705+
706+
def model_summary(
707+
model: nn.Module, max_depth: int = 4, show_input_size: bool = False
708+
) -> None:
709+
"""
710+
Prints an accurate summary of the model, avoiding double-counting of parameters.
711+
712+
:param PreTrainedModel model: torch model to summarize
713+
:param int max_depth: maximum depth of the model to print, defaults to 4
714+
:param bool show_input_size: whether to show input size for each layer, defaults to False
715+
"""
716+
717+
def format_params(num_params: int) -> str:
718+
return f"{num_params:,}" if num_params > 0 else "--"
719+
720+
def format_size(size: Optional[List[int]]) -> str:
721+
return "x".join(str(x) for x in size) if size else "N/A"
722+
723+
def count_parameters(module: nn.Module) -> Tuple[int, int]:
724+
total_params = sum(p.numel() for p in module.parameters())
725+
trainable_params = sum(
726+
p.numel() for p in module.parameters() if p.requires_grad
727+
)
728+
return total_params, trainable_params
729+
730+
def recursive_summarize(
731+
module: nn.Module, depth: int, idx: List[int], prefix: str = ""
732+
) -> List[Tuple[str, int, int, int, Optional[List[int]], nn.Module]]:
733+
summary = []
734+
735+
total_params, trainable_params = count_parameters(module)
736+
737+
if depth <= max_depth:
738+
layer_name = f"{prefix}{type(module).__name__}"
739+
layer_index = ".".join(map(str, idx))
740+
param_shape = next(
741+
(p.shape for p in module.parameters(recurse=False) if p.requires_grad),
742+
None,
743+
)
744+
summary.append(
745+
(layer_name, depth, total_params, trainable_params, param_shape, module)
746+
)
747+
748+
for i, (name, child) in enumerate(module.named_children(), 1):
749+
child_summary = recursive_summarize(
750+
child, depth + 1, idx + [i], prefix + " "
751+
)
752+
summary.extend(child_summary)
753+
754+
return summary
755+
756+
summary = recursive_summarize(model, 1, [1])
757+
758+
max_name_length = max(len(name) for name, _, _, _, _, _ in summary)
759+
max_shape_length = max(len(format_size(shape)) for _, _, _, _, shape, _ in summary)
760+
761+
print("=" * (max_name_length + 50))
762+
header = f"{'Layer (type:depth-idx)':<{max_name_length}} {'Output Shape':>{max_shape_length}} {'Param #':>12} {'Trainable':>10}"
763+
print(header)
764+
print("=" * (max_name_length + 50))
765+
766+
for name, depth, num_params, trainable_params, shape, _ in summary:
767+
shape_str = format_size(shape) if show_input_size else ""
768+
print(
769+
f"{name:<{max_name_length}} {shape_str:>{max_shape_length}} {format_params(num_params):>12} {str(trainable_params > 0):>10}"
770+
)
771+
772+
total_params, trainable_params = count_parameters(model)
773+
print("=" * (max_name_length + 50))
774+
print(f"Total params: {format_params(total_params)}")
775+
print(f"Trainable params: {format_params(trainable_params)}")
776+
print(f"Non-trainable params: {format_params(total_params - trainable_params)}")
777+
print("=" * (max_name_length + 50))

0 commit comments

Comments
 (0)