Skip to content

Commit b503488

Browse files
author
Tim Joseph
committed
Fixed links
1 parent 34c3ddb commit b503488

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed
Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,58 @@
11
# TensorDistribution Development Guide
22

3-
This document outlines the design requirements and implementation patterns for [`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) and its subclasses in the [`src/tensorcontainer/tensor_distribution/`](src/tensorcontainer/tensor_distribution/) module.
3+
This document outlines the design requirements and implementation patterns for [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) and its subclasses in the [`src/tensorcontainer/tensor_distribution/`](/src/tensorcontainer/tensor_distribution/) module.
44

55
## Objective
66

7-
The [`tensor_distribution`](src/tensorcontainer/tensor_distribution/) module provides [`torch.distributions.Distribution`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution) functionality that enables direct application of tensor operations to probability distributions. The module maintains complete signature compatibility with [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) while extending [`TensorContainer`](src/tensorcontainer/tensor_container.py:1) functionality through [`TensorAnnotated`](src/tensorcontainer/tensor_annotated.py) inheritance.
7+
The [`tensor_distribution`](/src/tensorcontainer/tensor_distribution/) module provides `torch.distributions.Distribution` functionality that enables direct application of tensor operations to probability distributions. The module maintains complete signature compatibility with `torch.distributions` while extending [`TensorContainer`](/src/tensorcontainer/tensor_container.py) functionality through [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) inheritance.
88

99
## Architecture Requirements
1010

1111
### Signature Compatibility
1212

13-
All classes in [`tensorcontainer.tensor_distribution`](src/tensorcontainer/tensor_distribution/) must maintain exact signature compatibility with their corresponding [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) classes. This compatibility is enforced through automated testing using [`tests/tensor_distribution/conftest.py::assert_init_signatures_match`](tests/tensor_distribution/conftest.py).
13+
All classes in [`tensorcontainer.tensor_distribution`](/src/tensorcontainer/tensor_distribution/) must maintain exact signature compatibility with their corresponding `torch.distributions` classes. This compatibility is enforced through automated testing using [`tests/tensor_distribution/conftest.py::assert_init_signatures_match`](/tests/tensor_distribution/conftest.py).
1414

15-
**Implementation Requirement**: When [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) classes lack proper type annotations for [`__init__`](src/tensorcontainer/tensor_distribution/base.py:65) parameters, implementers **must** consult the class docstring to determine correct type hints.
15+
**Implementation Requirement**: When `torch.distributions` classes lack proper type annotations for `__init__` parameters, implementers **must** consult the class docstring to determine correct type hints.
1616

1717
### Distribution Delegation Pattern
1818

19-
[`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) subclasses **must not** implement distribution-specific logic. Instead, each subclass **must** implement a [`dist()`](src/tensorcontainer/tensor_distribution/base.py:143) method that constructs and returns the equivalent [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) object using the instance's parameters.
19+
[`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) subclasses **must not** implement distribution-specific logic. Instead, each subclass **must** implement a `dist()` method that constructs and returns the equivalent `torch.distributions` object using the instance's parameters.
2020

21-
**Implementation Requirement**: The [`dist()`](src/tensorcontainer/tensor_distribution/base.py:143) method **must** return the raw [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) instance, not a wrapped one (e.g., with [`Independent`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.independent.Independent)).
21+
**Implementation Requirement**: The `dist()` method **must** return the raw `torch.distributions` instance, not a wrapped one (e.g., with `Independent`).
2222

23-
**Design Principle**: [`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) serves as a parameter management wrapper around [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html), delegating all distribution operations to the underlying implementation via [`self.dist()`](src/tensorcontainer/tensor_distribution/base.py:143) calls.
23+
**Design Principle**: [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) serves as a parameter management wrapper around `torch.distributions`, delegating all distribution operations to the underlying implementation via `self.dist()` calls.
2424

2525
### Parameter Broadcasting Requirements
2626

