File tree Expand file tree Collapse file tree 4 files changed +8
-7
lines changed Expand file tree Collapse file tree 4 files changed +8
-7
lines changed Original file line number Diff line number Diff line change @@ -207,9 +207,9 @@ def _marginalise_backward(self) -> None:
207207
208208    def  _eval_rt_local_forward (self , ls : Tensor ) ->  Tensor :
209209
210-         dim_ls  =  ls .shape [ 1 ] 
210+         num_ls ,  dim_ls  =  ls .shape 
211211        zs  =  torch .zeros_like (ls )
212-         Gs_prod  =  torch .ones_like ( ls [:,  0 ] )
212+         Gs_prod  =  torch .ones (( num_ls ,  1 ),  device = ls . device )
213213
214214        cores  =  self .ftt .cores 
215215        Bs  =  self ._Bs_f  
Original file line number Diff line number Diff line change 1111class  TestLinearDomain (unittest .TestCase ):
1212
1313    def  setup_domain (self ):
14-         bounds  =  torch . tensor ( [- 2.0 , 4.0 ]) 
15-         domain  =  dt .BoundedDomain (bounds = bounds )
14+         bounds  =  [- 2.0 , 4.0 ]
15+         domain  =  dt .BoundedDomain (bounds )
1616        return  domain 
1717
1818    def  test_linear_domain (self ):
@@ -22,8 +22,9 @@ def test_linear_domain(self):
2222        domain  =  self .setup_domain ()
2323
2424        bounds_true  =  torch .tensor ([- 2.0 , 4.0 ])
25+         bounds  =  torch .tensor (domain .bounds )
2526
26-         self .assertTrue ((domain . bounds  -  bounds_true ).abs ().max () <  1e-8 )
27+         self .assertTrue ((bounds  -  bounds_true ).abs ().max () <  1e-8 )
2728        self .assertAlmostEqual (domain .dxdl , 3. )
2829        self .assertAlmostEqual (domain .mean , 1. )
2930        self .assertAlmostEqual (domain .left , - 2. )
Original file line number Diff line number Diff line change @@ -35,7 +35,7 @@ class are inverses of one another.
3535        for  poly  in  polys :
3636            with  self .subTest (poly = poly ):
3737
38-                 cdf  =  dt .construct_cdf (polys [poly ])
38+                 cdf  =  dt .construct_cdf (polys [poly ],  error_tol = 1e-10 )
3939
4040                ls  =  torch .linspace (- 1.0 , 1.0 , n_ls )
4141                ps  =  dummy_pdf (cdf .nodes ) +  1e-2 
Original file line number Diff line number Diff line change @@ -12,7 +12,7 @@ class TestPiecewiseCDF(unittest.TestCase):
1212
1313    def  setup_cdf (self ):
1414        poly  =  dt .Lagrange1 (num_elems = 2 )
15-         cdf  =  dt .Lagrange1CDF (poly = poly )
15+         cdf  =  dt .Lagrange1CDF (poly = poly ,  error_tol = 1e-10 )
1616        return  cdf 
1717
1818    def  test_lagrange_1d_cdf (self ):
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments