Skip to content

Commit 8dcaa44

Browse files
committed
adding intialization fixes
1 parent 876caa1 commit 8dcaa44

File tree

1 file changed

+62
-25
lines changed

1 file changed

+62
-25
lines changed

variationalsparsebayes/svi_half_cauchy.py

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@ def __init__(self, mu_init: Tensor, log_sigma_init: Tensor) -> None:
2020
# check that the input shapes are the same
2121
assert len(mu_init) == len(log_sigma_init), "init shapes must be equal."
2222
assert mu_init.shape == log_sigma_init.shape, "init shapes must be equal."
23-
self.d = len(mu_init)
2423
self.total_weights = len(mu_init)
2524
# saving parameters
26-
self.sparse_index = torch.ones(self.d).bool()
25+
self.register_buffer("sparse_index", torch.ones(len(mu_init)).bool())
2726

2827
self.mu = mu_init
2928
self.log_sigma = log_sigma_init
@@ -44,6 +43,10 @@ def log_sigma(self) -> Tensor:
4443
def log_sigma(self, value: Tensor):
4544
self.__log_sigma = Parameter(value)
4645

46+
@property
47+
def d(self) -> int:
48+
return int(self.sparse_index.sum())
49+
4750
def var(self) -> Tensor:
4851
return torch.exp(self.log_sigma).pow(2)
4952

@@ -52,7 +55,6 @@ def update_sparse_index(self, sparse_index: Tensor) -> None:
5255
len(sparse_index) == self.total_weights
5356
), "Sparse index should be a bool array masking unimportant weights."
5457
self.sparse_index = sparse_index
55-
self.d = int(self.sparse_index.sum())
5658

5759
def forward(self, n: int) -> Tensor:
5860
"""
@@ -65,7 +67,7 @@ def forward(self, n: int) -> Tensor:
6567
Tensor: (n,d) reparameterized samples from variational distribution
6668
"""
6769
sigma = torch.exp(self.log_sigma)
68-
return self.mu + torch.randn(n, self.d) * sigma
70+
return self.mu + torch.randn(n, len(sigma), device=sigma.device) * sigma
6971

7072

7173
class LogNormalMeanFieldVariational(NormalMeanFieldVariational):
@@ -95,9 +97,9 @@ def forward(self, n: int) -> Tensor:
9597

9698
class SVIHalfCauchyPrior(nn.Module):
9799
"""
98-
Class for performing sparse Bayesian learning using stochastic variational inference. This class provides
99-
utilities for generating reparameterized samples from the variational distribution and computing the
100-
KL-divergence between the variational distribution and the prior exactly
100+
Class for performing sparse Bayesian learning using stochastic variational inference. This class provides
101+
utilities for generating reparameterized samples from the variational distribution and computing the
102+
KL-divergence between the variational distribution and the prior exactly
101103
102104
Args:
103105
d (int): number of parameters
@@ -116,26 +118,52 @@ class SVIHalfCauchyPrior(nn.Module):
116118

117119
def __init__(self, d: int, tau: Union[Tensor, float], w_init: Tensor = None):
118120
super().__init__()
121+
# fixing gamma parameterization mixup
122+
tau = 1 / math.sqrt(tau)
119123
if isinstance(tau, float):
120124
tau = torch.tensor(tau)
121125
self.register_buffer("tau", tau)
126+
tau_data = torch.tensor(
127+
[1.0, 1 / 1e-1, 1 / 1e-2, 1 / 1e-3, 1 / 1e-4, 1 / 1e-5]
128+
).log()
129+
tau_data = torch.stack([torch.ones(6), tau_data], dim=-1)
130+
mu_data = torch.tensor(
131+
[-1.6932, -6.2983, -10.9035, -15.5087, -20.1138, -24.7190]
132+
)
133+
# scale_data = (tau_data.log() @ mu_data) / (tau_data.log() @ tau_data.log())
134+
scale = torch.linalg.solve(tau_data.t() @ tau_data, (tau_data.t() @ mu_data))
135+
if w_init is not None:
136+
assert len(w_init) == d, "w_init must be a vector of length d."
137+
global_scale_init = -1.6931 * torch.ones(1)
138+
noise = -5.0
139+
else:
140+
# initializing such that kl-divergence is minimized
141+
global_scale_init = scale[0] + scale[1] * tau.log() * torch.ones(1)
142+
noise = 0.3466
143+
# global_scale_init = -1.6931 * torch.ones(1)
144+
# data fit for tau
122145
self.s_a = LogNormalMeanFieldVariational(
123-
torch.zeros(1), -6.0 + torch.randn(1) * 1e-4
146+
global_scale_init, noise + torch.randn(1) * 1e-4
124147
)
125148
self.s_b = LogNormalMeanFieldVariational(
126-
torch.zeros(1), -6.0 + torch.randn(1) * 1e-4
149+
1.6931 * torch.ones(1), noise + torch.randn(1) * 1e-4
127150
)
128151
self.gamma_a = LogNormalMeanFieldVariational(
129-
torch.zeros(d), -6.0 + torch.randn(d) * 1e-4
152+
-1.6931 * torch.ones(d), noise + torch.randn(d) * 1e-4
130153
)
131154
self.gamma_b = LogNormalMeanFieldVariational(
132-
torch.zeros(d), -6.0 + torch.randn(d) * 1e-4
155+
+1.6931 * torch.ones(d), noise + torch.randn(d) * 1e-4
133156
)
134157
if w_init is None:
135-
w_init = torch.randn(d)
136-
self.w_tilde = NormalMeanFieldVariational(w_init, -6.0 + torch.randn(d) * 1e-4)
137-
self.register_buffer("sparse_index", torch.arange(d))
138-
self.pruning_tol = 0.0
158+
w_init = torch.zeros(d)
159+
w_tilde_noise = -0.0
160+
else:
161+
w_tilde_noise = -6.0
162+
self.w_tilde = NormalMeanFieldVariational(
163+
w_init, w_tilde_noise + torch.randn(d) * 1e-6
164+
)
165+
self.register_buffer("sparse_index", torch.ones(d, dtype=torch.bool))
166+
self.register_buffer("purning_tol", torch.tensor(0.0))
139167

140168
def _log_normal_reparam(
141169
self,
@@ -158,15 +186,15 @@ def _log_normal_reparam(
158186
log_sigma_b (Tensor): log stdev of r.v. b
159187
160188
Returns:
161-
Tensor: (n,d) reparmeterized samples
189+
Tensor: (n,d) reparmeterized samples
162190
"""
163191
mu = 0.5 * (mu_a + mu_b)
164192
var = 0.25 * (torch.exp(log_sigma_a).pow(2) + torch.exp(log_sigma_b).pow(2))
165-
return torch.exp(mu + torch.randn(n, d) * var.sqrt())
193+
return torch.exp(mu + torch.randn(n, d, device=mu.device) * var.sqrt())
166194

