Skip to content

Commit e1cec97

Browse files
authored
Merge pull request #5 from mctigger/tanh_normal-properties
This pull request fixes a critical bug in `TensorIndependent`'s shape calculation and significantly improves the `TensorTanhNormal` and `SamplingDistribution` classes. Here's a breakdown of the key changes: ### Key Fixes and Improvements * **`TensorIndependent` Shape Bug:** A bug where `reinterpreted_batch_ndims=0` resulted in an incorrect shape has been fixed. The shape calculation now correctly returns the full shape of the base distribution in this case. * **`TensorTanhNormal` Simplification:** The `reinterpreted_batch_ndims` parameter has been removed from `TensorTanhNormal` to enforce a clearer separation of concerns. To achieve this functionality, users should now wrap `TensorTanhNormal` with `TensorIndependent`. The class also now includes new statistical properties like **mean**, **variance**, and **standard deviation**. * **`SamplingDistribution` Enhancements:** This class has been rewritten for better performance and reliability. It now uses `__slots__` for memory efficiency, includes better caching for statistical properties, improves error handling, and adds input validation. ### Impact and Migration This update introduces a **breaking change**: the `reinterpreted_batch_ndims` parameter is no longer available in `TensorTanhNormal`. To migrate, you must now explicitly wrap `TensorTanhNormal` with `TensorIndependent` to reinterpret batch dimensions. The pull request includes a migration guide with an example of the new usage.
2 parents 80763ca + da7252c commit e1cec97

File tree

7 files changed

+458
-74
lines changed

7 files changed

+458
-74
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "tensorcontainer"
7-
version = "0.6.1"
7+
version = "0.6.2"
88
description = "TensorDict-like functionality for PyTorch with PyTree compatibility and torch.compile support"
99
authors = [{name="Tim Joseph", email="tim@mctigger.com"}]
1010
license = {text = "MIT"}
Lines changed: 112 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,135 @@
11
import torch
2-
from torch.distributions import Distribution
2+
from torch.distributions import Distribution, constraints
3+
from functools import cached_property
4+
from typing import Any, Optional
35

46

57
class SamplingDistribution(Distribution):
6-
def __init__(self, base_distribution: Distribution, n=100):
8+
"""
9+
A wrapper for a PyTorch distribution that calculates statistics via sampling.
10+
11+
This distribution is useful when the analytical statistics of a base
12+
distribution are not available or not desired. Instead, it computes
13+
properties like mean, stddev, variance, and mode by drawing samples from the
14+
base distribution.
15+
16+
To improve efficiency, it caches the generated samples and the computed
17+
statistics, ensuring that repeated access to these properties does not
18+
trigger redundant computations.
19+
20+
Args:
21+
base_distribution (Distribution): The underlying distribution to sample from.
22+
n (int, optional): The number of samples to draw for calculating
23+
statistics. Defaults to 100.
24+
"""
25+
26+
__slots__ = ["base_dist", "n"]
27+
28+
def __init__(self, base_distribution: Distribution, *, n: int = 100):
29+
if not isinstance(base_distribution, Distribution):
30+
raise TypeError(
31+
"base_distribution must be a torch.distributions.Distribution"
32+
)
33+
if not isinstance(n, int) or n <= 0:
34+
raise ValueError("n must be a positive integer")
35+
736
self.base_dist = base_distribution
837
self.n = n
938

10-
def __getattr__(self, name):
39+
super().__init__(
40+
batch_shape=self.base_dist.batch_shape,
41+
event_shape=self.base_dist.event_shape,
42+
validate_args=False, # We defer validation to the base distribution
43+
)
44+
45+
def __repr__(self) -> str:
46+
return f"SamplingDistribution(base_dist={self.base_dist}, n={self.n})"
47+
48+
def __getattr__(self, name: str) -> Any:
49+
"""Delegates attribute access to the base distribution."""
1150
return getattr(self.base_dist, name)
1251

52+
@cached_property
53+
def _samples(self) -> torch.Tensor:
54+
"""
55+
Cached samples from the base distribution.
56+
57+
Uses rsample if available for reparameterization-friendly gradients,
58+
otherwise falls back to sample.
59+
"""
60+
sample_shape = torch.Size((self.n,))
61+
if self.base_dist.has_rsample:
62+
return self.base_dist.rsample(sample_shape)
63+
return self.base_dist.sample(sample_shape)
64+
1365
@property
14-
def has_rsample(self):
66+
def has_rsample(self) -> bool:
1567
return self.base_dist.has_rsample
1668

