Skip to content

Commit 5a9fac3

Browse files
authored
Merge pull request #9 from mctigger/distribution-validate_args
2 parents bda0102 + 6cabfed commit 5a9fac3

File tree

9 files changed

+136
-301
lines changed

9 files changed

+136
-301
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# TensorDistribution Development Guide
2+
3+
This document outlines the design requirements and implementation patterns for [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) and its subclasses in the [`src/tensorcontainer/tensor_distribution/`](/src/tensorcontainer/tensor_distribution/) module.
4+
5+
## Objective
6+
7+
The [`tensor_distribution`](/src/tensorcontainer/tensor_distribution/) module provides `torch.distributions.Distribution` functionality that enables direct application of tensor operations to probability distributions. The module maintains complete signature compatibility with `torch.distributions` while extending [`TensorContainer`](/src/tensorcontainer/tensor_container.py) functionality through [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) inheritance.
8+
9+
## Architecture Requirements
10+
11+
### Signature Compatibility
12+
13+
All classes in [`tensorcontainer.tensor_distribution`](/src/tensorcontainer/tensor_distribution/) must maintain exact signature compatibility with their corresponding `torch.distributions` classes. This compatibility is enforced through automated testing using [`tests/tensor_distribution/conftest.py::assert_init_signatures_match`](/tests/tensor_distribution/conftest.py).
14+
15+
**Implementation Requirement**: When `torch.distributions` classes lack proper type annotations for `__init__` parameters, implementers **must** consult the class docstring to determine correct type hints.
16+
17+
### Distribution Delegation Pattern
18+
19+
[`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) subclasses **must not** implement distribution-specific logic. Instead, each subclass **must** implement a `dist()` method that constructs and returns the equivalent `torch.distributions` object using the instance's parameters.
20+
21+
**Implementation Requirement**: The `dist()` method **must** return the raw `torch.distributions` instance, not a wrapped one (e.g., with `Independent`).
22+
23+
**Design Principle**: [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) serves as a parameter management wrapper around `torch.distributions`, delegating all distribution operations to the underlying implementation via `self.dist()` calls.
24+
25+
### Parameter Broadcasting Requirements
26+
27+
Many `torch.distributions` constructors accept parameters of type `Union[Number, Tensor]` or any specialization of `Number` (e.g. `float`). However, [`TensorContainer`](/src/tensorcontainer/tensor_container.py) and [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) can only process `Union[Tensor, TensorContainer]` objects and require all parameters to have compatible shapes for broadcasting.
28+
29+
**Implementation Rule**: When the constructor signature contains `Union[Number, Tensor]` or any specialization of `Number` parameters, implementations **must** use `torch.distributions.utils.broadcast_all` to:
30+
1. Convert scalar numbers to tensors
31+
2. Broadcast all parameters to a common shape
32+
33+
This preprocessing ensures proper shape and device management within the [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) framework.
34+
35+
**Decision Criterion**: If the constructor signature does not contain `Union[Number, Tensor]` parameters, simpler parameter handling approaches should be preferred.
36+
37+
### Validation Strategy
38+
39+
[`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) accepts a `validate_args` parameter during initialization and stores it as the `_validate_args` attribute of the base class. Subclasses must pass this value to the underlying `torch.distributions` object (if the constructor supports it).
40+
41+
**Validation Policy**: Parameter validation for [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) subclasses is generally unnecessary because the `TensorDistribution.__init__` method constructs the underlying distribution once via `self.dist()`, triggering parameter validation in the `torch.distributions` implementation.
42+
43+
**Exception Handling**: Implementations should only raise validation errors when required parameters needed for device and shape inference are missing or invalid.
44+
45+
### Property Implementation Pattern
46+
47+
Following the `torch.distributions.Distribution` pattern, basic distribution properties are provided through the [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) base class via delegation to `self.dist()`.
48+
49+
**Specialization Rule**: Distribution-specific properties **must** be implemented only in the corresponding subclass, maintaining the same delegation pattern to the underlying `torch.distributions` object.
50+
51+
## Implementation Patterns
52+
53+
### Annotated Attribute Pattern
54+
55+
All tensor parameters must be declared as annotated class attributes to enable automatic transformation by [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) operations (e.g., `.to()`, `.expand()`).
56+
57+
**Example Pattern**:
58+
```python
59+
class TensorNormal(TensorDistribution):
60+
_loc: Tensor
61+
_scale: Tensor
62+
63+
def __init__(self, loc: Tensor, scale: Tensor, validate_args: Optional[bool] = None):
64+
self._loc = loc
65+
self._scale = scale
66+
super().__init__(loc.shape, loc.device, validate_args)
67+
68+
def dist(self) -> Distribution:
69+
return Normal(self._loc, self._scale, validate_args=self._validate_args)
70+
```
71+
72+
Note: If parameters like `loc` and `scale` could be scalars in the constructor signature, apply the broadcasting rules described in the "Parameter Broadcasting" section before assignment to ensure proper tensor handling.
73+
74+
### Lazy Distribution Creation
75+
76+
The actual `torch.distributions.Distribution` instance is created on-demand through the `dist()` method. This lazy evaluation pattern enables efficient tensor operations without premature distribution instantiation.
77+
78+
### Reconstruction Pattern
79+
80+
The `_unflatten_distribution()` class method reconstructs distribution instances from serialized tensor and metadata attributes. This method is called by `_init_from_reconstructed()` during operations like `.to()` and `.expand()`.
81+
82+
**Customization Requirement**: Subclasses with complex parameter relationships **must** override `_unflatten_distribution()` to implement appropriate reconstruction logic.
83+
84+
**Example Implementation**:
85+
```python
86+
@classmethod
87+
def _unflatten_distribution(cls, attributes: Dict[str, Any]):
88+
"""For TensorCategorical, extract _probs and _logits from attributes."""
89+
return cls(
90+
probs=attributes.get("_probs"),
91+
logits=attributes.get("_logits"),
92+
validate_args=attributes.get("_validate_args"),
93+
)

refactored.md

Lines changed: 0 additions & 175 deletions
This file was deleted.

src/tensorcontainer/tensor_distribution/beta.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, Optional, get_args
3+
from typing import Any, Dict, Optional
44

55
from torch import Tensor
66
from torch.distributions import Beta
77
from torch.distributions.utils import broadcast_all
8-
from torch.types import Number
98

109
from .base import TensorDistribution
1110

@@ -37,13 +36,7 @@ def __init__(
3736
self._concentration1, self._concentration0 = broadcast_all(
3837
concentration1, concentration0
3938
)
40-
41-
if isinstance(concentration1, get_args(Number)) and isinstance(
42-
concentration0, get_args(Number)
43-
):
44-
shape = tuple()
45-
else:
46-
shape = self._concentration1.shape
39+
shape = self._concentration1.shape
4740

4841
device = self._concentration1.device
4942

src/tensorcontainer/tensor_distribution/continuous_bernoulli.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Optional, Tuple, Union, get_args
1+
from typing import Any, Dict, Optional, Tuple, Union
22

33
import torch
44
from torch import Tensor
55
from torch.distributions import ContinuousBernoulli as TorchContinuousBernoulli
6+
from torch.distributions.utils import broadcast_all
67
from torch.types import Number
78

89
from .base import TensorDistribution
@@ -21,34 +22,21 @@ def __init__(
2122
validate_args: Optional[bool] = None,
2223
) -> None:
2324
self._lims = lims
24-
25-
if probs is not None and logits is not None:
25+
if (probs is None) == (logits is None):
2626
raise ValueError(
2727
"Either `probs` or `logits` must be specified, but not both."
2828
)
29-
elif probs is None and logits is None:
30-
raise ValueError("Either `probs` or `logits` must be specified.")
31-
32-
if probs is not None and isinstance(probs, get_args(Number)):
33-
self._probs = torch.tensor(probs)
34-
else:
35-
self._probs = probs
3629

37-
if logits is not None and isinstance(logits, get_args(Number)):
38-
self._logits = torch.tensor(logits)
30+
if probs is not None:
31+
(self._probs,) = broadcast_all(probs)
32+
self._logits = None
3933
else:
40-
self._logits = logits
34+
(self._logits,) = broadcast_all(logits)
35+
self._probs = None
4136

42-
if self._probs is not None:
43-
batch_shape = self._probs.shape
44-
device = self._probs.device
45-
elif self._logits is not None:
46-
batch_shape = self._logits.shape
47-
device = self._logits.device
48-
else:
49-
# This case should ideally not be reached due to the checks above,
50-
# but as a fallback for type inference or future changes.
51-
raise ValueError("Either `probs` or `logits` must be specified.")
37+
data = self._probs if self._probs is not None else self._logits
38+
batch_shape = data.shape # type: ignore
39+
device = data.device # type: ignore
5240

5341
super().__init__(shape=batch_shape, device=device, validate_args=validate_args)
5442

@@ -63,7 +51,7 @@ def dist(self) -> TorchContinuousBernoulli:
6351
@classmethod
6452
def _unflatten_distribution(
6553
cls,
66-
attributes: dict,
54+
attributes: Dict[str, Any],
6755
) -> "TensorContinuousBernoulli":
6856
return cls(
6957
probs=attributes.get("_probs"),

0 commit comments

Comments
 (0)