Skip to content

Commit 34c3ddb

Browse files
author
Tim Joseph
committed
refactor(tensor_distribution): simplify TensorLaplace initialization
Leverage `torch.distributions.utils.broadcast_all` to streamline parameter broadcasting and type conversion, which also simplifies shape and device inference. This change removes manual validation and broadcasting logic, making the implementation more concise and consistent with other distributions in the module. Redundant overrides for `log_prob` and `variance` have also been removed as they are handled by the base class.
1 parent 57e60e9 commit 34c3ddb

File tree

1 file changed

+5
-24
lines changed

1 file changed

+5
-24
lines changed

src/tensorcontainer/tensor_distribution/laplace.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from typing import Any, Dict, Optional
44

5-
import torch
65
from torch import Tensor
76
from torch.distributions import Laplace
7+
from torch.distributions.utils import broadcast_all
88

99
from .base import TensorDistribution
1010

@@ -22,21 +22,9 @@ def __init__(
2222
scale: Tensor | float,
2323
validate_args: Optional[bool] = None,
2424
):
25-
# Store the parameters in annotated attributes before calling super().__init__()
26-
# This is required because super().__init__() calls self.dist() which needs these attributes
27-
self._loc = loc if isinstance(loc, Tensor) else torch.tensor(loc)
28-
self._scale = scale if isinstance(scale, Tensor) else torch.tensor(scale)
29-
30-
if torch.any(self._scale <= 0):
31-
raise ValueError("scale must be positive")
32-
33-
try:
34-
torch.broadcast_tensors(self._loc, self._scale)
35-
except RuntimeError as e:
36-
raise ValueError(f"loc and scale must have compatible shapes: {e}")
37-
38-
shape = torch.broadcast_shapes(self._loc.shape, self._scale.shape)
39-
device = self._loc.device if self._loc.numel() > 0 else self._scale.device
25+
self._loc, self._scale = broadcast_all(loc, scale)
26+
shape = self._loc.shape
27+
device = self._loc.device
4028
super().__init__(shape, device, validate_args)
4129

4230
@classmethod
@@ -55,9 +43,6 @@ def dist(self) -> Laplace:
5543
validate_args=self._validate_args,
5644
)
5745

58-
def log_prob(self, value: Tensor) -> Tensor:
59-
return self.dist().log_prob(value)
60-
6146
@property
6247
def loc(self) -> Optional[Tensor]:
6348
"""Returns the loc used to initialize the distribution."""
@@ -68,8 +53,4 @@ def scale(self) -> Optional[Tensor]:
6853
"""Returns the scale used to initialize the distribution."""
6954
return self.dist().scale
7055

71-
@property
72-
def variance(self) -> Tensor:
73-
"""Returns the variance of the Laplace distribution."""
74-
assert self._scale is not None
75-
return self.dist().variance
56+

0 commit comments

Comments
 (0)