You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This document outlines the design requirements and implementation patterns for [`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) and its subclasses in the [`src/tensorcontainer/tensor_distribution/`](src/tensorcontainer/tensor_distribution/) module.
3
+
This document outlines the design requirements and implementation patterns for [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) and its subclasses in the [`src/tensorcontainer/tensor_distribution/`](/src/tensorcontainer/tensor_distribution/) module.
4
4
5
5
## Objective
6
6
7
-
The [`tensor_distribution`](src/tensorcontainer/tensor_distribution/) module provides [`torch.distributions.Distribution`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution) functionality that enables direct application of tensor operations to probability distributions. The module maintains complete signature compatibility with [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) while extending [`TensorContainer`](src/tensorcontainer/tensor_container.py:1) functionality through [`TensorAnnotated`](src/tensorcontainer/tensor_annotated.py) inheritance.
7
+
The [`tensor_distribution`](/src/tensorcontainer/tensor_distribution/) module provides `torch.distributions.Distribution` functionality that enables direct application of tensor operations to probability distributions. The module maintains complete signature compatibility with `torch.distributions` while extending [`TensorContainer`](/src/tensorcontainer/tensor_container.py) functionality through [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) inheritance.
8
8
9
9
## Architecture Requirements
10
10
11
11
### Signature Compatibility
12
12
13
-
All classes in [`tensorcontainer.tensor_distribution`](src/tensorcontainer/tensor_distribution/) must maintain exact signature compatibility with their corresponding [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) classes. This compatibility is enforced through automated testing using [`tests/tensor_distribution/conftest.py::assert_init_signatures_match`](tests/tensor_distribution/conftest.py).
13
+
All classes in [`tensorcontainer.tensor_distribution`](/src/tensorcontainer/tensor_distribution/) must maintain exact signature compatibility with their corresponding `torch.distributions` classes. This compatibility is enforced through automated testing using [`tests/tensor_distribution/conftest.py::assert_init_signatures_match`](/tests/tensor_distribution/conftest.py).
14
14
15
-
**Implementation Requirement**: When [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) classes lack proper type annotations for [`__init__`](src/tensorcontainer/tensor_distribution/base.py:65) parameters, implementers **must** consult the class docstring to determine correct type hints.
15
+
**Implementation Requirement**: When `torch.distributions` classes lack proper type annotations for `__init__` parameters, implementers **must** consult the class docstring to determine correct type hints.
16
16
17
17
### Distribution Delegation Pattern
18
18
19
-
[`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) subclasses **must not** implement distribution-specific logic. Instead, each subclass **must** implement a [`dist()`](src/tensorcontainer/tensor_distribution/base.py:143) method that constructs and returns the equivalent [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) object using the instance's parameters.
19
+
[`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) subclasses **must not** implement distribution-specific logic. Instead, each subclass **must** implement a `dist()` method that constructs and returns the equivalent `torch.distributions` object using the instance's parameters.
20
20
21
-
**Implementation Requirement**: The [`dist()`](src/tensorcontainer/tensor_distribution/base.py:143) method **must** return the raw [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) instance, not a wrapped one (e.g., with [`Independent`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.independent.Independent)).
21
+
**Implementation Requirement**: The `dist()` method **must** return the raw `torch.distributions` instance, not a wrapped one (e.g., with `Independent`).
22
22
23
-
**Design Principle**: [`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) serves as a parameter management wrapper around [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html), delegating all distribution operations to the underlying implementation via [`self.dist()`](src/tensorcontainer/tensor_distribution/base.py:143) calls.
23
+
**Design Principle**: [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) serves as a parameter management wrapper around `torch.distributions`, delegating all distribution operations to the underlying implementation via `self.dist()` calls.
24
24
25
25
### Parameter Broadcasting Requirements
26
26
27
-
Many [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) constructors accept parameters of type `Union[Number, Tensor]` or any specialization of `Number` (e.g. `float`). However, [`TensorContainer`](src/tensorcontainer/tensor_container.py) and [`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) can only process `Union[Tensor, TensorContainer]` objects and require all parameters to have compatible shapes for broadcasting.
27
+
Many `torch.distributions` constructors accept parameters of type `Union[Number, Tensor]` or any specialization of `Number` (e.g. `float`). However, [`TensorContainer`](/src/tensorcontainer/tensor_container.py) and [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) can only process `Union[Tensor, TensorContainer]` objects and require all parameters to have compatible shapes for broadcasting.
28
28
29
-
**Implementation Rule**: When the constructor signature contains `Union[Number, Tensor]` or any specialization of `Number` parameters, implementations **must** use [`torch.distributions.utils.broadcast_all`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.utils.broadcast_all) to:
29
+
**Implementation Rule**: When the constructor signature contains `Union[Number, Tensor]` or any specialization of `Number` parameters, implementations **must** use `torch.distributions.utils.broadcast_all` to:
30
30
1. Convert scalar numbers to tensors
31
31
2. Broadcast all parameters to a common shape
32
32
33
-
This preprocessing ensures proper shape and device management within the [`TensorAnnotated`](src/tensorcontainer/tensor_annotated.py) framework.
33
+
This preprocessing ensures proper shape and device management within the [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) framework.
34
34
35
35
**Decision Criterion**: If the constructor signature does not contain `Union[Number, Tensor]` parameters, simpler parameter handling approaches should be preferred.
36
36
37
37
### Validation Strategy
38
38
39
-
[`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) accepts a [`validate_args`](src/tensorcontainer/tensor_distribution/base.py:69) parameter during initialization and stores it as the [`_validate_args`](src/tensorcontainer/tensor_distribution/base.py:63) attribute of the base class. Subclasses must pass this value to the underlying [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) object (if the constructor supports it).
39
+
[`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) accepts a `validate_args` parameter during initialization and stores it as the `_validate_args` attribute of the base class. Subclasses must pass this value to the underlying `torch.distributions` object (if the constructor supports it).
40
40
41
-
**Validation Policy**: Parameter validation for [`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) subclasses is generally unnecessary because the [`TensorDistribution.__init__`](src/tensorcontainer/tensor_distribution/base.py:65) method constructs the underlying distribution once via [`self.dist()`](src/tensorcontainer/tensor_distribution/base.py:86), triggering parameter validation in the [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) implementation.
41
+
**Validation Policy**: Parameter validation for [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) subclasses is generally unnecessary because the `TensorDistribution.__init__` method constructs the underlying distribution once via `self.dist()`, triggering parameter validation in the `torch.distributions` implementation.
42
42
43
43
**Exception Handling**: Implementations should only raise validation errors when required parameters needed for device and shape inference are missing or invalid.
44
44
45
45
### Property Implementation Pattern
46
46
47
-
Following the [`torch.distributions.Distribution`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution) pattern, basic distribution properties are provided through the [`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) base class via delegation to [`self.dist()`](src/tensorcontainer/tensor_distribution/base.py:143).
47
+
Following the `torch.distributions.Distribution` pattern, basic distribution properties are provided through the [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) base class via delegation to `self.dist()`.
48
48
49
-
**Specialization Rule**: Distribution-specific properties **must** be implemented only in the corresponding subclass, maintaining the same delegation pattern to the underlying [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) object.
49
+
**Specialization Rule**: Distribution-specific properties **must** be implemented only in the corresponding subclass, maintaining the same delegation pattern to the underlying `torch.distributions` object.
50
50
51
51
## Implementation Patterns
52
52
53
53
### Annotated Attribute Pattern
54
54
55
-
All tensor parameters must be declared as annotated class attributes to enable automatic transformation by [`TensorAnnotated`](src/tensorcontainer/tensor_annotated.py) operations (e.g., [`.to()`](src/tensorcontainer/tensor_annotated.py), [`.expand()`](src/tensorcontainer/tensor_annotated.py)).
55
+
All tensor parameters must be declared as annotated class attributes to enable automatic transformation by [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) operations (e.g., `.to()`, `.expand()`).
56
56
57
57
**Example Pattern**:
58
58
```python
@@ -73,13 +73,13 @@ Note: If parameters like `loc` and `scale` could be scalars in the constructor s
73
73
74
74
### Lazy Distribution Creation
75
75
76
-
The actual [`torch.distributions.Distribution`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution) instance is created on-demand through the [`dist()`](src/tensorcontainer/tensor_distribution/base.py:143) method. This lazy evaluation pattern enables efficient tensor operations without premature distribution instantiation.
76
+
The actual `torch.distributions.Distribution` instance is created on-demand through the `dist()` method. This lazy evaluation pattern enables efficient tensor operations without premature distribution instantiation.
77
77
78
78
### Reconstruction Pattern
79
79
80
-
The [`_unflatten_distribution()`](src/tensorcontainer/tensor_distribution/base.py:115) class method reconstructs distribution instances from serialized tensor and metadata attributes. This method is called by [`_init_from_reconstructed()`](src/tensorcontainer/tensor_distribution/base.py:90) during operations like [`.to()`](src/tensorcontainer/tensor_annotated.py) and [`.expand()`](src/tensorcontainer/tensor_annotated.py).
80
+
The `_unflatten_distribution()` class method reconstructs distribution instances from serialized tensor and metadata attributes. This method is called by `_init_from_reconstructed()` during operations like `.to()` and `.expand()`.
81
81
82
-
**Customization Requirement**: Subclasses with complex parameter relationships **must** override [`_unflatten_distribution()`](src/tensorcontainer/tensor_distribution/base.py:115) to implement appropriate reconstruction logic.
82
+
**Customization Requirement**: Subclasses with complex parameter relationships **must** override `_unflatten_distribution()` to implement appropriate reconstruction logic.
0 commit comments