Skip to content

Commit d904ef8

Browse files
committed
feat: honu weight init options update. Honu dim check in forward call. torch dependency in pyproject.toml
1 parent 02be34b commit d904ef8

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

ghonn_models_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""GHONN - A Python package for Gated Higher Order Neural Networks."""
22

3-
__version__ = "0.1.1"
3+
__version__ = "0.1.2"
44

55
from .core import GHONN, GHONU, HONN, HONU
66
from .datasets import load_example_dataset

ghonn_models_pytorch/core/honu.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
"""Defines the Higher-Order Neural Units (HONU) model."""
22

3-
from __future__ import annotations
4-
53
import math
64
from itertools import combinations_with_replacement
75
from typing import Any
86

97
import torch
108
from torch import Tensor, nn
119

12-
__version__ = "0.0.1"
10+
__version__ = "0.0.2"
1311

1412

1513
class HONU(nn.Module):
@@ -55,6 +53,10 @@ def __init__(
5553
order (int): Polynomial order of the model.
5654
in_features (int): Number of input features.
5755
_weight_divisor (float): Divisor used to scale the randomly initialized weights.
56+
_weight_init_mode (str): Method for initializing weights, can be "random", "zeros",
57+
"ones", "xavier", "kaiming_normal", or "kaiming_uniform".
58+
_activation (str): Activation function to be used.
59+
_activation_function (callable): The actual activation function to apply.
5860
_bias (bool): Indicates whether a bias term is included in the model.
5961
weight (nn.Parameter): Trainable weights of the model.
6062
_num_combinations (int): Number of polynomial feature combinations.
@@ -71,6 +73,7 @@ def __init__(
7173
msg = f"weight_divisor must be a number or string, got {type(weight_divisor)}"
7274
raise TypeError(msg)
7375
self._weight_divisor = float(weight_divisor)
76+
self._weight_init_mode = kwargs.get("weight_init_mode", "random")
7477
self._bias = kwargs.get("bias", True)
7578
self._activation = activation
7679
if self._activation in ["identity", "linear"]:
@@ -125,6 +128,15 @@ def _initialize_weights(self) -> Tensor:
125128
- n is the number of states, calculated as the input length + 1 if a bias is included.
126129
- r is the polynomial order of the neuron.
127130
131+
The weights are initialized using various methods based on the `_weight_init_mode`:
132+
133+
- "random": Uniformly distributed random values scaled by `_weight_divisor`.
134+
- "zeros": All weights initialized to zero.
135+
- "ones": All weights initialized to one.
136+
- "xavier": Xavier initialization for 1D tensors.
137+
- "kaiming_normal": Kaiming normal initialization for 1D tensors.
138+
- "kaiming_uniform": Kaiming uniform initialization for 1D tensors.
139+
128140
Returns:
129141
Array of initialized weights.
130142
"""
@@ -134,8 +146,33 @@ def _initialize_weights(self) -> Tensor:
134146
math.factorial(n_weights + self.order - 1)
135147
/ (math.factorial(self.order) * math.factorial(n_weights - 1))
136148
)
137-
# Initialize weights randomly and scale them
138-
return torch.rand(num_weights) / self._weight_divisor
149+
# Initialize weights using PyTorch native methods
150+
weights = torch.empty(num_weights)
151+
152+
if self._weight_init_mode == "random":
153+
torch.nn.init.uniform_(weights, 0, 1)
154+
weights /= self._weight_divisor
155+
elif self._weight_init_mode == "zeros":
156+
torch.nn.init.zeros_(weights)
157+
elif self._weight_init_mode == "ones":
158+
torch.nn.init.ones_(weights)
159+
elif self._weight_init_mode == "xavier":
160+
# For 1D tensors, manually compute Xavier initialization
161+
limit = math.sqrt(6 / (self.in_features + num_weights))
162+
torch.nn.init.uniform_(weights, -limit, limit)
163+
elif self._weight_init_mode == "kaiming_normal":
164+
# For 1D tensors, manually compute Kaiming normal initialization
165+
std = math.sqrt(2 / self.in_features)
166+
torch.nn.init.normal_(weights, mean=0, std=std)
167+
elif self._weight_init_mode == "kaiming_uniform":
168+
# For 1D tensors, manually compute Kaiming uniform initialization
169+
limit = math.sqrt(6 / self.in_features)
170+
torch.nn.init.uniform_(weights, -limit, limit)
171+
else:
172+
msg = f"Unknown weight initialization mode: {self._weight_init_mode}"
173+
raise ValueError(msg)
174+
175+
return weights
139176

140177
def _get_combinations(self) -> Tensor:
141178
"""Precompute and return all index combinations for the given input length and order.
@@ -196,6 +233,10 @@ def forward(self, x: Tensor) -> Tensor:
196233
Returns:
197234
Tensor[B, 1]: Output tensor from the model.
198235
"""
236+
# Check if input shape matches the expected input length
237+
if x.size(1) != self.in_features:
238+
msg = f"Input shape mismatch: expected {self.in_features} features, got {x.size(1)}."
239+
raise ValueError(msg)
199240
# Get the polynomial feature map
200241
colx = self._get_colx(x)
201242

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ classifiers = [
2121
"Operating System :: OS Independent",
2222
]
2323
dependencies = [
24-
"torch~=2.7.0",
24+
"torch~=2.7.1",
2525
"numpy~=2.2.0",
2626
"pandas~=2.2.0",
2727
]

0 commit comments

Comments
 (0)