Skip to content

Commit 192cde0

Browse files
committed
Update bridging densities to support case where no debugging samples are passed in
1 parent 64c5e75 commit 192cde0

File tree

6 files changed

+114
-89
lines changed

6 files changed

+114
-89
lines changed

deep_tensor/bridging_densities/bridge.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ def is_last(self) -> bool:
1616
pass
1717

1818
@property
19-
def n_layers(self) -> int:
20-
return self._n_layers
19+
def num_layers(self) -> int:
20+
return self._num_layers
2121

22-
@n_layers.setter
23-
def n_layers(self, value: int) -> None:
24-
self._n_layers = value
22+
@num_layers.setter
23+
def num_layers(self, value: int) -> None:
24+
self._num_layers = value
2525
return
2626

2727
@property
@@ -86,6 +86,15 @@ def apply_preconditioner(self, us: Tensor) -> Tuple[Tensor, Tensor]:
8686
neglogdets = self.preconditioner.neglogdet_Q(us)
8787
return xs, neglogdets
8888

89+
def _eval_pullback(self, us: Tensor) -> Tensor:
90+
"""Evaluates the pullback of the target density under the
91+
preconditioning mapping.
92+
"""
93+
xs, neglogdets = self.apply_preconditioner(us)
94+
neglogfxs = self.target_func(xs)
95+
neglogfus = neglogfxs + neglogdets
96+
return neglogfus
97+
8998
def _reorder(
9099
self,
91100
xs: Tensor,

deep_tensor/bridging_densities/rare_event.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
betas: Sequence | Tensor | float = 1.0
6666
):
6767
self.gammas, self.betas = self._parse_bridging_params(gammas, betas)
68-
self.n_layers = 0
68+
self.num_layers = 0
6969
self.initialised = False
7070