167195
def get_reparam_weights(self, n: int) -> Tensor:
168196
"""
169-
Generate reparameterized samples
197+
Generate reparameterized samples
170198
171199
Args:
172200
n (int): number of reparam samples
@@ -228,7 +256,7 @@ def kl_divergence(self) -> Tensor:
228256
Computes the KL divergence for the approximating posteriors
229257
230258
Returns:
231-
Tensor: kl divergence
259+
Tensor: kl divergence
232260
"""
233261
kl_sa = self._kl_s_a()
234262
kl_sb = self._kl_s_b()
@@ -240,7 +268,7 @@ def kl_divergence(self) -> Tensor:
240268
def _compute_sparsity_tolerance(self, negative_log_mode: Tensor) -> Tensor:
241269
"""
242270
Provides a reasonable pruning tolerance using the mid range of the
243-
negative log modes
271+
negative log modes
244272
245273
Args:
246274
negative_log_mode (Tensor): negative log mode of the weight est
@@ -252,9 +280,9 @@ def _compute_sparsity_tolerance(self, negative_log_mode: Tensor) -> Tensor:
252280

253281
def update_sparse_index(self) -> Tensor:
254282
"""
255-
Updates the sparse_index by pruning based on the negative log-mode
283+
Updates the sparse_index by pruning based on the negative log-mode
256284
257-
Returns:
285+
Returns:
258286
Tensor: negative log mode for each parameter
259287
"""
260288
mu_zi = 0.5 * (self.s_a.mu + self.s_b.mu + self.gamma_a.mu + self.gamma_b.mu)
@@ -263,9 +291,19 @@ def update_sparse_index(self) -> Tensor:
263291
)
264292
negative_log_mode = var_zi - mu_zi
265293
self.pruning_tol = self._compute_sparsity_tolerance(negative_log_mode)
266-
self.sparse_index = negative_log_mode <= self.pruning_tol
294+
self.sparse_index = (negative_log_mode <= self.pruning_tol).cpu()
295+
self._propogate_sparse_index(self.sparse_index)
296+
return -negative_log_mode + self.pruning_tol
297+
298+
def reset_sparse_index(self) -> None:
299+
"""
300+
Updates the sparse_index by pruning based on the negative log-mode
301+
302+
Returns:
303+
Tensor: negative log mode for each parameter
304+
"""
305+
self.sparse_index = torch.ones(len(self.sparse_index), dtype=torch.bool)
267306
self._propogate_sparse_index(self.sparse_index)
268-
return negative_log_mode
269307

270308
def _propogate_sparse_index(self, sparse_index) -> None:
271309
"""
@@ -280,4 +318,3 @@ def _propogate_sparse_index(self, sparse_index) -> None:
280318
print(2.0)
281319
svi = SVIHalfCauchyPrior(10, torch.tensor(1.0))
282320
print(svi.get_reparam_weights(20).shape)
283-

0 commit comments

Comments
 (0)