Skip to content

Commit 24f2c04

Browse files
committed
Tidy up polynomials
1 parent 677af52 commit 24f2c04

File tree

11 files changed

+23
-26
lines changed

11 files changed

+23
-26
lines changed

deep_tensor/polynomials/cdf_1d.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,8 @@ def nodes(self, value: Tensor) -> None:
3434
return
3535

3636
@property
37-
@abc.abstractmethod
38-
def cardinality(self) -> Tensor:
39-
"""The number of nodes associated with the polynomial basis of
40-
the CDF.
41-
"""
42-
pass
37+
def cardinality(self) -> int:
38+
return self.nodes.numel()
4339

4440
@property
4541
@abc.abstractmethod

deep_tensor/polynomials/piecewise/lagrange_1.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,6 @@ def __init__(self, num_elems: int):
5151
self.mass_R = torch.linalg.cholesky(mass).T
5252
return
5353

54-
# @property
55-
# def nodes(self) -> Tensor:
56-
# return self._nodes
57-
58-
# @nodes.setter
59-
# def nodes(self, value: Tensor) -> None:
60-
# self._nodes = value
61-
# return
62-
6354
@property
6455
def mass_R(self) -> Tensor:
6556
return self._mass_R

deep_tensor/polynomials/piecewise/lagrange_p.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,11 @@ def int_W(self, value: Tensor) -> None:
185185
return
186186

187187
@property
188-
def cardinality(self):
188+
def cardinality(self) -> int:
189189
return self.nodes.numel()
190190

191191
@property
192-
def domain(self):
192+
def domain(self) -> Tensor:
193193
return torch.tensor([-1.0, 1.0])
194194

195195
@property

deep_tensor/polynomials/piecewise/piecewise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def eval_measure(self, ls: Tensor) -> Tensor:
8282
def eval_log_measure(self, ls: Tensor) -> Tensor:
8383
return torch.full(ls.shape, -self.domain_size.log())
8484

85-
def eval_measure_deriv(obj, ls: Tensor) -> Tensor:
85+
def eval_measure_deriv(self, ls: Tensor) -> Tensor:
8686
return torch.zeros_like(ls)
8787

8888
def eval_log_measure_deriv(self, ls: Tensor) -> Tensor:

deep_tensor/polynomials/spectral/chebyshev_1st.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24
from torch import Tensor
35

@@ -91,7 +93,7 @@ def eval_log_measure(self, ls: Tensor) -> Tensor:
9193
self._check_in_domain(ls)
9294
ts = 1.0 - ls.square()
9395
ts[ts < EPS] = EPS
94-
return -0.5 * ts.log() - torch.tensor(torch.pi).log()
96+
return -0.5 * torch.log(ts) - math.log(torch.pi)
9597

9698
def eval_log_measure_deriv(self, ls: Tensor) -> Tensor:
9799
self._check_in_domain(ls)

deep_tensor/polynomials/spectral/chebyshev_1st_trigo_cdf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple
2+
13
import torch
24
from torch import Tensor
35

@@ -59,7 +61,7 @@ def eval_int_basis(self, thetas: Tensor) -> Tensor:
5961
))
6062
return int_pws
6163

62-
def eval_int_basis_newton(self, thetas: Tensor) -> Tensor:
64+
def eval_int_basis_newton(self, thetas: Tensor) -> Tuple[Tensor, Tensor]:
6365
int_pws = self.eval_int_basis(thetas)
6466
thetas = thetas[:, None]
6567
derivs = self.norm * torch.cos(thetas * self.n) / torch.pi

deep_tensor/polynomials/spectral/chebyshev_2nd.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24
from torch import Tensor
35
from torch.distributions.beta import Beta
@@ -87,7 +89,7 @@ def eval_measure(self, ls: Tensor) -> Tensor:
8789
def eval_log_measure(self, ls: Tensor) -> Tensor:
8890
self._check_in_domain(ls)
8991
ts = 1.0 - ls.square()
90-
logws = 0.5 * ts.log() + torch.tensor(2.0/torch.pi).log()
92+
logws = 0.5 * ts.log() + math.log(2.0/torch.pi)
9193
return logws
9294

9395
def eval_measure_deriv(self, ls: Tensor) -> Tensor:

deep_tensor/polynomials/spectral/chebyshev_2nd_trigo_cdf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple
2+
13
import torch
24
from torch import Tensor
35

@@ -60,7 +62,7 @@ def eval_int_basis(self, thetas: Tensor) -> Tensor:
6062

6163
return ps
6264

63-
def eval_int_basis_newton(self, thetas: Tensor) -> Tensor:
65+
def eval_int_basis_newton(self, thetas: Tensor) -> Tuple[Tensor, Tensor]:
6466

6567
thetas = thetas[:, None]
6668

deep_tensor/polynomials/spectral/jacobi_11.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import torch
24
from torch import Tensor
35
from torch.distributions.beta import Beta
@@ -45,7 +47,7 @@ def eval_measure(self, ls: Tensor) -> Tensor:
4547
return ws
4648

4749
def eval_log_measure(self, ls: Tensor) -> Tensor:
48-
ws = (1.0 - ls.square()).log() + torch.tensor(0.75).log()
50+
ws = (1.0 - ls.square()).log() + math.log(0.75)
4951
return ws
5052

5153
def eval_measure_deriv(self, ls: Tensor) -> Tensor:

deep_tensor/polynomials/spectral/recurr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import abc
2-
from typing import Tuple
2+
from typing import Tuple, Union
33

44
import torch
55
from torch import Tensor
@@ -15,7 +15,7 @@ def __init__(
1515
a: Tensor,
1616
b: Tensor,
1717
c: Tensor,
18-
norm: float | Tensor
18+
norm: Tensor
1919
):
2020
"""Class for spectral polynomials for which the three-term
2121
recurrence relation is known. This relation takes the form

0 commit comments

Comments
 (0)