Skip to content

Commit 681441f

Browse files
committed
Add check for situations where the normalising constant of the target density is very small
1 parent 1e6bc83 commit 681441f

File tree

3 files changed

+35
-18
lines changed

3 files changed

+35
-18
lines changed

deep_tensor/ftt/ftt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def approximate(
315315
Parameters
316316
----------
317317
target_func:
318-
The target function, $f : [-1, 1]^{d} \rightarrow \mathbb{R}$.
318+
The target function, $f : [-1, 1]^{d} \rightarrow \mathbb{R}$.
319319
reference:
320320
The reference measure. If provided, this will be used to
321321
generate the initial index sets for the underlying TT.

deep_tensor/irt/dirt.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def _get_new_layer(self) -> SIRT:
152152
self.eval_ratio_func,
153153
ftt,
154154
self.reference,
155-
self.domain,
156155
self.defensive,
157156
self.cdf_tol
158157
)

deep_tensor/irt/sirt.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Callable, Dict, Tuple
2+
import warnings
23

34
import torch
45
from torch import Tensor
@@ -7,6 +8,7 @@
78
from ..ftt import ApproxBases, Direction, FTT
89
from ..linalg import batch_mul, n_mode_prod, unfold_left, unfold_right
910
from ..polynomials import CDF1D, construct_cdf
11+
from ..references import Reference
1012

1113

1214
SUBSET2DIRECTION = {
@@ -16,38 +18,33 @@
1618

1719

1820
class SIRT():
19-
r"""Squared inverse Rosenblatt transport.
21+
"""Squared inverse Rosenblatt transport.
2022
2123
Parameters
2224
----------
2325
potential:
24-
A function that receives an $n \times d$ matrix of samples and
25-
returns an $n$-dimensional vector containing the potential
26+
A function that receives an n * d matrix of samples and
27+
returns an n-dimensional vector containing the potential
2628
function of the target density evaluated at each sample.
2729
ftt:
2830
TODO
29-
defensive:
31+
reference:
3032
TODO
33+
domain:
34+
The domain of the reference.
3135
defensive:
32-
The defensive parameter, $\tau$, which ensures that the tails
33-
of the approximation are sufficiently heavy.
36+
The defensive parameter.
3437
cdf_tol:
35-
TODO
36-
37-
References
38-
----------
39-
Cui, T and Dolgov, S (2022). *[Deep composition of tensor-trains
40-
using squared inverse Rosenblatt transports](https://doi.org/10.1007/s10208-021-09537-5).*
41-
Foundations of Computational Mathematics, **22**, 1863--1922.
38+
The tolerance used when solving the rootfinding problem to
39+
evaluate the inverse of each conditional CDF.
4240
4341
"""
4442

4543
def __init__(
4644
self,
4745
target_func: Callable[[Tensor], Tensor],
4846
ftt: FTT,
49-
reference,
50-
domain: Domain,
47+
reference: Reference,
5148
defensive: float,
5249
cdf_tol: float
5350
):
@@ -56,7 +53,7 @@ def __init__(
5653
self.ftt = ftt
5754
self.bases = self.ftt.bases
5855
self.dim = self.ftt.dim
59-
self.domain = domain
56+
self.domain = reference.domain
6057
self.defensive = defensive
6158
self.cdfs = self.construct_cdfs(self.bases, cdf_tol)
6259

@@ -175,6 +172,26 @@ def _target_func(self, ls: Tensor) -> Tensor:
175172
neglogwxs = self.eval_measure_potential(xs)[0]
176173
gs = torch.exp(-0.5 * (neglogfxs - neglogwxs))
177174
return gs
175+
176+
@staticmethod
177+
def _check_z_func(z_func) -> None:
178+
179+
dtype = torch.get_default_dtype()
180+
msg = (
181+
"The normalising constant of the current SIRT layer is very small "
182+
f"({z_func:.2e}). This may cause numerical instability. "
183+
)
184+
if dtype == torch.float32 and z_func < 1.0e-5:
185+
msg += (
186+
"Consider rescaling the potential function "
187+
"or changing to double precision."
188+
)
189+
warnings.warn(msg)
190+
elif dtype == torch.float64 and z_func < 1.0e-10:
191+
msg += "Consider rescaling the potential function."
192+
warnings.warn(msg)
193+
194+
return
178195

179196
def _marginalise_forward(self) -> None:
180197
"""Computes each coefficient tensor required to evaluate the
@@ -192,6 +209,7 @@ def _marginalise_forward(self) -> None:
192209
self._Rs_f[k] = torch.linalg.qr(C_k, mode="reduced")[1].T
193210

194211
self.z_func = self._Rs_f[0].square().sum()
212+
self._check_z_func(self.z_func)
195213
return
196214

197215
def _marginalise_backward(self) -> None:

0 commit comments

Comments
 (0)