Skip to content

Commit 1d7d2d2

Browse files
authored
Merge pull request #10 from mctigger/unify-tensor_distribution-implementations
Unify tensor distribution implementations
2 parents 5a9fac3 + 58f8698 commit 1d7d2d2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+881
-1070
lines changed

src/tensorcontainer/tensor_annotated.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, Iterable, List, Tuple, TypeVar, Union, get_args
3+
from typing import Any, Iterable, TypeVar, Union, get_args
44

55
import torch
66
from torch import Tensor
@@ -20,7 +20,7 @@
2020
class TensorAnnotated(TensorContainer, PytreeRegistered):
2121
def __init__(
2222
self,
23-
shape: torch.Size | List[int] | Tuple[int],
23+
shape: torch.Size | list[int] | tuple[int, ...],
2424
device: str | torch.device | int | None,
2525
):
2626
super().__init__(shape, device, True)
@@ -76,14 +76,14 @@ def _get_meta_attributes(self):
7676
return meta_attributes
7777

7878
def _get_pytree_context(
79-
self, flat_names: List[str], flat_leaves: List[TDCompatible], meta_data
80-
) -> Tuple:
79+
self, flat_names: list[str], flat_leaves: list[TDCompatible], meta_data
80+
) -> tuple:
8181
batch_ndim = len(self.shape)
8282
event_ndims = tuple(leaf.ndim - batch_ndim for leaf in flat_leaves)
8383

8484
return flat_names, event_ndims, meta_data, self.device
8585

86-
def _pytree_flatten(self) -> Tuple[List[Any], Any]:
86+
def _pytree_flatten(self) -> tuple[list[Any], Any]:
8787
tensor_attributes = self._get_tensor_attributes()
8888
flat_names = list(tensor_attributes.keys())
8989
flat_values = list(tensor_attributes.values())
@@ -132,8 +132,8 @@ def _pytree_unflatten(cls, leaves: Iterable[Any], context: pytree.Context) -> Se
132132
@classmethod
133133
def _init_from_reconstructed(
134134
cls,
135-
tensor_attributes: Dict[str, TDCompatible],
136-
meta_attributes: Dict[str, Any],
135+
tensor_attributes: dict[str, TDCompatible],
136+
meta_attributes: dict[str, Any],
137137
device,
138138
shape,
139139
):

src/tensorcontainer/tensor_distribution/base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import abstractmethod
4-
from typing import Any, Dict, List, Tuple
4+
from typing import Any
55

