Skip to content

Commit 4b5ba6c

Browse files
committed
Tidy up resetting of bridging densities
1 parent b6e933f commit 4b5ba6c

File tree

4 files changed

+12
-11
lines changed

4 files changed

+12
-11
lines changed

deep_tensor/bridging_densities/bridge.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,19 @@ def update(self, us: Tensor, neglogfus_dirt: Tensor) -> Tuple[Tensor, Tensor]:
7171
"""
7272
pass
7373

74+
@abc.abstractmethod
75+
def reset(self) -> None:
76+
"""Resets the parameters of the bridging density to those at
77+
initialisation.
78+
"""
79+
pass
80+
7481
def initialise(
7582
self,
7683
preconditioner: Preconditioner,
7784
target_func: TargetFunc
7885
) -> None:
86+
self.reset()
7987
self.preconditioner = preconditioner
8088
self.reference = self.preconditioner.reference
8189
self.target_func = target_func

deep_tensor/bridging_densities/rare_event.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
self.gammas, self.betas = self._parse_bridging_params(gammas, betas)
6868
self.num_layers = 0
6969
self.initialised = False
70+
self.is_adaptive = False
7071

7172
self._ratio_weight_funcs = {
7273
"aratio": self._compute_weights_aratio,
@@ -84,7 +85,7 @@ def _parse_bridging_params(
8485
gammas,
8586
betas
8687
) -> Tuple[Dict[int, float], Dict[int, float]]:
87-
"""TODO: this needs tidying up, I think..."""
88+
# TODO: this could be tidied up.
8889

8990
if isinstance(gammas, Tensor):
9091
gammas = gammas.tolist()
@@ -125,9 +126,7 @@ def initialise(
125126
msg = "Target function must be of type 'RareEventFunc'."
126127
raise Exception(msg)
127128

128-
self.preconditioner = preconditioner
129-
self.reference = preconditioner.reference
130-
self.target_func = target_func
129+
Bridge.initialise(self, preconditioner, target_func)
131130
self.initialised = True
132131
return
133132

deep_tensor/bridging_densities/tempering.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@ def initialise(
115115
preconditioner: Preconditioner,
116116
target_func: TargetFunc
117117
) -> None:
118-
self.preconditioner = preconditioner
119-
self.reference = self.preconditioner.reference
120-
self.target_func = target_func
118+
Bridge.initialise(self, preconditioner, target_func)
121119
self.initialised = True
122120
return
123121

deep_tensor/irt/dirt.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,6 @@ def __init__(
4848
bridge: Bridge | None = None,
4949
options: DIRTOptions | None = None
5050
):
51-
# TODO: need to reset the bridge prior to starting. Ideally we
52-
# should be able to use the same bridge object to build
53-
# multiple DIRT objects.
5451

5552
if not isinstance(target_func, TargetFunc):
5653
target_func = TargetFunc(target_func)
@@ -66,7 +63,6 @@ def __init__(
6663
self.domain = self.reference.domain
6764
self.ftt = ftt
6865
self.bridge = bridge
69-
self.bridge.reset()
7066
self.bridge.initialise(preconditioner, target_func)
7167

7268
self.ratio_type = options.ratio_type

0 commit comments

Comments
 (0)