44import torch
55from torch import Tensor
66
7- from ..domains import Domain
87from ..ftt import ApproxBases , Direction , FTT
98from ..linalg import batch_mul , n_mode_prod , unfold_left , unfold_right
109from ..polynomials import CDF1D , construct_cdf
@@ -27,9 +26,11 @@ class SIRT():
2726 returns an n-dimensional vector containing the potential
2827 function of the target density evaluated at each sample.
2928 ftt:
30- TODO
29+ The functional tensor train to use to approximate the
30+ square root of the ratio between the target density and
31+ weighting function.
3132 reference:
32- TODO
33+ The reference density.
3334 domain:
3435 The domain of the reference.
3536 defensive:
@@ -56,11 +57,11 @@ def __init__(
5657 self .domain = reference .domain
5758 self .defensive = defensive
5859 self .cdfs = self .construct_cdfs (self .bases , cdf_tol )
59-
6060 self .ftt .approximate (self ._target_func , reference )
6161
62- # Compute coefficient tensors and marginalisation coefficents,
63- # from the first core to the last and the last core to the first
62+ # Precompute coefficient tensors and marginalisation
63+ # coefficents, from the first core to the last and the last
64+ # core to the first.
6465 self ._Bs_f : Dict [int , Tensor ] = {}
6566 self ._Rs_f : Dict [int , Tensor ] = {}
6667 self ._Bs_b : Dict [int , Tensor ] = {}
@@ -71,7 +72,15 @@ def __init__(
7172
7273 @property
7374 def z (self ) -> Tensor :
74- return self .defensive + self .z_func
75+ return (1.0 * self .defensive ) * self .z_func
76+
77+ @property
78+ def coef_defensive (self ) -> Tensor :
79+ # Note: this is a slight change from the defensive parameter
80+ # defined in @CuiDolgov2022. The defensive parameter now scales
81+ # according to the normalising constant of the FTT approximation
82+ # to the target density.
83+ return self .defensive * self .z_func
7584
7685 @property
7786 def num_eval (self ) -> int :
@@ -244,7 +253,7 @@ def _eval_rt_local_forward(self, ls: Tensor) -> Tensor:
244253 # Compute (unnormalised) conditional PDF for each sample
245254 Ps = FTT .eval_core (self .bases [k ], Bs [k ], self .cdfs [k ].nodes )
246255 gs = torch .einsum ("jl, ilk -> ijk" , Gs_prod , Ps )
247- ps = gs .square ().sum (dim = 2 ) + self .defensive
256+ ps = gs .square ().sum (dim = 2 ) + self .coef_defensive
248257
249258 # Evaluate CDF to obtain corresponding uniform variates
250259 zs [:, k ] = self .cdfs [k ].eval_cdf (ps , ls [:, k ])
@@ -270,7 +279,7 @@ def _eval_rt_local_backward(self, ls: Tensor) -> Tensor:
270279 # Compute (unnormalised) conditional PDF for each sample
271280 Ps = FTT .eval_core (self .bases [k ], Bs [k ], self .cdfs [k ].nodes )
272281 gs = torch .einsum ("ijl, lk -> ijk" , Ps , Gs_prod )
273- ps = gs .square ().sum (dim = 1 ) + self .defensive
282+ ps = gs .square ().sum (dim = 1 ) + self .coef_defensive
274283
275284 # Evaluate CDF to obtain corresponding uniform variates
276285 zs [:, - i ] = self .cdfs [k ].eval_cdf (ps , ls [:, - i ])
@@ -337,7 +346,7 @@ def _eval_irt_local_forward(self, zs: Tensor) -> Tuple[Tensor, Tensor]:
337346
338347 Ps = FTT .eval_core (self .bases [k ], Bs [k ], self .cdfs [k ].nodes )
339348 gls = n_mode_prod (Ps , gs , n = 1 )
340- ps = gls .square ().sum (dim = 2 ) + self .defensive
349+ ps = gls .square ().sum (dim = 2 ) + self .coef_defensive
341350 ls [:, k ] = self .cdfs [k ].invert_cdf (ps , zs [:, k ])
342351
343352 Gs = FTT .eval_core (self .bases [k ], self .ftt .cores [k ], ls [:, k ])
@@ -379,7 +388,7 @@ def _eval_irt_local_backward(self, zs: Tensor) -> Tuple[Tensor, Tensor]:
379388
380389 Ps = FTT .eval_core_rev (self .bases [k ], Bs [k ], self .cdfs [k ].nodes )
381390 gls = n_mode_prod (Ps , gs , n = 1 )
382- ps = gls .square ().sum (dim = 2 ) + self .defensive
391+ ps = gls .square ().sum (dim = 2 ) + self .coef_defensive
383392 ls [:, - i ] = self .cdfs [k ].invert_cdf (ps , zs [:, - i ])
384393
385394 Gs = FTT .eval_core_rev (self .bases [k ], cores [k ], ls [:, - i ])
@@ -424,7 +433,7 @@ def _eval_irt_local(
424433
425434 indices = self ._get_transform_indices (zs .shape [1 ], direction )
426435
427- neglogpls = - (gs_sq + self .defensive ).log ()
436+ neglogpls = - (gs_sq + self .coef_defensive ).log ()
428437 neglogwls = self .bases .eval_measure_potential (ls , indices )
429438 neglogfls = self .z .log () + neglogpls + neglogwls
430439
@@ -453,7 +462,7 @@ def _eval_cirt_local_forward(
453462
454463 Ps = FTT .eval_core (self .bases [k ], Bs [k ], ls_x [:, k ])
455464 gs_marg = batch_mul (Gs_prod , Ps )
456- ps_marg = gs_marg .square ().sum (dim = (1 , 2 )) + self .defensive
465+ ps_marg = gs_marg .square ().sum (dim = (1 , 2 )) + self .coef_defensive
457466
458467 Gs = FTT .eval_core (self .bases [k ], cores [k ], ls_x [:, k ])
459468 Gs_prod = batch_mul (Gs_prod , Gs )
@@ -463,13 +472,13 @@ def _eval_cirt_local_forward(
463472
464473 Ps = FTT .eval_core (self .bases [k ], Bs [k ], self .cdfs [k ].nodes )
465474 gs = torch .einsum ("mij, ljk -> lmk" , Gs_prod , Ps )
466- ps = gs .square ().sum (dim = 2 ) + self .defensive
475+ ps = gs .square ().sum (dim = 2 ) + self .coef_defensive
467476 ls_y [:, i ] = self .cdfs [k ].invert_cdf (ps , zs [:, i ])
468477
469478 Gs = FTT .eval_core (self .bases [k ], cores [k ], ls_y [:, i ])
470479 Gs_prod = batch_mul (Gs_prod , Gs )
471480
472- ps = Gs_prod .flatten ().square () + self .defensive
481+ ps = Gs_prod .flatten ().square () + self .coef_defensive
473482
474483 indices = d_xs + torch .arange (d_zs )
475484 neglogwls_y = self .bases .eval_measure_potential (ls_y , indices )
@@ -497,7 +506,7 @@ def _eval_cirt_local_backward(
497506
498507 Ps = FTT .eval_core (self .bases [d_zs ], Bs [d_zs ], ls_x [:, 0 ])
499508 gs_marg = batch_mul (Ps , Gs_prod )
500- ps_marg = gs_marg .square ().sum (dim = (1 , 2 )) + self .defensive
509+ ps_marg = gs_marg .square ().sum (dim = (1 , 2 )) + self .coef_defensive
501510
502511 Gs = FTT .eval_core (self .bases [d_zs ], cores [d_zs ], ls_x [:, 0 ])
503512 Gs_prod = batch_mul (Gs , Gs_prod )
@@ -507,13 +516,13 @@ def _eval_cirt_local_backward(
507516
508517 Ps = FTT .eval_core (self .bases [k ], Bs [k ], self .cdfs [k ].nodes )
509518 gs = torch .einsum ("lij, mjk -> lmi" , Ps , Gs_prod )
510- ps = gs .square ().sum (dim = 2 ) + self .defensive
519+ ps = gs .square ().sum (dim = 2 ) + self .coef_defensive
511520 ls_y [:, k ] = self .cdfs [k ].invert_cdf (ps , zs [:, k ])
512521
513522 Gs = FTT .eval_core (self .bases [k ], cores [k ], ls_y [:, k ])
514523 Gs_prod = batch_mul (Gs , Gs_prod )
515524
516- ps = Gs_prod .flatten ().square () + self .defensive
525+ ps = Gs_prod .flatten ().square () + self .coef_defensive
517526
518527 indices = torch .arange (d_zs - 1 , - 1 , - 1 )
519528 neglogwls_y = self .bases .eval_measure_potential (ls_y , indices )
@@ -582,7 +591,7 @@ def _eval_potential_grad_local(self, ls: Tensor) -> Tensor:
582591 zs = self ._eval_rt_local_forward (ls )
583592 ls , gs_sq = self ._eval_irt_local_forward (zs )
584593 n_ls = ls .shape [0 ]
585- ps = gs_sq + self .defensive
594+ ps = gs_sq + self .coef_defensive
586595 neglogws = self .bases .eval_measure_potential (ls )
587596 ws = torch .exp (- neglogws )
588597 fs = ps * ws # Don't need to normalise as derivative ends up being a ratio
@@ -662,11 +671,11 @@ def _eval_rt_jac_local_forward(self, ls: Tensor) -> Tensor:
662671 # Evaluate marginal probability for the first k elements of
663672 # each sample
664673 gs = batch_mul (Gs_prod [k - 1 ], Ps [k ])
665- ps_marg [k ] = gs .square ().sum (dim = (1 , 2 )) + self .defensive
674+ ps_marg [k ] = gs .square ().sum (dim = (1 , 2 )) + self .coef_defensive
666675
667676 # Compute (unnormalised) marginal PDF at CDF nodes for each sample
668677 gs_grid = torch .einsum ("mij, ljk -> lmik" , Gs_prod [k - 1 ], Ps_grid [k ])
669- ps_grid [k ] = gs_grid .square ().sum (dim = (2 , 3 )) + self .defensive
678+ ps_grid [k ] = gs_grid .square ().sum (dim = (2 , 3 )) + self .coef_defensive
670679
671680 # Derivatives of marginal PDF
672681 for k in range (self .dim - 1 ):
@@ -757,11 +766,11 @@ def _eval_rt_jac_local_backward(self, ls: Tensor) -> Tensor:
757766 # Evaluate marginal probability for the first k elements of
758767 # each sample
759768 gs = batch_mul (Gs_prod [k + 1 ], Ps [k ])
760- ps_marg [k ] = gs .square ().sum (dim = (1 , 2 )) + self .defensive
769+ ps_marg [k ] = gs .square ().sum (dim = (1 , 2 )) + self .coef_defensive
761770
762771 # Compute (unnormalised) marginal PDF at CDF nodes for each sample
763772 gs_grid = torch .einsum ("mij, ljk -> lmik" , Gs_prod [k + 1 ], Ps_grid [k ])
764- ps_grid [k ] = gs_grid .square ().sum (dim = (2 , 3 )) + self .defensive
773+ ps_grid [k ] = gs_grid .square ().sum (dim = (2 , 3 )) + self .coef_defensive
765774
766775 # Derivatives of marginal PDF
767776 for k in range (1 , self .dim ):
@@ -878,7 +887,7 @@ def _eval_potential_local(self, ls: Tensor, direction: Direction) -> Tensor:
878887 gs_sq = (self ._Rs_b [self .dim - dim_l - 1 ] @ gs .T ).square ().sum (dim = 0 )
879888
880889 neglogwls = self .bases .eval_measure_potential (ls , indices )
881- neglogfls = self .z .log () - (gs_sq + self .defensive ).log () + neglogwls
890+ neglogfls = self .z .log () - (gs_sq + self .coef_defensive ).log () + neglogwls
882891 return neglogfls
883892
884893 def _eval_potential (self , xs : Tensor , subset : str ) -> Tensor :
0 commit comments