27-
Many [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) constructors accept parameters of type `Union[Number, Tensor]` or any specialization of `Number` (e.g. `float`). However, [`TensorContainer`](src/tensorcontainer/tensor_container.py) and [`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) can only process `Union[Tensor, TensorContainer]` objects and require all parameters to have compatible shapes for broadcasting.
27+
Many `torch.distributions` constructors accept parameters of type `Union[Number, Tensor]` or any specialization of `Number` (e.g. `float`). However, [`TensorContainer`](/src/tensorcontainer/tensor_container.py) and [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) can only process `Union[Tensor, TensorContainer]` objects and require all parameters to have compatible shapes for broadcasting.
2828

29-
**Implementation Rule**: When the constructor signature contains `Union[Number, Tensor]` or any specialization of `Number` parameters, implementations **must** use [`torch.distributions.utils.broadcast_all`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.utils.broadcast_all) to:
29+
**Implementation Rule**: When the constructor signature contains `Union[Number, Tensor]` or any specialization of `Number` parameters, implementations **must** use `torch.distributions.utils.broadcast_all` to:
3030
1. Convert scalar numbers to tensors
3131
2. Broadcast all parameters to a common shape
3232

33-
This preprocessing ensures proper shape and device management within the [`TensorAnnotated`](src/tensorcontainer/tensor_annotated.py) framework.
33+
This preprocessing ensures proper shape and device management within the [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) framework.
3434

3535
**Decision Criterion**: If the constructor signature does not contain `Union[Number, Tensor]` parameters, simpler parameter handling approaches should be preferred.
3636

3737
### Validation Strategy
3838

39-
[`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) accepts a [`validate_args`](src/tensorcontainer/tensor_distribution/base.py:69) parameter during initialization and stores it as the [`_validate_args`](src/tensorcontainer/tensor_distribution/base.py:63) attribute of the base class. Subclasses must pass this value to the underlying [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) object (if the constructor supports it).
39+
[`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) accepts a `validate_args` parameter during initialization and stores it as the `_validate_args` attribute of the base class. Subclasses must pass this value to the underlying `torch.distributions` object (if the constructor supports it).
4040

41-
**Validation Policy**: Parameter validation for [`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) subclasses is generally unnecessary because the [`TensorDistribution.__init__`](src/tensorcontainer/tensor_distribution/base.py:65) method constructs the underlying distribution once via [`self.dist()`](src/tensorcontainer/tensor_distribution/base.py:86), triggering parameter validation in the [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) implementation.
41+
**Validation Policy**: Parameter validation for [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) subclasses is generally unnecessary because the `TensorDistribution.__init__` method constructs the underlying distribution once via `self.dist()`, triggering parameter validation in the `torch.distributions` implementation.
4242

4343
**Exception Handling**: Implementations should only raise validation errors when required parameters needed for device and shape inference are missing or invalid.
4444

4545
### Property Implementation Pattern
4646

47-
Following the [`torch.distributions.Distribution`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution) pattern, basic distribution properties are provided through the [`TensorDistribution`](src/tensorcontainer/tensor_distribution/base.py:14) base class via delegation to [`self.dist()`](src/tensorcontainer/tensor_distribution/base.py:143).
47+
Following the `torch.distributions.Distribution` pattern, basic distribution properties are provided through the [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) base class via delegation to `self.dist()`.
4848

49-
**Specialization Rule**: Distribution-specific properties **must** be implemented only in the corresponding subclass, maintaining the same delegation pattern to the underlying [`torch.distributions`](https://pytorch.org/docs/stable/distributions.html) object.
49+
**Specialization Rule**: Distribution-specific properties **must** be implemented only in the corresponding subclass, maintaining the same delegation pattern to the underlying `torch.distributions` object.
5050

5151
## Implementation Patterns
5252

5353
### Annotated Attribute Pattern
5454

55-
All tensor parameters must be declared as annotated class attributes to enable automatic transformation by [`TensorAnnotated`](src/tensorcontainer/tensor_annotated.py) operations (e.g., [`.to()`](src/tensorcontainer/tensor_annotated.py), [`.expand()`](src/tensorcontainer/tensor_annotated.py)).
55+
All tensor parameters must be declared as annotated class attributes to enable automatic transformation by [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) operations (e.g., `.to()`, `.expand()`).
5656

5757
**Example Pattern**:
5858
```python
@@ -73,13 +73,13 @@ Note: If parameters like `loc` and `scale` could be scalars in the constructor s
7373

7474
### Lazy Distribution Creation
7575

76-
The actual [`torch.distributions.Distribution`](https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution) instance is created on-demand through the [`dist()`](src/tensorcontainer/tensor_distribution/base.py:143) method. This lazy evaluation pattern enables efficient tensor operations without premature distribution instantiation.
76+
The actual `torch.distributions.Distribution` instance is created on-demand through the `dist()` method. This lazy evaluation pattern enables efficient tensor operations without premature distribution instantiation.
7777

7878
### Reconstruction Pattern
7979

80-
The [`_unflatten_distribution()`](src/tensorcontainer/tensor_distribution/base.py:115) class method reconstructs distribution instances from serialized tensor and metadata attributes. This method is called by [`_init_from_reconstructed()`](src/tensorcontainer/tensor_distribution/base.py:90) during operations like [`.to()`](src/tensorcontainer/tensor_annotated.py) and [`.expand()`](src/tensorcontainer/tensor_annotated.py).
80+
The `_unflatten_distribution()` class method reconstructs distribution instances from serialized tensor and metadata attributes. This method is called by `_init_from_reconstructed()` during operations like `.to()` and `.expand()`.
8181

82-
**Customization Requirement**: Subclasses with complex parameter relationships **must** override [`_unflatten_distribution()`](src/tensorcontainer/tensor_distribution/base.py:115) to implement appropriate reconstruction logic.
82+
**Customization Requirement**: Subclasses with complex parameter relationships **must** override `_unflatten_distribution()` to implement appropriate reconstruction logic.
8383

8484
**Example Implementation**:
8585
```python
@@ -91,4 +91,3 @@ def _unflatten_distribution(cls, attributes: Dict[str, Any]):
9191
logits=attributes.get("_logits"),
9292
validate_args=attributes.get("_validate_args"),
9393
)
94-
```

0 commit comments

Comments
 (0)