Skip to content

Commit 566fb92

Browse files
committed
Tidy up bridging densities and pCN sampler
1 parent 24f2c04 commit 566fb92

File tree

3 files changed

+23
-24
lines changed

3 files changed

+23
-24
lines changed

deep_tensor/bridging_densities/bridge.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,16 @@
66

77

88
class Bridge(abc.ABC):
9-
10-
@property
11-
@abc.abstractmethod
12-
def is_adaptive(self) -> bool:
13-
return
149

1510
@property
1611
@abc.abstractmethod
1712
def is_last(self) -> bool:
18-
return
13+
pass
1914

2015
@property
2116
@abc.abstractmethod
2217
def params_dict(self) -> Dict:
23-
return
18+
pass
2419

2520
@property
2621
def n_layers(self) -> int:
@@ -77,7 +72,7 @@ def _get_ratio_func(
7772
function evaluated for each sample.
7873
7974
"""
80-
return
75+
pass
8176

8277
@abc.abstractmethod
8378
def _compute_log_weights(
@@ -112,7 +107,7 @@ def _compute_log_weights(
112107
previous bridging density evaluated at each sample.
113108
114109
"""
115-
return
110+
pass
116111

117112
def _set_init(self, neglogliks: Tensor) -> None:
118113
"""Computes the properties of the initial bridging density.

deep_tensor/bridging_densities/tempering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def max_layers(self, value: int) -> None:
4040
def is_last(self) -> bool:
4141
max_layers_reached = self.n_layers == self.max_layers
4242
final_beta_reached = (self.betas[self.n_layers-1] - 1.0).abs() < 1e-6
43-
return max_layers_reached or final_beta_reached
43+
return bool(max_layers_reached or final_beta_reached)
4444

4545
@property
4646
def params_dict(self) -> Dict:

deep_tensor/debiasing/mcmc.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ def run_dirt_pcn(
134134
dirt: AbstractDIRT,
135135
n: int,
136136
dt: float = 2.0,
137-
y_obs: Tensor | None = None,
138137
x0: Tensor | None = None,
138+
ys: Tensor | None = None,
139139
subset: str = "first",
140140
verbose: bool = True
141141
) -> MCMCResult:
@@ -157,8 +157,6 @@ def run_dirt_pcn(
157157
unnormalised) target density at a given sample.
158158
dirt:
159159
A previously-constructed DIRT object.
160-
y_obs:
161-
A tensor containing the observations.
162160
n:
163161
The length of the Markov chain to construct.
164162
dt:
@@ -169,6 +167,12 @@ def run_dirt_pcn(
169167
be applied to it to generate the starting location for sampling
170168
from the pullback of the target density. Otherwise, the mean of
171169
the reference density will be used.
170+
ys:
171+
A tensor containing a set of values to condition on.
172+
subset:
173+
If `ys` are passed in, whether they correspond to the first
174+
$k$ variables (`subset='first'`) or the final $k$ variables
175+
(`subset='last'`).
172176
verbose:
173177
Whether to print diagnostic information during the sampling
174178
process.
@@ -218,43 +222,43 @@ def run_dirt_pcn(
218222
msg = "Stepsize must be positive."
219223
raise Exception(msg)
220224

221-
if y_obs is not None:
222-
223-
y_obs = torch.atleast_2d(y_obs)
224-
dim = dirt.dim - y_obs.shape[1]
225+
if ys is None:
226+
227+
dim = dirt.dim
225228

226229
def negloglik_pullback(rs: Tensor) -> Tensor:
227230
"""Returns the difference between the negative logarithm of the
228231
pullback of the target function under the DIRT mapping and the
229232
negative log-prior density.
230233
"""
231234
rs = torch.atleast_2d(rs)
232-
neglogfr = dirt.eval_cirt_pullback(potential, y_obs, rs, subset=subset)
235+
neglogfr = dirt.eval_irt_pullback(potential, rs, subset=subset)
233236
neglogref = dirt.reference.eval_potential(rs)[0]
234237
return neglogfr - neglogref
235-
238+
236239
def irt_func(rs: Tensor) -> Tensor:
237240
rs = torch.atleast_2d(rs)
238-
ms = dirt.eval_cirt(y_obs, rs, subset=subset)[0]
241+
ms = dirt.eval_irt(rs, subset=subset)[0]
239242
return ms
240243

241244
else:
242245

243-
dim = dirt.dim
246+
ys = torch.atleast_2d(ys)
247+
dim = dirt.dim - ys.shape[1]
244248

245249
def negloglik_pullback(rs: Tensor) -> Tensor:
246250
"""Returns the difference between the negative logarithm of the
247251
pullback of the target function under the DIRT mapping and the
248252
negative log-prior density.
249253
"""
250254
rs = torch.atleast_2d(rs)
251-
neglogfr = dirt.eval_irt_pullback(potential, rs, subset=subset)
255+
neglogfr = dirt.eval_cirt_pullback(potential, ys, rs, subset=subset)
252256
neglogref = dirt.reference.eval_potential(rs)[0]
253257
return neglogfr - neglogref
254-
258+
255259
def irt_func(rs: Tensor) -> Tensor:
256260
rs = torch.atleast_2d(rs)
257-
ms = dirt.eval_irt(rs, subset=subset)[0]
261+
ms = dirt.eval_cirt(ys, rs, subset=subset)[0]
258262
return ms
259263

260264
res = _run_irt_pcn(

0 commit comments

Comments
 (0)