66
from torch import Size, Tensor
77
from torch._C import device
@@ -47,10 +47,10 @@ class TensorDistribution(TensorAnnotated):
4747
--------
4848
```python
4949
class TensorNormal(TensorDistribution):
50-
_loc: Optional[Tensor] = None
51-
_scale: Optional[Tensor] = None
50+
_loc: Tensor | None = None
51+
_scale: Tensor | None = None
5252
53-
def __init__(self, loc: Tensor, scale: Tensor, validate_args: Optional[bool] = None):
53+
def __init__(self, loc: Tensor, scale: Tensor, validate_args: bool | None = None):
5454
self._loc = loc
5555
self._scale = scale
5656
super().__init__(loc.shape, loc.device, validate_args)
@@ -64,7 +64,7 @@ def dist(self) -> Distribution:
6464

6565
def __init__(
6666
self,
67-
shape: Size | List[int] | Tuple[int],
67+
shape: Size | list[int] | tuple[int, ...],
6868
device: str | device | int | None,
6969
validate_args: bool | None = None,
7070
):
@@ -89,8 +89,8 @@ def __init__(
8989
@classmethod
9090
def _init_from_reconstructed(
9191
cls,
92-
tensor_attributes: Dict[str, TDCompatible],
93-
meta_attributes: Dict[str, Any],
92+
tensor_attributes: dict[str, TDCompatible],
93+
meta_attributes: dict[str, Any],
9494
device,
9595
shape,
9696
):
@@ -112,7 +112,7 @@ def _init_from_reconstructed(
112112
return cls._unflatten_distribution({**tensor_attributes, **meta_attributes})
113113

114114
@classmethod
115-
def _unflatten_distribution(cls, attributes: Dict[str, Any]):
115+
def _unflatten_distribution(cls, attributes: dict[str, Any]):
116116
"""
117117
Reconstruct a distribution from flattened tensor and metadata attributes.
118118

src/tensorcontainer/tensor_distribution/bernoulli.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, Optional, Union
3+
from typing import Any
44

55
from torch import Size, Tensor
66
from torch.distributions import Bernoulli
7-
from torch.distributions.utils import broadcast_all
7+
from .utils import broadcast_all
88
from torch.types import Number
99

1010
from .base import TensorDistribution
@@ -14,14 +14,14 @@ class TensorBernoulli(TensorDistribution):
1414
"""Tensor-aware Bernoulli distribution."""
1515

1616
# Annotated tensor parameters
17-
_probs: Optional[Tensor]
18-
_logits: Optional[Tensor]
17+
_probs: Tensor | None
18+
_logits: Tensor | None
1919

2020
def __init__(
2121
self,
22-
probs: Optional[Union[Number, Tensor]] = None,
23-
logits: Optional[Union[Number, Tensor]] = None,
24-
validate_args: Optional[bool] = None,
22+
probs: Number | Tensor | None = None,
23+
logits: Number | Tensor | None = None,
24+
validate_args: bool | None = None,
2525
):
2626
if (probs is None) == (logits is None):
2727
raise ValueError(
@@ -43,7 +43,7 @@ def __init__(
4343
super().__init__(shape, device, validate_args)
4444

4545
@classmethod
46-
def _unflatten_distribution(cls, attributes: Dict[str, Any]) -> TensorBernoulli:
46+
def _unflatten_distribution(cls, attributes: dict[str, Any]) -> TensorBernoulli:
4747
"""Reconstruct distribution from tensor attributes."""
4848
return cls(
4949
probs=attributes.get("_probs"),
@@ -56,21 +56,6 @@ def dist(self) -> Bernoulli:
5656
probs=self._probs, logits=self._logits, validate_args=self._validate_args
5757
)
5858

59-
def log_prob(self, value: Tensor) -> Tensor:
60-
return self.dist().log_prob(value)
61-
62-
@property
63-
def mean(self) -> Tensor:
64-
return self.dist().mean
65-
66-
@property
67-
def variance(self) -> Tensor:
68-
return self.dist().variance
69-
70-
@property
71-
def mode(self) -> Tensor:
72-
return self.dist().mode
73-
7459
@property
7560
def logits(self) -> Tensor:
7661
return self.dist().logits
Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, Optional
3+
from typing import Any
44

55
from torch import Tensor
66
from torch.distributions import Beta
7-
from torch.distributions.utils import broadcast_all
7+
from .utils import broadcast_all
88

99
from .base import TensorDistribution
1010

@@ -29,9 +29,9 @@ class TensorBeta(TensorDistribution):
2929

3030
def __init__(
3131
self,
32-
concentration1: Tensor,
33-
concentration0: Tensor,
34-
validate_args: Optional[bool] = None,
32+
concentration1: float | Tensor,
33+
concentration0: float | Tensor,
34+
validate_args: bool | None = None,
3535
):
3636
self._concentration1, self._concentration0 = broadcast_all(
3737
concentration1, concentration0
@@ -45,12 +45,12 @@ def __init__(
4545
@classmethod
4646
def _unflatten_distribution(
4747
cls,
48-
attributes: Dict[str, Any],
48+
attributes: dict[str, Any],
4949
) -> TensorBeta:
5050
"""Reconstruct distribution from tensor attributes."""
5151
return cls(
52-
concentration1=attributes.get("_concentration1"), # type: ignore
53-
concentration0=attributes.get("_concentration0"), # type: ignore
52+
concentration1=attributes["_concentration1"],
53+
concentration0=attributes["_concentration0"],
5454
validate_args=attributes.get("_validate_args"),
5555
)
5656

@@ -62,9 +62,6 @@ def dist(self) -> Beta:
6262
validate_args=self._validate_args,
6363
)
6464

65-
def log_prob(self, value: Tensor) -> Tensor:
66-
return self.dist().log_prob(value)
67-
6865
@property
6966
def concentration1(self) -> Tensor:
7067
"""Returns the concentration1 parameter of the distribution."""
@@ -74,23 +71,3 @@ def concentration1(self) -> Tensor:
7471
def concentration0(self) -> Tensor:
7572
"""Returns the concentration0 parameter of the distribution."""
7673
return self.dist().concentration0
77-
78-
@property
79-
def mean(self) -> Tensor:
80-
"""Returns the mean of the distribution."""
81-
return self.dist().mean
82-
83-
@property
84-
def variance(self) -> Tensor:
85-
"""Returns the variance of the distribution."""
86-
return self.dist().variance
87-
88-
@property
89-
def mode(self) -> Tensor:
90-
"""Returns the mode of the distribution."""
91-
return self.dist().mode
92-
93-
@property
94-
def stddev(self) -> Tensor:
95-
"""Returns the standard deviation of the distribution."""
96-
return self.dist().stddev
Lines changed: 18 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, Optional, Union
3+
from typing import Any
44

5-
import torch
65
from torch import Size, Tensor
76
from torch.distributions import Binomial
8-
from torch.distributions.utils import broadcast_all
7+
from .utils import broadcast_all
98

109
from .base import TensorDistribution
1110

@@ -24,16 +23,16 @@ class TensorBinomial(TensorDistribution):
2423
"""
2524

2625
# Annotated tensor parameters
27-
_total_count: Union[int, Tensor]
28-
_probs: Optional[Tensor] = None
29-
_logits: Optional[Tensor] = None
26+
_total_count: Tensor
27+
_probs: Tensor | None = None
28+
_logits: Tensor | None = None
3029

3130
def __init__(
3231
self,
33-
total_count: Union[int, Tensor] = 1,
34-
probs: Optional[Tensor] = None,
35-
logits: Optional[Tensor] = None,
36-
validate_args: Optional[bool] = None,
32+
total_count: int | Tensor = 1,
33+
probs: Tensor | None = None,
34+
logits: Tensor | None = None,
35+
validate_args: bool | None = None,
3736
):
3837
if (probs is None) == (logits is None):
3938
raise ValueError(
@@ -42,36 +41,19 @@ def __init__(
4241

4342
if probs is not None:
4443
self._total_count, self._probs = broadcast_all(total_count, probs)
45-
param = self._probs
46-
assert param is not None
4744
else:
4845
self._total_count, self._logits = broadcast_all(total_count, logits)
49-
param = self._logits
50-
assert param is not None
5146

52-
# Ensure total_count has the same dtype as the parameter tensor if it's a Tensor
53-
if isinstance(self._total_count, Tensor):
54-
self._total_count = self._total_count.type_as(param)
55-
56-
shape = param.shape
57-
device = param.device
47+
shape = self._total_count.shape
48+
device = self._total_count.device
5849

5950
super().__init__(shape, device, validate_args)
6051

6152
@classmethod
62-
def _unflatten_distribution(cls, attributes: Dict[str, Any]):
63-
"""Reconstruct distribution from tensor attributes."""
53+
def _unflatten_distribution(cls, attributes: dict[str, Any]):
6454
total_count = attributes["_total_count"]
65-
if isinstance(total_count, Tensor):
66-
total_count = total_count.clone()
67-
68-
probs = attributes.get("_probs")
69-
if probs is not None:
70-
probs = probs.clone()
71-
72-
logits = attributes.get("_logits")
73-
if logits is not None:
74-
logits = logits.clone()
55+
probs = attributes["_probs"]
56+
logits = attributes["_logits"]
7557

7658
return cls(
7759
total_count=total_count,
@@ -81,63 +63,29 @@ def _unflatten_distribution(cls, attributes: Dict[str, Any]):
8163
)
8264

8365
def dist(self) -> Binomial:
84-
total_count = self._total_count
85-
if isinstance(total_count, int):
86-
# Convert int total_count to a tensor with the correct device and dtype
87-
# The device and dtype should match the probs/logits tensor
88-
if self._probs is not None:
89-
total_count = torch.tensor(
90-
total_count, device=self._probs.device, dtype=self._probs.dtype
91-
)
92-
elif self._logits is not None:
93-
total_count = torch.tensor(
94-
total_count, device=self._logits.device, dtype=self._logits.dtype
95-
)
96-
else:
97-
# Fallback if neither probs nor logits are set (should not happen with current init logic)
98-
total_count = torch.tensor(total_count)
99-
10066
return Binomial(
101-
total_count=total_count,
67+
total_count=self._total_count,
10268
probs=self._probs,
10369
logits=self._logits,
10470
validate_args=self._validate_args,
10571
)
10672

107-
def log_prob(self, value: Tensor) -> Tensor:
108-
return self.dist().log_prob(value)
109-
11073
@property
111-
def total_count(self) -> Union[int, Tensor]:
74+
def total_count(self) -> Tensor:
11275
"""Returns the total_count parameter of the distribution."""
11376
return self._total_count
11477

11578
@property
116-
def probs(self) -> Optional[Tensor]:
79+
def probs(self) -> Tensor | None:
11780
"""Returns the probs parameter of the distribution."""
11881
return self.dist().probs
11982

12083
@property
121-
def logits(self) -> Optional[Tensor]:
84+
def logits(self) -> Tensor | None:
12285
"""Returns the logits parameter of the distribution."""
12386
return self.dist().logits
12487

12588
@property
12689
def param_shape(self) -> Size:
12790
"""Returns the shape of the underlying parameter."""
12891
return self.dist().param_shape
129-
130-
@property
131-
def mean(self) -> Tensor:
132-
"""Returns the mean of the Binomial distribution."""
133-
return self.dist().mean
134-
135-
@property
136-
def variance(self) -> Tensor:
137-
"""Returns the variance of the Binomial distribution."""
138-
return self.dist().variance
139-
140-
@property
141-
def mode(self) -> Tensor:
142-
"""Returns the mode of the Binomial distribution."""
143-
return self.dist().mode

0 commit comments

Comments
 (0)