|
| 1 | +# TensorDistribution Development Guide |
| 2 | + |
| 3 | +This document outlines the design requirements and implementation patterns for [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) and its subclasses in the [`src/tensorcontainer/tensor_distribution/`](/src/tensorcontainer/tensor_distribution/) module. |
| 4 | + |
| 5 | +## Objective |
| 6 | + |
| 7 | +The [`tensor_distribution`](/src/tensorcontainer/tensor_distribution/) module provides `torch.distributions.Distribution` functionality that enables direct application of tensor operations to probability distributions. The module maintains complete signature compatibility with `torch.distributions` while extending [`TensorContainer`](/src/tensorcontainer/tensor_container.py) functionality through [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) inheritance. |
| 8 | + |
| 9 | +## Architecture Requirements |
| 10 | + |
| 11 | +### Signature Compatibility |
| 12 | + |
| 13 | +All classes in [`tensorcontainer.tensor_distribution`](/src/tensorcontainer/tensor_distribution/) must maintain exact signature compatibility with their corresponding `torch.distributions` classes. This compatibility is enforced through automated testing using [`tests/tensor_distribution/conftest.py::assert_init_signatures_match`](/tests/tensor_distribution/conftest.py). |
| 14 | + |
| 15 | +**Implementation Requirement**: When `torch.distributions` classes lack proper type annotations for `__init__` parameters, implementers **must** consult the class docstring to determine correct type hints. |
| 16 | + |
| 17 | +### Distribution Delegation Pattern |
| 18 | + |
| 19 | +[`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) subclasses **must not** implement distribution-specific logic. Instead, each subclass **must** implement a `dist()` method that constructs and returns the equivalent `torch.distributions` object using the instance's parameters. |
| 20 | + |
| 21 | +**Implementation Requirement**: The `dist()` method **must** return the raw `torch.distributions` instance, not a wrapped one (e.g., with `Independent`). |
| 22 | + |
| 23 | +**Design Principle**: [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) serves as a parameter management wrapper around `torch.distributions`, delegating all distribution operations to the underlying implementation via `self.dist()` calls. |
| 24 | + |
| 25 | +### Parameter Broadcasting Requirements |
| 26 | + |
| 27 | +Many `torch.distributions` constructors accept parameters of type `Union[Number, Tensor]` or any specialization of `Number` (e.g. `float`). However, [`TensorContainer`](/src/tensorcontainer/tensor_container.py) and [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) can only process `Union[Tensor, TensorContainer]` objects and require all parameters to have compatible shapes for broadcasting. |
| 28 | + |
| 29 | +**Implementation Rule**: When the constructor signature contains `Union[Number, Tensor]` or any specialization of `Number` parameters, implementations **must** use `torch.distributions.utils.broadcast_all` to: |
| 30 | +1. Convert scalar numbers to tensors |
| 31 | +2. Broadcast all parameters to a common shape |
| 32 | + |
| 33 | +This preprocessing ensures proper shape and device management within the [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) framework. |
| 34 | + |
| 35 | +**Decision Criterion**: If the constructor signature does not contain `Union[Number, Tensor]` parameters, simpler parameter handling approaches should be preferred. |
| 36 | + |
| 37 | +### Validation Strategy |
| 38 | + |
| 39 | +[`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) accepts a `validate_args` parameter during initialization and stores it as the `_validate_args` attribute of the base class. Subclasses must pass this value to the underlying `torch.distributions` object (if the constructor supports it). |
| 40 | + |
| 41 | +**Validation Policy**: Parameter validation for [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) subclasses is generally unnecessary because the `TensorDistribution.__init__` method constructs the underlying distribution once via `self.dist()`, triggering parameter validation in the `torch.distributions` implementation. |
| 42 | + |
| 43 | +**Exception Handling**: Implementations should only raise validation errors when required parameters needed for device and shape inference are missing or invalid. |
| 44 | + |
| 45 | +### Property Implementation Pattern |
| 46 | + |
| 47 | +Following the `torch.distributions.Distribution` pattern, basic distribution properties are provided through the [`TensorDistribution`](/src/tensorcontainer/tensor_distribution/base.py) base class via delegation to `self.dist()`. |
| 48 | + |
| 49 | +**Specialization Rule**: Distribution-specific properties **must** be implemented only in the corresponding subclass, maintaining the same delegation pattern to the underlying `torch.distributions` object. |
| 50 | + |
| 51 | +## Implementation Patterns |
| 52 | + |
| 53 | +### Annotated Attribute Pattern |
| 54 | + |
| 55 | +All tensor parameters must be declared as annotated class attributes to enable automatic transformation by [`TensorAnnotated`](/src/tensorcontainer/tensor_annotated.py) operations (e.g., `.to()`, `.expand()`). |
| 56 | + |
| 57 | +**Example Pattern**: |
| 58 | +```python |
| 59 | +class TensorNormal(TensorDistribution): |
| 60 | + _loc: Tensor |
| 61 | + _scale: Tensor |
| 62 | + |
| 63 | + def __init__(self, loc: Tensor, scale: Tensor, validate_args: Optional[bool] = None): |
| 64 | + self._loc = loc |
| 65 | + self._scale = scale |
| 66 | + super().__init__(loc.shape, loc.device, validate_args) |
| 67 | + |
| 68 | + def dist(self) -> Distribution: |
| 69 | + return Normal(self._loc, self._scale, validate_args=self._validate_args) |
| 70 | +``` |
| 71 | + |
| 72 | +Note: If parameters like `loc` and `scale` could be scalars in the constructor signature, apply the broadcasting rules described in the "Parameter Broadcasting" section before assignment to ensure proper tensor handling. |
| 73 | + |
| 74 | +### Lazy Distribution Creation |
| 75 | + |
| 76 | +The actual `torch.distributions.Distribution` instance is created on-demand through the `dist()` method. This lazy evaluation pattern enables efficient tensor operations without premature distribution instantiation. |
| 77 | + |
| 78 | +### Reconstruction Pattern |
| 79 | + |
| 80 | +The `_unflatten_distribution()` class method reconstructs distribution instances from serialized tensor and metadata attributes. This method is called by `_init_from_reconstructed()` during operations like `.to()` and `.expand()`. |
| 81 | + |
| 82 | +**Customization Requirement**: Subclasses with complex parameter relationships **must** override `_unflatten_distribution()` to implement appropriate reconstruction logic. |
| 83 | + |
| 84 | +**Example Implementation**: |
| 85 | +```python |
| 86 | +@classmethod |
| 87 | +def _unflatten_distribution(cls, attributes: Dict[str, Any]): |
| 88 | + """For TensorCategorical, extract _probs and _logits from attributes.""" |
| 89 | + return cls( |
| 90 | + probs=attributes.get("_probs"), |
| 91 | + logits=attributes.get("_logits"), |
| 92 | + validate_args=attributes.get("_validate_args"), |
| 93 | + ) |
0 commit comments