11"""Defines the Higher-Order Neural Units (HONU) model."""
22
3- from __future__ import annotations
4-
53import math
64from itertools import combinations_with_replacement
75from typing import Any
86
97import torch
108from torch import Tensor , nn
119
12- __version__ = "0.0.1 "
10+ __version__ = "0.0.2 "
1311
1412
1513class HONU (nn .Module ):
@@ -55,6 +53,10 @@ def __init__(
5553 order (int): Polynomial order of the model.
5654 in_features (int): Number of input features.
5755 _weight_divisor (float): Divisor used to scale the randomly initialized weights.
56+ _weight_init_mode (str): Method for initializing weights, can be "random", "zeros",
57+ "ones", "xavier", "kaiming_normal", or "kaiming_uniform".
58+ _activation (str): Activation function to be used.
59+ _activation_function (callable): The actual activation function to apply.
5860 _bias (bool): Indicates whether a bias term is included in the model.
5961 weight (nn.Parameter): Trainable weights of the model.
6062 _num_combinations (int): Number of polynomial feature combinations.
@@ -71,6 +73,7 @@ def __init__(
7173 msg = f"weight_divisor must be a number or string, got { type (weight_divisor )} "
7274 raise TypeError (msg )
7375 self ._weight_divisor = float (weight_divisor )
76+ self ._weight_init_mode = kwargs .get ("weight_init_mode" , "random" )
7477 self ._bias = kwargs .get ("bias" , True )
7578 self ._activation = activation
7679 if self ._activation in ["identity" , "linear" ]:
@@ -125,6 +128,15 @@ def _initialize_weights(self) -> Tensor:
125128 - n is the number of states, calculated as the input length + 1 if a bias is included.
126129 - r is the polynomial order of the neuron.
127130
131+ The weights are initialized using various methods based on the `_weight_init_mode`:
132+
133+ - "random": Uniformly distributed random values scaled by `_weight_divisor`.
134+ - "zeros": All weights initialized to zero.
135+ - "ones": All weights initialized to one.
136+ - "xavier": Xavier initialization for 1D tensors.
137+ - "kaiming_normal": Kaiming normal initialization for 1D tensors.
138+ - "kaiming_uniform": Kaiming uniform initialization for 1D tensors.
139+
128140 Returns:
129141 Array of initialized weights.
130142 """
@@ -134,8 +146,33 @@ def _initialize_weights(self) -> Tensor:
134146 math .factorial (n_weights + self .order - 1 )
135147 / (math .factorial (self .order ) * math .factorial (n_weights - 1 ))
136148 )
137- # Initialize weights randomly and scale them
138- return torch .rand (num_weights ) / self ._weight_divisor
149+ # Initialize weights using PyTorch native methods
150+ weights = torch .empty (num_weights )
151+
152+ if self ._weight_init_mode == "random" :
153+ torch .nn .init .uniform_ (weights , 0 , 1 )
154+ weights /= self ._weight_divisor
155+ elif self ._weight_init_mode == "zeros" :
156+ torch .nn .init .zeros_ (weights )
157+ elif self ._weight_init_mode == "ones" :
158+ torch .nn .init .ones_ (weights )
159+ elif self ._weight_init_mode == "xavier" :
160+ # For 1D tensors, manually compute Xavier initialization
161+ limit = math .sqrt (6 / (self .in_features + num_weights ))
162+ torch .nn .init .uniform_ (weights , - limit , limit )
163+ elif self ._weight_init_mode == "kaiming_normal" :
164+ # For 1D tensors, manually compute Kaiming normal initialization
165+ std = math .sqrt (2 / self .in_features )
166+ torch .nn .init .normal_ (weights , mean = 0 , std = std )
167+ elif self ._weight_init_mode == "kaiming_uniform" :
168+ # For 1D tensors, manually compute Kaiming uniform initialization
169+ limit = math .sqrt (6 / self .in_features )
170+ torch .nn .init .uniform_ (weights , - limit , limit )
171+ else :
172+ msg = f"Unknown weight initialization mode: { self ._weight_init_mode } "
173+ raise ValueError (msg )
174+
175+ return weights
139176
140177 def _get_combinations (self ) -> Tensor :
141178 """Precompute and return all index combinations for the given input length and order.
@@ -196,6 +233,10 @@ def forward(self, x: Tensor) -> Tensor:
196233 Returns:
197234 Tensor[B, 1]: Output tensor from the model.
198235 """
236+ # Check if input shape matches the expected input length
237+ if x .size (1 ) != self .in_features :
238+ msg = f"Input shape mismatch: expected { self .in_features } features, got { x .size (1 )} ."
239+ raise ValueError (msg )
199240 # Get the polynomial feature map
200241 colx = self ._get_colx (x )
201242
0 commit comments