17-
def rsample(self, sample_shape=torch.Size()):
69+
def rsample(self, sample_shape: Any = torch.Size()) -> torch.Tensor:
70+
"""Delegates rsample to the base distribution."""
1871
return self.base_dist.rsample(sample_shape)
1972

20-
def sample(self, sample_shape=torch.Size()):
73+
def sample(self, sample_shape: Any = torch.Size()) -> torch.Tensor:
74+
"""Delegates sample to the base distribution."""
2175
return self.base_dist.sample(sample_shape)
2276

23-
@property
24-
def mean(self):
25-
return self.base_dist.rsample((self.n,)).mean(0)
77+
@cached_property
78+
def mean(self) -> torch.Tensor: # type: ignore
79+
"""Mean of the distribution, computed as the mean of cached samples."""
80+
return self._samples.float().mean(0)
2681

27-
@property
28-
def stddev(self):
29-
return self.base_dist.rsample((self.n,)).std(0)
82+
@cached_property
83+
def stddev(self) -> torch.Tensor: # type: ignore
84+
"""Standard deviation of the distribution, computed from cached samples."""
85+
return self._samples.float().std(0)
3086

31-
@property
32-
def variance(self):
33-
return self.base_dist.rsample((self.n,)).var(0)
87+
@cached_property
88+
def variance(self) -> torch.Tensor: # type: ignore
89+
"""Variance of the distribution, computed from cached samples."""
90+
return self._samples.float().var(0)
3491

35-
@property
36-
def mode(self):
37-
samples = self.base_dist.sample((self.n,))
38-
log_probs = self.base_dist.log_prob(samples).view(self.n, -1)
39-
index = torch.argmax(log_probs, dim=0)
92+
@cached_property
93+
def mode(self) -> torch.Tensor: # type: ignore
94+
"""
95+
Mode of the distribution.
96+
97+
Tries to return the analytical mode if available. Otherwise, it computes
98+
the mode via Monte Carlo approximation by finding the sample with the
99+
highest log probability.
100+
"""
101+
try:
102+
return self.base_dist.mode
103+
except (AttributeError, NotImplementedError):
104+
pass # Fall back to sampling
105+
106+
log_probs = self.base_dist.log_prob(self._samples)
107+
max_indices = torch.argmax(log_probs, dim=0)
40108

41-
selected = torch.gather(samples.view(self.n, -1), 0, index.unsqueeze(0))
42-
return selected
109+
# Use advanced indexing to gather the modes efficiently
110+
return self._samples.gather(
111+
0, max_indices.reshape(1, *max_indices.shape, *(1,) * len(self.event_shape))
112+
).squeeze(0)
43113

44-
def entropy(self):
45-
samples = self.base_dist.rsample((self.n,))
46-
logprob = self.base_dist.log_prob(samples)
47-
return -logprob.mean(0)
114+
def entropy(self) -> torch.Tensor:
115+
"""
116+
Entropy of the distribution, estimated via Monte Carlo.
48117
49-
def log_prob(self, value):
118+
Calculates the negative mean of the log probabilities of cached samples.
119+
"""
120+
log_prob = self.base_dist.log_prob(self._samples)
121+
return -log_prob.mean(0)
122+
123+
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
124+
"""Delegates log probability calculation to the base distribution."""
50125
return self.base_dist.log_prob(value)
126+
127+
@property
128+
def support(self) -> Optional[constraints.Constraint]:
129+
"""Delegates support to the base distribution."""
130+
return self.base_dist.support
131+
132+
@property
133+
def arg_constraints(self) -> dict:
134+
"""Delegates argument constraints to the base distribution."""
135+
return self.base_dist.arg_constraints

