11from typing import Callable , Dict , Tuple
2+ import warnings
23
34import torch
45from torch import Tensor
78from ..ftt import ApproxBases , Direction , FTT
89from ..linalg import batch_mul , n_mode_prod , unfold_left , unfold_right
910from ..polynomials import CDF1D , construct_cdf
11+ from ..references import Reference
1012
1113
1214SUBSET2DIRECTION = {
1618
1719
1820class SIRT ():
19- r """Squared inverse Rosenblatt transport.
21+ """Squared inverse Rosenblatt transport.
2022
2123 Parameters
2224 ----------
2325 potential:
24- A function that receives an $n \times d$ matrix of samples and
25- returns an $n$ -dimensional vector containing the potential
26+ A function that receives an n * d matrix of samples and
27+ returns an n -dimensional vector containing the potential
2628 function of the target density evaluated at each sample.
2729 ftt:
2830 TODO
29- defensive:
31+ reference:
3032 TODO
33+ domain:
34+ The domain of the reference.
3135 defensive:
32- The defensive parameter, $\tau$, which ensures that the tails
33- of the approximation are sufficiently heavy.
36+ The defensive parameter.
3437 cdf_tol:
35- TODO
36-
37- References
38- ----------
39- Cui, T and Dolgov, S (2022). *[Deep composition of tensor-trains
40- using squared inverse Rosenblatt transports](https://doi.org/10.1007/s10208-021-09537-5).*
41- Foundations of Computational Mathematics, **22**, 1863--1922.
38+ The tolerance used when solving the rootfinding problem to
39+ evaluate the inverse of each conditional CDF.
4240
4341 """
4442
4543 def __init__ (
4644 self ,
4745 target_func : Callable [[Tensor ], Tensor ],
4846 ftt : FTT ,
49- reference ,
50- domain : Domain ,
47+ reference : Reference ,
5148 defensive : float ,
5249 cdf_tol : float
5350 ):
@@ -56,7 +53,7 @@ def __init__(
5653 self .ftt = ftt
5754 self .bases = self .ftt .bases
5855 self .dim = self .ftt .dim
59- self .domain = domain
56+ self .domain = reference . domain
6057 self .defensive = defensive
6158 self .cdfs = self .construct_cdfs (self .bases , cdf_tol )
6259
@@ -175,6 +172,26 @@ def _target_func(self, ls: Tensor) -> Tensor:
175172 neglogwxs = self .eval_measure_potential (xs )[0 ]
176173 gs = torch .exp (- 0.5 * (neglogfxs - neglogwxs ))
177174 return gs
175+
176+ @staticmethod
177+ def _check_z_func (z_func ) -> None :
178+
179+ dtype = torch .get_default_dtype ()
180+ msg = (
181+ "The normalising constant of the current SIRT layer is very small "
182+ f"({ z_func :.2e} ). This may cause numerical instability. "
183+ )
184+ if dtype == torch .float32 and z_func < 1.0e-5 :
185+ msg += (
186+ "Consider rescaling the potential function "
187+ "or changing to double precision."
188+ )
189+ warnings .warn (msg )
190+ elif dtype == torch .float64 and z_func < 1.0e-10 :
191+ msg += "Consider rescaling the potential function."
192+ warnings .warn (msg )
193+
194+ return
178195
179196 def _marginalise_forward (self ) -> None :
180197 """Computes each coefficient tensor required to evaluate the
@@ -192,6 +209,7 @@ def _marginalise_forward(self) -> None:
192209 self ._Rs_f [k ] = torch .linalg .qr (C_k , mode = "reduced" )[1 ].T
193210
194211 self .z_func = self ._Rs_f [0 ].square ().sum ()
212+ self ._check_z_func (self .z_func )
195213 return
196214
197215 def _marginalise_backward (self ) -> None :
0 commit comments