Skip to content

Commit c234414

Browse files
committed
Scale the defensive parameter with the normalising constant of the FTT approximation to the target density
1 parent 681441f commit c234414

File tree

1 file changed

+33
-24
lines changed

1 file changed

+33
-24
lines changed

deep_tensor/irt/sirt.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import torch
55
from torch import Tensor
66

7-
from ..domains import Domain
87
from ..ftt import ApproxBases, Direction, FTT
98
from ..linalg import batch_mul, n_mode_prod, unfold_left, unfold_right
109
from ..polynomials import CDF1D, construct_cdf
@@ -27,9 +26,11 @@ class SIRT():
2726
returns an n-dimensional vector containing the potential
2827
function of the target density evaluated at each sample.
2928
ftt:
30-
TODO
29+
The functional tensor train to use to approximate the
30+
square root of the ratio between the target density and
31+
weighting function.
3132
reference:
32-
TODO
33+
The reference density.
3334
domain:
3435
The domain of the reference.
3536
defensive:
@@ -56,11 +57,11 @@ def __init__(
5657
self.domain = reference.domain
5758
self.defensive = defensive
5859
self.cdfs = self.construct_cdfs(self.bases, cdf_tol)
59-
6060
self.ftt.approximate(self._target_func, reference)
6161

62-
# Compute coefficient tensors and marginalisation coefficents,
63-
# from the first core to the last and the last core to the first
62+
# Precompute coefficient tensors and marginalisation
63+
# coefficents, from the first core to the last and the last
64+
# core to the first.
6465
self._Bs_f: Dict[int, Tensor] = {}
6566
self._Rs_f: Dict[int, Tensor] = {}
6667
self._Bs_b: Dict[int, Tensor] = {}
@@ -71,7 +72,15 @@ def __init__(
7172

7273
@property
7374
def z(self) -> Tensor:
74-
return self.defensive + self.z_func
75+
return (1.0 * self.defensive) * self.z_func
76+
77+
@property
78+
def coef_defensive(self) -> Tensor:
79+
# Note: this is a slight change from the defensive parameter
80+
# defined in @CuiDolgov2022. The defensive parameter now scales
81+
# according to the normalising constant of the FTT approximation
82+
# to the target density.
83+
return self.defensive * self.z_func
7584

7685
@property
7786
def num_eval(self) -> int:
@@ -244,7 +253,7 @@ def _eval_rt_local_forward(self, ls: Tensor) -> Tensor:
244253
# Compute (unnormalised) conditional PDF for each sample
245254
Ps = FTT.eval_core(self.bases[k], Bs[k], self.cdfs[k].nodes)
246255
gs = torch.einsum("jl, ilk -> ijk", Gs_prod, Ps)
247-
ps = gs.square().sum(dim=2) + self.defensive
256+
ps = gs.square().sum(dim=2) + self.coef_defensive
248257

249258
# Evaluate CDF to obtain corresponding uniform variates
250259
zs[:, k] = self.cdfs[k].eval_cdf(ps, ls[:, k])
@@ -270,7 +279,7 @@ def _eval_rt_local_backward(self, ls: Tensor) -> Tensor:
270279
# Compute (unnormalised) conditional PDF for each sample
271280
Ps = FTT.eval_core(self.bases[k], Bs[k], self.cdfs[k].nodes)
272281
gs = torch.einsum("ijl, lk -> ijk", Ps, Gs_prod)
273-
ps = gs.square().sum(dim=1) + self.defensive
282+
ps = gs.square().sum(dim=1) + self.coef_defensive
274283

275284
# Evaluate CDF to obtain corresponding uniform variates
276285
zs[:, -i] = self.cdfs[k].eval_cdf(ps, ls[:, -i])
@@ -337,7 +346,7 @@ def _eval_irt_local_forward(self, zs: Tensor) -> Tuple[Tensor, Tensor]:
337346

338347
Ps = FTT.eval_core(self.bases[k], Bs[k], self.cdfs[k].nodes)
339348
gls = n_mode_prod(Ps, gs, n=1)
340-
ps = gls.square().sum(dim=2) + self.defensive
349+
ps = gls.square().sum(dim=2) + self.coef_defensive
341350
ls[:, k] = self.cdfs[k].invert_cdf(ps, zs[:, k])
342351

343352
Gs = FTT.eval_core(self.bases[k], self.ftt.cores[k], ls[:, k])
@@ -379,7 +388,7 @@ def _eval_irt_local_backward(self, zs: Tensor) -> Tuple[Tensor, Tensor]:
379388

380389
Ps = FTT.eval_core_rev(self.bases[k], Bs[k], self.cdfs[k].nodes)
381390
gls = n_mode_prod(Ps, gs, n=1)
382-
ps = gls.square().sum(dim=2) + self.defensive
391+
ps = gls.square().sum(dim=2) + self.coef_defensive
383392
ls[:, -i] = self.cdfs[k].invert_cdf(ps, zs[:, -i])
384393

385394
Gs = FTT.eval_core_rev(self.bases[k], cores[k], ls[:, -i])
@@ -424,7 +433,7 @@ def _eval_irt_local(
424433

425434
indices = self._get_transform_indices(zs.shape[1], direction)
426435

427-
neglogpls = -(gs_sq + self.defensive).log()
436+
neglogpls = -(gs_sq + self.coef_defensive).log()
428437
neglogwls = self.bases.eval_measure_potential(ls, indices)
429438
neglogfls = self.z.log() + neglogpls + neglogwls
430439

@@ -453,7 +462,7 @@ def _eval_cirt_local_forward(
453462

454463
Ps = FTT.eval_core(self.bases[k], Bs[k], ls_x[:, k])
455464
gs_marg = batch_mul(Gs_prod, Ps)
456-
ps_marg = gs_marg.square().sum(dim=(1, 2)) + self.defensive
465+
ps_marg = gs_marg.square().sum(dim=(1, 2)) + self.coef_defensive
457466

458467
Gs = FTT.eval_core(self.bases[k], cores[k], ls_x[:, k])
459468
Gs_prod = batch_mul(Gs_prod, Gs)
@@ -463,13 +472,13 @@ def _eval_cirt_local_forward(
463472

464473
Ps = FTT.eval_core(self.bases[k], Bs[k], self.cdfs[k].nodes)
465474
gs = torch.einsum("mij, ljk -> lmk", Gs_prod, Ps)
466-
ps = gs.square().sum(dim=2) + self.defensive
475+
ps = gs.square().sum(dim=2) + self.coef_defensive
467476
ls_y[:, i] = self.cdfs[k].invert_cdf(ps, zs[:, i])
468477

469478
Gs = FTT.eval_core(self.bases[k], cores[k], ls_y[:, i])
470479
Gs_prod = batch_mul(Gs_prod, Gs)
471480

472-
ps = Gs_prod.flatten().square() + self.defensive
481+
ps = Gs_prod.flatten().square() + self.coef_defensive
473482

474483
indices = d_xs + torch.arange(d_zs)
475484
neglogwls_y = self.bases.eval_measure_potential(ls_y, indices)
@@ -497,7 +506,7 @@ def _eval_cirt_local_backward(
497506

498507
Ps = FTT.eval_core(self.bases[d_zs], Bs[d_zs], ls_x[:, 0])
499508
gs_marg = batch_mul(Ps, Gs_prod)
500-
ps_marg = gs_marg.square().sum(dim=(1, 2)) + self.defensive
509+
ps_marg = gs_marg.square().sum(dim=(1, 2)) + self.coef_defensive
501510

502511
Gs = FTT.eval_core(self.bases[d_zs], cores[d_zs], ls_x[:, 0])
503512
Gs_prod = batch_mul(Gs, Gs_prod)
@@ -507,13 +516,13 @@ def _eval_cirt_local_backward(
507516

508517
Ps = FTT.eval_core(self.bases[k], Bs[k], self.cdfs[k].nodes)
509518
gs = torch.einsum("lij, mjk -> lmi", Ps, Gs_prod)
510-
ps = gs.square().sum(dim=2) + self.defensive
519+
ps = gs.square().sum(dim=2) + self.coef_defensive
511520
ls_y[:, k] = self.cdfs[k].invert_cdf(ps, zs[:, k])
512521

513522
Gs = FTT.eval_core(self.bases[k], cores[k], ls_y[:, k])
514523
Gs_prod = batch_mul(Gs, Gs_prod)
515524

516-
ps = Gs_prod.flatten().square() + self.defensive
525+
ps = Gs_prod.flatten().square() + self.coef_defensive
517526

518527
indices = torch.arange(d_zs-1, -1, -1)
519528
neglogwls_y = self.bases.eval_measure_potential(ls_y, indices)
@@ -582,7 +591,7 @@ def _eval_potential_grad_local(self, ls: Tensor) -> Tensor:
582591
zs = self._eval_rt_local_forward(ls)
583592
ls, gs_sq = self._eval_irt_local_forward(zs)
584593
n_ls = ls.shape[0]
585-
ps = gs_sq + self.defensive
594+
ps = gs_sq + self.coef_defensive
586595
neglogws = self.bases.eval_measure_potential(ls)
587596
ws = torch.exp(-neglogws)
588597
fs = ps * ws # Don't need to normalise as derivative ends up being a ratio
@@ -662,11 +671,11 @@ def _eval_rt_jac_local_forward(self, ls: Tensor) -> Tensor:
662671
# Evaluate marginal probability for the first k elements of
663672
# each sample
664673
gs = batch_mul(Gs_prod[k-1], Ps[k])
665-
ps_marg[k] = gs.square().sum(dim=(1, 2)) + self.defensive
674+
ps_marg[k] = gs.square().sum(dim=(1, 2)) + self.coef_defensive
666675

667676
# Compute (unnormalised) marginal PDF at CDF nodes for each sample
668677
gs_grid = torch.einsum("mij, ljk -> lmik", Gs_prod[k-1], Ps_grid[k])
669-
ps_grid[k] = gs_grid.square().sum(dim=(2, 3)) + self.defensive
678+
ps_grid[k] = gs_grid.square().sum(dim=(2, 3)) + self.coef_defensive
670679

671680
# Derivatives of marginal PDF
672681
for k in range(self.dim-1):
@@ -757,11 +766,11 @@ def _eval_rt_jac_local_backward(self, ls: Tensor) -> Tensor:
757766
# Evaluate marginal probability for the first k elements of
758767
# each sample
759768
gs = batch_mul(Gs_prod[k+1], Ps[k])
760-
ps_marg[k] = gs.square().sum(dim=(1, 2)) + self.defensive
769+
ps_marg[k] = gs.square().sum(dim=(1, 2)) + self.coef_defensive
761770

762771
# Compute (unnormalised) marginal PDF at CDF nodes for each sample
763772
gs_grid = torch.einsum("mij, ljk -> lmik", Gs_prod[k+1], Ps_grid[k])
764-
ps_grid[k] = gs_grid.square().sum(dim=(2, 3)) + self.defensive
773+
ps_grid[k] = gs_grid.square().sum(dim=(2, 3)) + self.coef_defensive
765774

766775
# Derivatives of marginal PDF
767776
for k in range(1, self.dim):
@@ -878,7 +887,7 @@ def _eval_potential_local(self, ls: Tensor, direction: Direction) -> Tensor:
878887
gs_sq = (self._Rs_b[self.dim-dim_l-1] @ gs.T).square().sum(dim=0)
879888

880889
neglogwls = self.bases.eval_measure_potential(ls, indices)
881-
neglogfls = self.z.log() - (gs_sq + self.defensive).log() + neglogwls
890+
neglogfls = self.z.log() - (gs_sq + self.coef_defensive).log() + neglogwls
882891
return neglogfls
883892

884893
def _eval_potential(self, xs: Tensor, subset: str) -> Tensor:

0 commit comments

Comments
 (0)