Skip to content

Commit 79e4efc

Browse files
committed
Adjust tests and fix bug introduced in previous commit
1 parent 7c29d25 commit 79e4efc

File tree

4 files changed

+8
-7
lines changed

4 files changed

+8
-7
lines changed

deep_tensor/irt/sirt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,9 @@ def _marginalise_backward(self) -> None:
207207

208208
def _eval_rt_local_forward(self, ls: Tensor) -> Tensor:
209209

210-
dim_ls = ls.shape[1]
210+
num_ls, dim_ls = ls.shape
211211
zs = torch.zeros_like(ls)
212-
Gs_prod = torch.ones_like(ls[:, 0])
212+
Gs_prod = torch.ones((num_ls, 1), device=ls.device)
213213

214214
cores = self.ftt.cores
215215
Bs = self._Bs_f

tests/test_domains/test_bounded_domain.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
class TestLinearDomain(unittest.TestCase):
1212

1313
def setup_domain(self):
14-
bounds = torch.tensor([-2.0, 4.0])
15-
domain = dt.BoundedDomain(bounds=bounds)
14+
bounds = [-2.0, 4.0]
15+
domain = dt.BoundedDomain(bounds)
1616
return domain
1717

1818
def test_linear_domain(self):
@@ -22,8 +22,9 @@ def test_linear_domain(self):
2222
domain = self.setup_domain()
2323

2424
bounds_true = torch.tensor([-2.0, 4.0])
25+
bounds = torch.tensor(domain.bounds)
2526

26-
self.assertTrue((domain.bounds - bounds_true).abs().max() < 1e-8)
27+
self.assertTrue((bounds - bounds_true).abs().max() < 1e-8)
2728
self.assertAlmostEqual(domain.dxdl, 3.)
2829
self.assertAlmostEqual(domain.mean, 1.)
2930
self.assertAlmostEqual(domain.left, -2.)

tests/test_polynomials/test_cdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class are inverses of one another.
3535
for poly in polys:
3636
with self.subTest(poly=poly):
3737

38-
cdf = dt.construct_cdf(polys[poly])
38+
cdf = dt.construct_cdf(polys[poly], error_tol=1e-10)
3939

4040
ls = torch.linspace(-1.0, 1.0, n_ls)
4141
ps = dummy_pdf(cdf.nodes) + 1e-2

tests/test_polynomials/test_piecewise_cdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class TestPiecewiseCDF(unittest.TestCase):
1212

1313
def setup_cdf(self):
1414
poly = dt.Lagrange1(num_elems=2)
15-
cdf = dt.Lagrange1CDF(poly=poly)
15+
cdf = dt.Lagrange1CDF(poly=poly, error_tol=1e-10)
1616
return cdf
1717

1818
def test_lagrange_1d_cdf(self):

0 commit comments

Comments
 (0)