src/tensorcontainer/tensor_distribution/independent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ def __init__(
1919
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
2020

2121
super().__init__(
22-
Size(base_distribution.shape[:-reinterpreted_batch_ndims]),
22+
Size(
23+
base_distribution.shape[:-reinterpreted_batch_ndims]
24+
if reinterpreted_batch_ndims > 0
25+
else base_distribution.shape
26+
),
2327
base_distribution.device,
2428
)
2529

src/tensorcontainer/tensor_distribution/tanh_normal.py

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, Optional, get_args
3+
from functools import cached_property
4+
from typing import Any, Dict, get_args
45

56
import torch
67
from torch import Tensor
78
from torch.distributions import (
89
Distribution,
9-
Independent,
1010
Normal,
1111
TransformedDistribution,
1212
constraints,
1313
)
1414
from torch.distributions.utils import broadcast_all
1515
from torch.types import Number
1616

17+
from ..distributions.sampling import SamplingDistribution
1718
from .base import TensorDistribution
1819

1920

@@ -49,25 +50,31 @@ def log_abs_det_jacobian(self, x, y):
4950

5051

5152
class TensorTanhNormal(TensorDistribution):
53+
"""Tensor-aware TanhNormal distribution.
54+
55+
Creates a transformed Normal distribution where the output is passed through
56+
a hyperbolic tangent (tanh) function, constraining values to the interval (-1, 1).
57+
58+
Args:
59+
loc: Location parameter of the underlying normal distribution.
60+
scale: Scale parameter of the underlying normal distribution. Must be positive.
61+
62+
Note:
63+
This distribution is commonly used in reinforcement learning for bounded
64+
continuous action spaces. Use TensorIndependent to reinterpret batch dimensions
65+
as event dimensions if needed.
66+
"""
67+
5268
_loc: Tensor
5369
_scale: Tensor
54-
_reinterpreted_batch_ndims: int
5570

5671
def __init__(
5772
self,
5873
loc: Tensor,
5974
scale: Tensor,
60-
reinterpreted_batch_ndims: Optional[int] = None,
6175
):
6276
self._loc, self._scale = broadcast_all(loc, scale)
6377

64-
if reinterpreted_batch_ndims is None:
65-
self._reinterpreted_batch_ndims = 0
66-
if self._loc.ndim > 0:
67-
self._reinterpreted_batch_ndims = 1
68-
else:
69-
self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
70-
7178
if isinstance(loc, get_args(Number)) and isinstance(scale, get_args(Number)):
7279
shape = tuple()
7380
else:
@@ -86,25 +93,59 @@ def _unflatten_distribution(
8693
return cls(
8794
loc=attributes.get("_loc"), # type: ignore
8895
scale=attributes.get("_scale"), # type: ignore
89-
reinterpreted_batch_ndims=attributes.get("_reinterpreted_batch_ndims"), # type: ignore
9096
)
9197

9298
def dist(self) -> Distribution:
93-
return Independent(
94-
TransformedDistribution(
95-
Normal(self._loc.float(), self._scale.float(), validate_args=False),
96-
[
97-
ClampedTanhTransform(),
98-
],
99-
validate_args=False,
100-
),
101-
self._reinterpreted_batch_ndims,
99+
return TransformedDistribution(
100+
Normal(self._loc.float(), self._scale.float(), validate_args=False),
101+
[
102+
ClampedTanhTransform(),
103+
],
104+
validate_args=False,
102105
)
103106

107+
def log_prob(self, value: Tensor) -> Tensor:
108+
"""Compute log probability of value under the distribution."""
109+
return self.dist().log_prob(value)
110+
104111
@property
105112
def loc(self) -> Tensor:
113+
"""Returns the location parameter of the underlying normal distribution."""
106114
return self._loc
107115

108116
@property
109117
def scale(self) -> Tensor:
118+
"""Returns the scale parameter of the underlying normal distribution."""
110119
return self._scale
120+
121+
@cached_property
122+
def _sampling_dist(self) -> SamplingDistribution:
123+
"""Cached sampling distribution for consistent property calculations."""
124+
return SamplingDistribution(self.dist(), n=10000)
125+
126+
@property
127+
def mean(self) -> Tensor:
128+
"""Returns the mean of the distribution.
129+
130+
Note: For transformed distributions, this is computed via sampling
131+
since the analytical mean may not be available.
132+
"""
133+
return self._sampling_dist.mean
134+
135+
@property
136+
def variance(self) -> Tensor:
137+
"""Returns the variance of the distribution.
138+
139+
Note: For transformed distributions, this is computed via sampling
140+
since the analytical variance may not be available.
141+
"""
142+
return self._sampling_dist.variance
143+
144+
@property
145+
def stddev(self) -> Tensor:
146+
"""Returns the standard deviation of the distribution.
147+
148+
Note: For transformed distributions, this is computed via sampling
149+
since the analytical standard deviation may not be available.
150+
"""
151+
return self._sampling_dist.stddev

tests/tensor_distribution/test_independent.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def test_broadcasting_shapes(
4242
assert td_independent.batch_shape == expected_batch_shape
4343
assert td_independent.dist().batch_shape == expected_batch_shape
4444

45+
# Test that shape property matches expected_batch_shape (fixes bug with reinterpreted_batch_ndims=0)
46+
assert td_independent.shape == expected_batch_shape
47+
4548

4649
class TestTensorIndependentTensorContainerIntegration:
4750
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)