7171
self._ratio_weight_funcs = {
@@ -77,7 +77,7 @@ def __init__(
7777

7878
@property
7979
def is_last(self) -> bool:
80-
return self.n_layers == (len(self.betas) - 1)
80+
return self.num_layers == (len(self.betas) - 1)
8181

8282
@staticmethod
8383
def _parse_bridging_params(
@@ -132,14 +132,23 @@ def neglogsigmoid(self, gamma: float, responses: Tensor) -> Tensor:
132132
neglogsigmoids = torch.log1p(torch.exp(gamma * dzs))
133133
return neglogsigmoids
134134

135+
def _eval_pullback(self, us: Tensor) -> Tuple[Tensor, Tensor]:
136+
"""Evaluates the pullback of the target density under the
137+
preconditioning mapping.
138+
"""
139+
xs, neglogdets = self.apply_preconditioner(us)
140+
neglogfxs, responses = self.target_func.func(xs)
141+
neglogfus = neglogfxs + neglogdets
142+
return neglogfus, responses
143+
135144
def _compute_neglogbridges(
136145
self,
137146
neglogref_us: Tensor,
138147
neglogfus: Tensor,
139148
responses: Tensor
140149
) -> Tensor:
141150

142-
k = self.n_layers
151+
k = self.num_layers
143152
neglogsigmoids_p = self.neglogsigmoid(self.gammas[k-1], responses)
144153
neglogbridges = (
145154
+ (1.0 - self.betas[k-1]) * neglogref_us
@@ -159,7 +168,7 @@ def _compute_weights_aratio(
159168
the previous bridging density for each particle.
160169
"""
161170

162-
k = self.n_layers
171+
k = self.num_layers
163172
neglogsigmoids = self.neglogsigmoid(self.gammas[k], responses)
164173
neglogsigmoids_p = self.neglogsigmoid(self.gammas[k-1], responses)
165174
neglogsigmoids_p[neglogsigmoids_p.isinf()] = 0.0
@@ -183,7 +192,7 @@ def _compute_weights_eratio(
183192
each particle.
184193
"""
185194

186-
k = self.n_layers
195+
k = self.num_layers
187196
neglogsigmoids = self.neglogsigmoid(self.gammas[k], responses)
188197

189198
neglogweights = (
@@ -243,10 +252,7 @@ def ratio_func(
243252

244253
neglogref_rs = self.reference.eval_potential(rs)[0]
245254
neglogref_us = self.reference.eval_potential(us)[0]
246-
247-
xs, neglogdets = self.apply_preconditioner(us)
248-
neglogfxs, responses = self.target_func.func(xs)
249-
neglogfus = neglogfxs + neglogdets
255+
neglogfus, responses = self._eval_pullback(us)
250256

251257
neglogratios = self._compute_ratio_func(
252258
method,
@@ -264,10 +270,7 @@ def update(self, us: Tensor, neglogfus_dirt: Tensor) -> Tuple[Tensor, Tensor]:
264270
raise Exception("Need to call self.initialise().")
265271

266272
neglogref_us = self.reference.eval_potential(us)[0]
267-
268-
xs, neglogdets = self.apply_preconditioner(us)
269-
neglogfxs, responses = self.target_func.func(xs)
270-
neglogfus = neglogfxs + neglogdets
273+
neglogfus, responses = self._eval_pullback(us)
271274

272275
neglogbridges = self._compute_neglogbridges(
273276
neglogref_us,
@@ -290,14 +293,16 @@ def _get_diagnostics(
290293
neglogfus: Tensor,
291294
neglogfus_dirt: Tensor
292295
) -> List[str]:
296+
297+
msg = [
298+
f"Gamma: {self.gammas[self.num_layers]:.4f}",
299+
f"Beta: {self.betas[self.num_layers]:.4f}"
300+
]
301+
302+
if None in (log_weights, neglogfus, neglogfus_dirt):
303+
return msg
293304

294305
div_h2 = compute_f_divergence(-neglogfus_dirt, -neglogfus)
295306
ess = estimate_ess_ratio(log_weights)
296-
297-
msg = [
298-
f"DHell: {div_h2.sqrt():.4f}",
299-
f"Gamma: {self.gammas[self.n_layers]:.4f}",
300-
f"Beta: {self.betas[self.n_layers]:.4f}",
301-
f"ESS: {ess:.4f}"
302-
]
307+
msg += [f"DHell: {div_h2.sqrt():.4f}", f"ESS: {ess:.4f}"]
303308
return msg

deep_tensor/bridging_densities/single_layer.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class SingleLayer(Bridge):
1414
"""
1515

1616
def __init__(self):
17-
self.n_layers = 0
17+
self.num_layers = 0
1818
self.is_adaptive = False
1919
return
2020

@@ -27,11 +27,7 @@ def update(
2727
us: Tensor,
2828
neglogfus_dirt: Tensor
2929
) -> Tuple[Tensor, Tensor]:
30-
31-
xs, neglogdets = self.apply_preconditioner(us)
32-
neglogfxs = self.target_func(xs)
33-
neglogfus = neglogfxs + neglogdets
34-
30+
neglogfus = self._eval_pullback(us)
3531
log_weights = -neglogfus + neglogfus_dirt
3632
return log_weights, neglogfus
3733

@@ -42,7 +38,4 @@ def ratio_func(
4238
us: Tensor,
4339
neglogfus_dirt: Tensor
4440
) -> Tensor:
45-
xs, neglogdets = self.apply_preconditioner(us)
46-
neglogfxs = self.target_func(xs)
47-
neglogfus = neglogfxs + neglogdets
48-
return neglogfus
41+
return self._eval_pullback(us)

deep_tensor/bridging_densities/tempering.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787
self.init_beta = min_beta
8888
self.max_layers = max_layers
8989
self.is_adaptive = len(self.betas) == 1
90-
self.n_layers = 0
90+
self.num_layers = 0
9191
self.initialised = False
9292

9393
self._ratio_weight_funcs = {
@@ -99,8 +99,8 @@ def __init__(
9999

100100
@property
101101
def is_last(self) -> bool:
102-
max_layers_reached = self.n_layers == self.max_layers
103-
final_beta_reached = abs(self.betas[self.n_layers-1] - 1.0) < 1e-6
102+
max_layers_reached = self.num_layers == self.max_layers
103+
final_beta_reached = abs(self.betas[self.num_layers-1] - 1.0) < 1e-6
104104
return bool(max_layers_reached or final_beta_reached)
105105

106106
def initialise(
@@ -120,7 +120,7 @@ def _compute_neglogbridges(
120120
neglogfus: Tensor
121121
) -> Tensor:
122122

123-
k = self.n_layers
123+
k = self.num_layers
124124
neglogbridges = (
125125
+ (1.0 - self.betas[k-1]) * neglogref_us
126126
+ self.betas[k-1] * neglogfus
@@ -136,7 +136,7 @@ def _compute_weights_aratio(
136136
"""Computes the ratio between the current bridging density and
137137
the previous bridging density for each particle.
138138
"""
139-
k = self.n_layers
139+
k = self.num_layers
140140
neglogweights = (
141141
+ (self.betas[k-1] - self.betas[k]) * neglogref_us
142142
+ (self.betas[k] - self.betas[k-1]) * neglogfus
@@ -153,7 +153,7 @@ def _compute_weights_eratio(
153153
the DIRT approximation to the previous bridging density for
154154
each particle.
155155
"""
156-
k = self.n_layers
156+
k = self.num_layers
157157
neglogweights = (
158158
+ (1.0 - self.betas[k]) * neglogref_us
159159
+ self.betas[k] * neglogfus
@@ -183,7 +183,7 @@ def _compute_log_weights(
183183
neglogfus: Tensor,
184184
neglogfus_dirt: Tensor
185185
) -> Tensor:
186-
beta = self.betas[self.n_layers]
186+
beta = self.betas[self.num_layers]
187187
log_weights = -beta*neglogfus - (1-beta)*neglogrefs + neglogfus_dirt
188188
return log_weights
189189

@@ -200,10 +200,7 @@ def ratio_func(
200200

201201
neglogref_rs = self.reference.eval_potential(rs)[0]
202202
neglogref_us = self.reference.eval_potential(us)[0]
203-
204-
xs, neglogdets = self.apply_preconditioner(us)
205-
neglogfxs = self.target_func(xs)
206-
neglogfus = neglogfxs + neglogdets
203+
neglogfus = self._eval_pullback(us)
207204

208205
neglogratios = self._compute_ratio_func(
209206
method,
@@ -221,11 +218,11 @@ def _adapt_beta(
221218
neglogfus_dirt: Tensor
222219
):
223220

224-
if self.n_layers == 0:
221+
if self.num_layers == 0:
225222
self.betas[0] = self.init_beta
226223
return
227224

228-
k = self.n_layers
225+
k = self.num_layers
229226
self.betas[k] = self.betas[k-1] * self.beta_factor
230227

231228
while True:
@@ -250,10 +247,7 @@ def update(
250247
) -> Tuple[Tensor, Tensor]:
251248

252249
neglogref_us = self.reference.eval_potential(us)[0]
253-
254-
xs, neglogdets = self.apply_preconditioner(us)
255-
neglogfxs = self.target_func(xs)
256-
neglogfus = neglogfxs + neglogdets
250+
neglogfus = self._eval_pullback(us)
257251

258252
if self.is_adaptive:
259253
self._adapt_beta(neglogref_us, neglogfus, neglogfus_dirt)
@@ -277,13 +271,13 @@ def _get_diagnostics(
277271
neglogfus: Tensor,
278272
neglogfus_dirt: Tensor
279273
) -> List[str]:
274+
275+
msg = [f"Beta: {self.betas[self.num_layers]:.4f}"]
276+
277+
if None in (log_weights, neglogfus, neglogfus_dirt):
278+
return msg
280279

281280
div_h2 = compute_f_divergence(-neglogfus_dirt, -neglogfus)
282281
ess = estimate_ess_ratio(log_weights)
283-
284-
msg = [
285-
f"DHell: {div_h2.sqrt():.4f}",
286-
f"Beta: {self.betas[self.n_layers]:.4f}",
287-
f"ESS: {ess:.4f}"
288-
]
282+
msg += [f"DHell: {div_h2.sqrt():.4f}", f"ESS: {ess:.4f}"]
289283
return msg

0 commit comments

Comments
 (0)