@@ -20,10 +20,9 @@ def __init__(self, mu_init: Tensor, log_sigma_init: Tensor) -> None:
20
20
# check that the input shapes are the same
21
21
assert len (mu_init ) == len (log_sigma_init ), "init shapes must be equal."
22
22
assert mu_init .shape == log_sigma_init .shape , "init shapes must be equal."
23
- self .d = len (mu_init )
24
23
self .total_weights = len (mu_init )
25
24
# saving parameters
26
- self .sparse_index = torch .ones (self . d ) .bool ()
25
+ self .register_buffer ( " sparse_index" , torch .ones (len ( mu_init )) .bool () )
27
26
28
27
self .mu = mu_init
29
28
self .log_sigma = log_sigma_init
@@ -44,6 +43,10 @@ def log_sigma(self) -> Tensor:
44
43
def log_sigma (self , value : Tensor ):
45
44
self .__log_sigma = Parameter (value )
46
45
46
+ @property
47
+ def d (self ) -> int :
48
+ return int (self .sparse_index .sum ())
49
+
47
50
def var (self ) -> Tensor :
48
51
return torch .exp (self .log_sigma ).pow (2 )
49
52
@@ -52,7 +55,6 @@ def update_sparse_index(self, sparse_index: Tensor) -> None:
52
55
len (sparse_index ) == self .total_weights
53
56
), "Sparse index should be a bool array masking unimportant weights."
54
57
self .sparse_index = sparse_index
55
- self .d = int (self .sparse_index .sum ())
56
58
57
59
def forward (self , n : int ) -> Tensor :
58
60
"""
@@ -65,7 +67,7 @@ def forward(self, n: int) -> Tensor:
65
67
Tensor: (n,d) reparameterized samples from variational distribution
66
68
"""
67
69
sigma = torch .exp (self .log_sigma )
68
- return self .mu + torch .randn (n , self . d ) * sigma
70
+ return self .mu + torch .randn (n , len ( sigma ), device = sigma . device ) * sigma
69
71
70
72
71
73
class LogNormalMeanFieldVariational (NormalMeanFieldVariational ):
@@ -95,9 +97,9 @@ def forward(self, n: int) -> Tensor:
95
97
96
98
class SVIHalfCauchyPrior (nn .Module ):
97
99
"""
98
- Class for performing sparse Bayesian learning using stochastic variational inference. This class provides
99
- utilities for generating reparameterized samples from the variational distribution and computing the
100
- KL-divergence between the variational distribution and the prior exactly
100
+ Class for performing sparse Bayesian learning using stochastic variational inference. This class provides
101
+ utilities for generating reparameterized samples from the variational distribution and computing the
102
+ KL-divergence between the variational distribution and the prior exactly
101
103
102
104
Args:
103
105
d (int): number of parameters
@@ -116,26 +118,52 @@ class SVIHalfCauchyPrior(nn.Module):
116
118
117
119
def __init__ (self , d : int , tau : Union [Tensor , float ], w_init : Tensor = None ):
118
120
super ().__init__ ()
121
+ # fixing gamma parameterization mixup
122
+ tau = 1 / math .sqrt (tau )
119
123
if isinstance (tau , float ):
120
124
tau = torch .tensor (tau )
121
125
self .register_buffer ("tau" , tau )
126
+ tau_data = torch .tensor (
127
+ [1.0 , 1 / 1e-1 , 1 / 1e-2 , 1 / 1e-3 , 1 / 1e-4 , 1 / 1e-5 ]
128
+ ).log ()
129
+ tau_data = torch .stack ([torch .ones (6 ), tau_data ], dim = - 1 )
130
+ mu_data = torch .tensor (
131
+ [- 1.6932 , - 6.2983 , - 10.9035 , - 15.5087 , - 20.1138 , - 24.7190 ]
132
+ )
133
+ # scale_data = (tau_data.log() @ mu_data) / (tau_data.log() @ tau_data.log())
134
+ scale = torch .linalg .solve (tau_data .t () @ tau_data , (tau_data .t () @ mu_data ))
135
+ if w_init is not None :
136
+ assert len (w_init ) == d , "w_init must be a vector of length d."
137
+ global_scale_init = - 1.6931 * torch .ones (1 )
138
+ noise = - 5.0
139
+ else :
140
+ # initializing such that kl-divergence is minimized
141
+ global_scale_init = scale [0 ] + scale [1 ] * tau .log () * torch .ones (1 )
142
+ noise = 0.3466
143
+ # global_scale_init = -1.6931 * torch.ones(1)
144
+ # data fit for tau
122
145
self .s_a = LogNormalMeanFieldVariational (
123
- torch . zeros ( 1 ), - 6.0 + torch .randn (1 ) * 1e-4
146
+ global_scale_init , noise + torch .randn (1 ) * 1e-4
124
147
)
125
148
self .s_b = LogNormalMeanFieldVariational (
126
- torch .zeros (1 ), - 6.0 + torch .randn (1 ) * 1e-4
149
+ 1.6931 * torch .ones (1 ), noise + torch .randn (1 ) * 1e-4
127
150
)
128
151
self .gamma_a = LogNormalMeanFieldVariational (
129
- torch .zeros (d ), - 6.0 + torch .randn (d ) * 1e-4
152
+ - 1.6931 * torch .ones (d ), noise + torch .randn (d ) * 1e-4
130
153
)
131
154
self .gamma_b = LogNormalMeanFieldVariational (
132
- torch .zeros (d ), - 6.0 + torch .randn (d ) * 1e-4
155
+ + 1.6931 * torch .ones (d ), noise + torch .randn (d ) * 1e-4
133
156
)
134
157
if w_init is None :
135
- w_init = torch .randn (d )
136
- self .w_tilde = NormalMeanFieldVariational (w_init , - 6.0 + torch .randn (d ) * 1e-4 )
137
- self .register_buffer ("sparse_index" , torch .arange (d ))
138
- self .pruning_tol = 0.0
158
+ w_init = torch .zeros (d )
159
+ w_tilde_noise = - 0.0
160
+ else :
161
+ w_tilde_noise = - 6.0
162
+ self .w_tilde = NormalMeanFieldVariational (
163
+ w_init , w_tilde_noise + torch .randn (d ) * 1e-6
164
+ )
165
+ self .register_buffer ("sparse_index" , torch .ones (d , dtype = torch .bool ))
166
+ self .register_buffer ("purning_tol" , torch .tensor (0.0 ))
139
167
140
168
def _log_normal_reparam (
141
169
self ,
@@ -158,15 +186,15 @@ def _log_normal_reparam(
158
186
log_sigma_b (Tensor): log stdev of r.v. b
159
187
160
188
Returns:
161
- Tensor: (n,d) reparmeterized samples
189
+ Tensor: (n,d) reparmeterized samples
162
190
"""
163
191
mu = 0.5 * (mu_a + mu_b )
164
192
var = 0.25 * (torch .exp (log_sigma_a ).pow (2 ) + torch .exp (log_sigma_b ).pow (2 ))
165
- return torch .exp (mu + torch .randn (n , d ) * var .sqrt ())
193
+ return torch .exp (mu + torch .randn (n , d , device = mu . device ) * var .sqrt ())
166
194
167
195
def get_reparam_weights (self , n : int ) -> Tensor :
168
196
"""
169
- Generate reparameterized samples
197
+ Generate reparameterized samples
170
198
171
199
Args:
172
200
n (int): number of reparam samples
@@ -228,7 +256,7 @@ def kl_divergence(self) -> Tensor:
228
256
Computes the KL divergence for the approximating posteriors
229
257
230
258
Returns:
231
- Tensor: kl divergence
259
+ Tensor: kl divergence
232
260
"""
233
261
kl_sa = self ._kl_s_a ()
234
262
kl_sb = self ._kl_s_b ()
@@ -240,7 +268,7 @@ def kl_divergence(self) -> Tensor:
240
268
def _compute_sparsity_tolerance (self , negative_log_mode : Tensor ) -> Tensor :
241
269
"""
242
270
Provides a reasonable pruning tolerance using the mid range of the
243
- negative log modes
271
+ negative log modes
244
272
245
273
Args:
246
274
negative_log_mode (Tensor): negative log mode of the weight est
@@ -252,9 +280,9 @@ def _compute_sparsity_tolerance(self, negative_log_mode: Tensor) -> Tensor:
252
280
253
281
def update_sparse_index (self ) -> Tensor :
254
282
"""
255
- Updates the sparse_index by pruning based on the negative log-mode
283
+ Updates the sparse_index by pruning based on the negative log-mode
256
284
257
- Returns:
285
+ Returns:
258
286
Tensor: negative log mode for each parameter
259
287
"""
260
288
mu_zi = 0.5 * (self .s_a .mu + self .s_b .mu + self .gamma_a .mu + self .gamma_b .mu )
@@ -263,9 +291,19 @@ def update_sparse_index(self) -> Tensor:
263
291
)
264
292
negative_log_mode = var_zi - mu_zi
265
293
self .pruning_tol = self ._compute_sparsity_tolerance (negative_log_mode )
266
- self .sparse_index = negative_log_mode <= self .pruning_tol
294
+ self .sparse_index = (negative_log_mode <= self .pruning_tol ).cpu ()
295
+ self ._propogate_sparse_index (self .sparse_index )
296
+ return - negative_log_mode + self .pruning_tol
297
+
298
+ def reset_sparse_index (self ) -> None :
299
+ """
300
+ Updates the sparse_index by pruning based on the negative log-mode
301
+
302
+ Returns:
303
+ Tensor: negative log mode for each parameter
304
+ """
305
+ self .sparse_index = torch .ones (len (self .sparse_index ), dtype = torch .bool )
267
306
self ._propogate_sparse_index (self .sparse_index )
268
- return negative_log_mode
269
307
270
308
def _propogate_sparse_index (self , sparse_index ) -> None :
271
309
"""
@@ -280,4 +318,3 @@ def _propogate_sparse_index(self, sparse_index) -> None:
280
318
print (2.0 )
281
319
svi = SVIHalfCauchyPrior (10 , torch .tensor (1.0 ))
282
320
print (svi .get_reparam_weights (20 ).shape )
283
-
0 commit comments