Skip to content

Commit 637c29f

Browse files
author
Tim Joseph
committed
Merge branch 'main' into documentation
2 parents 00ce70f + a8d546e commit 637c29f

File tree

3 files changed

+377
-0
lines changed

3 files changed

+377
-0
lines changed

src/tensorcontainer/tensor_distribution/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .relaxed_one_hot_categorical import TensorRelaxedOneHotCategorical
3636
from .soft_bernoulli import TensorSoftBernoulli
3737
from .student_t import TensorStudentT
38+
from .symlog import TensorSymLog
3839
from .tanh_normal import TensorTanhNormal
3940
from .transformed_distribution import TransformedDistribution
4041
from .truncated_normal import TensorTruncatedNormal
@@ -81,6 +82,7 @@
8182
"TensorRelaxedOneHotCategorical",
8283
"TensorSoftBernoulli",
8384
"TensorStudentT",
85+
"TensorSymLog",
8486
"TensorTanhNormal",
8587
"TransformedDistribution",
8688
"TensorTruncatedNormal",
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from torch import Tensor
6+
from torch.distributions import Distribution
7+
from torch.types import Number
8+
9+
from .base import TensorDistribution
10+
from .utils import broadcast_all
11+
12+
13+
class TensorSymLog(TensorDistribution):
14+
"""Tensor-aware SymLog distribution.
15+
16+
Creates a SymLog distribution parameterized by `loc` (mean) and `scale` (standard deviation).
17+
This distribution transforms a Normal distribution with a symexp transform, which is useful
18+
for modeling data with a wide dynamic range where the data can be both positive and negative.
19+
20+
Args:
21+
loc: Mean of the base Normal distribution.
22+
scale: Standard deviation of the base Normal distribution. Must be positive.
23+
validate_args: Whether to validate the arguments. Defaults to None.
24+
25+
Note:
26+
The SymLog distribution is useful for modeling data with a wide dynamic range,
27+
where the data can be both positive and negative, and can have values close to zero.
28+
The symlog transform compresses large values and expands small values, making the
29+
distribution more stable for optimization.
30+
"""
31+
32+
# Annotated tensor parameters
33+
_loc: Tensor
34+
_scale: Tensor
35+
36+
def __init__(
37+
self,
38+
loc: Number | Tensor,
39+
scale: Number | Tensor,
40+
validate_args: bool | None = None,
41+
) -> None:
42+
self._loc, self._scale = broadcast_all(loc, scale)
43+
44+
shape = self._loc.shape
45+
device = self._loc.device
46+
47+
super().__init__(shape, device, validate_args)
48+
49+
@classmethod
50+
def _unflatten_distribution(
51+
cls,
52+
attributes: dict[str, Any],
53+
) -> TensorSymLog:
54+
return cls(
55+
loc=attributes["_loc"],
56+
scale=attributes["_scale"],
57+
validate_args=attributes.get("_validate_args"),
58+
)
59+
60+
def dist(self) -> Distribution:
61+
"""Return the underlying SymLogDistribution instance."""
62+
from tensorcontainer.distributions.symlog import SymLogDistribution
63+
64+
return SymLogDistribution(
65+
loc=self._loc,
66+
scale=self._scale,
67+
validate_args=self._validate_args,
68+
)
69+
70+
@property
71+
def loc(self) -> Tensor:
72+
"""Returns the location parameter of the distribution."""
73+
return self._loc
74+
75+
@property
76+
def scale(self) -> Tensor:
77+
"""Returns the scale parameter of the distribution."""
78+
return self._scale
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
"""
2+
Tests for TensorSymLog distribution.
3+
4+
This module contains test classes that verify:
5+
- TensorSymLog initialization and parameter validation
6+
- Core distribution operations (sample, rsample, log_prob)
7+
- TensorContainer integration (view, reshape, device operations)
8+
- Distribution-specific properties and edge cases
9+
"""
10+
11+
import pytest
12+
import torch
13+
14+
from tensorcontainer.distributions.symlog import SymLogDistribution
15+
from tensorcontainer.tensor_distribution.symlog import TensorSymLog
16+
from tests.compile_utils import run_and_compare_compiled
17+
from tests.tensor_distribution.conftest import (
18+
assert_init_signatures_match,
19+
assert_properties_signatures_match,
20+
assert_property_values_match,
21+
)
22+
23+
24+
class TestTensorSymLogInitialization:
25+
@pytest.mark.parametrize(
26+
"loc_shape, scale_shape, expected_batch_shape",
27+
[
28+
((), (), ()),
29+
((5,), (), (5,)),
30+
((), (5,), (5,)),
31+
((3, 5), (5,), (3, 5)),
32+
((5,), (3, 5), (3, 5)),
33+
((2, 4, 5), (5,), (2, 4, 5)),
34+
((5,), (2, 4, 5), (2, 4, 5)),
35+
((2, 4, 5), (2, 4, 5), (2, 4, 5)),
36+
],
37+
)
38+
def test_broadcasting_shapes(self, loc_shape, scale_shape, expected_batch_shape):
39+
"""Test that batch_shape is correctly determined by broadcasting."""
40+
loc = torch.randn(loc_shape)
41+
scale = torch.rand(scale_shape).exp() # scale must be positive
42+
td_symlog = TensorSymLog(loc=loc, scale=scale)
43+
assert td_symlog.batch_shape == expected_batch_shape
44+
assert td_symlog.dist().batch_shape == expected_batch_shape
45+
46+
def test_initialization_with_scalars(self):
47+
"""Test initialization with scalar parameters."""
48+
td_symlog = TensorSymLog(loc=0.0, scale=1.0)
49+
assert td_symlog.batch_shape == torch.Size(())
50+
assert td_symlog.loc.shape == torch.Size(())
51+
assert td_symlog.scale.shape == torch.Size(())
52+
53+
def test_initialization_with_tensors(self):
54+
"""Test initialization with tensor parameters."""
55+
loc = torch.tensor([1.0, -2.0, 0.0])
56+
scale = torch.tensor([0.5, 1.0, 2.0])
57+
td_symlog = TensorSymLog(loc=loc, scale=scale)
58+
assert td_symlog.batch_shape == torch.Size([3])
59+
assert torch.allclose(td_symlog.loc, loc)
60+
assert torch.allclose(td_symlog.scale, scale)
61+
62+
63+
class TestTensorSymLogTensorContainerIntegration:
64+
@pytest.mark.parametrize("param_shape", [(5,), (3, 5), (2, 4, 5)])
65+
def test_compile_compatibility(self, param_shape):
66+
"""Core operations should be compatible with torch.compile."""
67+
loc = torch.randn(*param_shape)
68+
scale = torch.rand(*param_shape).exp() # scale must be positive
69+
td_symlog = TensorSymLog(loc=loc, scale=scale)
70+
71+
sample = td_symlog.sample()
72+
73+
def sample_fn(td):
74+
return td.sample()
75+
76+
def rsample_fn(td):
77+
return td.rsample()
78+
79+
def log_prob_fn(td, s):
80+
return td.log_prob(s)
81+
82+
run_and_compare_compiled(sample_fn, td_symlog, fullgraph=False)
83+
run_and_compare_compiled(rsample_fn, td_symlog, fullgraph=False)
84+
run_and_compare_compiled(log_prob_fn, td_symlog, sample, fullgraph=False)
85+
86+
def test_device_compatibility(self):
87+
"""Test that the distribution works on different devices."""
88+
loc = torch.tensor([1.0, -2.0, 0.0])
89+
scale = torch.tensor([0.5, 1.0, 2.0])
90+
td_symlog = TensorSymLog(loc=loc, scale=scale)
91+
92+
# Test CPU
93+
sample = td_symlog.sample()
94+
assert sample.device == loc.device
95+
96+
# Test GPU if available
97+
if torch.cuda.is_available():
98+
loc_gpu = loc.cuda()
99+
scale_gpu = scale.cuda()
100+
td_symlog_gpu = TensorSymLog(loc=loc_gpu, scale=scale_gpu)
101+
sample_gpu = td_symlog_gpu.sample()
102+
assert sample_gpu.is_cuda
103+
104+
105+
class TestTensorSymLogAPIMatch:
106+
"""
107+
Tests that the TensorSymLog API matches the SymLogDistribution API.
108+
"""
109+
110+
def test_init_signatures_match(self):
111+
"""
112+
Tests that the __init__ signature of TensorSymLog matches
113+
SymLogDistribution.
114+
"""
115+
assert_init_signatures_match(TensorSymLog, SymLogDistribution)
116+
117+
def test_properties_match(self):
118+
"""
119+
Tests that the properties of TensorSymLog match
120+
SymLogDistribution.
121+
"""
122+
assert_properties_signatures_match(TensorSymLog, SymLogDistribution)
123+
124+
def test_property_values_match(self):
125+
"""
126+
Tests that the property values of TensorSymLog match
127+
SymLogDistribution.
128+
"""
129+
loc = torch.tensor([1.0, -2.0, 0.0])
130+
scale = torch.tensor([0.5, 1.0, 2.0])
131+
td_symlog = TensorSymLog(loc=loc, scale=scale)
132+
assert_property_values_match(td_symlog)
133+
134+
def test_distribution_equivalence(self):
135+
"""
136+
Tests that TensorSymLog produces the same results as SymLogDistribution.
137+
"""
138+
loc = torch.tensor([1.0, -2.0, 0.0])
139+
scale = torch.tensor([0.5, 1.0, 2.0])
140+
141+
# Create both distributions
142+
td_symlog = TensorSymLog(loc=loc, scale=scale)
143+
symlog_dist = SymLogDistribution(loc=loc, scale=scale)
144+
145+
# Test sampling
146+
torch.manual_seed(42)
147+
td_sample = td_symlog.sample(torch.Size([100]))
148+
torch.manual_seed(42)
149+
symlog_sample = symlog_dist.sample(torch.Size([100]))
150+
assert torch.allclose(td_sample, symlog_sample, rtol=1e-5, atol=1e-5)
151+
152+
# Test log_prob
153+
test_values = torch.tensor([0.0, 1.0, -1.0])
154+
td_log_prob = td_symlog.log_prob(test_values)
155+
symlog_log_prob = symlog_dist.log_prob(test_values)
156+
assert torch.allclose(td_log_prob, symlog_log_prob, rtol=1e-5, atol=1e-5)
157+
158+
# Test mean and mode
159+
assert torch.allclose(td_symlog.mean, symlog_dist.mean, rtol=1e-5, atol=1e-5)
160+
assert torch.allclose(td_symlog.mode, symlog_dist.mode, rtol=1e-5, atol=1e-5)
161+
162+
163+
class TestTensorSymLogFunctionality:
164+
"""Test specific functionality of TensorSymLog."""
165+
166+
@pytest.fixture
167+
def sample_params(self):
168+
"""Common test parameters."""
169+
return {
170+
"loc": torch.tensor([1.0, -2.0, 0.0]),
171+
"scale": torch.tensor([0.5, 1.0, 2.0]),
172+
}
173+
174+
@pytest.fixture
175+
def sample_distribution(self, sample_params):
176+
"""Common test distribution."""
177+
return TensorSymLog(sample_params["loc"], sample_params["scale"])
178+
179+
def test_sampling(self, sample_distribution):
180+
"""Test sampling functionality."""
181+
# Test default sampling
182+
sample = sample_distribution.sample()
183+
assert sample.shape == sample_distribution.batch_shape
184+
assert torch.all(torch.isfinite(sample))
185+
186+
# Test sampling with specific shape
187+
sample_shape = torch.Size([10, 2])
188+
samples = sample_distribution.sample(sample_shape)
189+
assert samples.shape == sample_shape + sample_distribution.batch_shape
190+
assert torch.all(torch.isfinite(samples))
191+
192+
def test_reparameterized_sampling(self, sample_distribution):
193+
"""Test reparameterized sampling functionality."""
194+
# Check if the distribution supports rsample
195+
assert sample_distribution.has_rsample == sample_distribution.dist().has_rsample
196+
197+
# SymLogDistribution claims to support rsample but doesn't actually
198+
# provide gradients due to the non-affine transform
199+
assert sample_distribution.has_rsample
200+
201+
# Test rsample - it should work but without gradients
202+
rsample = sample_distribution.rsample()
203+
assert rsample.shape == sample_distribution.batch_shape
204+
assert not rsample.requires_grad
205+
206+
# Test rsample with specific shape
207+
sample_shape = torch.Size([10, 2])
208+
rsamples = sample_distribution.rsample(sample_shape)
209+
assert rsamples.shape == sample_shape + sample_distribution.batch_shape
210+
assert not rsamples.requires_grad
211+
212+
def test_log_prob(self, sample_distribution):
213+
"""Test log probability computation."""
214+
# Test log_prob at mode (should be high probability)
215+
mode_log_prob = sample_distribution.log_prob(sample_distribution.mode)
216+
assert mode_log_prob.shape == sample_distribution.batch_shape
217+
assert torch.all(torch.isfinite(mode_log_prob))
218+
219+
# Test log_prob for random values
220+
test_values = torch.tensor([0.0, 1.0, -1.0])
221+
log_probs = sample_distribution.log_prob(test_values)
222+
assert log_probs.shape == test_values.shape
223+
assert torch.all(torch.isfinite(log_probs))
224+
225+
def test_entropy(self, sample_distribution):
226+
"""Test entropy computation."""
227+
# SymLogDistribution doesn't implement entropy
228+
with pytest.raises(NotImplementedError):
229+
sample_distribution.entropy()
230+
231+
def test_mean_and_variance(self, sample_distribution):
232+
"""Test mean and variance properties."""
233+
mean = sample_distribution.mean
234+
235+
assert mean.shape == sample_distribution.batch_shape
236+
assert torch.all(torch.isfinite(mean))
237+
238+
# SymLogDistribution doesn't implement variance
239+
with pytest.raises(NotImplementedError):
240+
sample_distribution.variance
241+
242+
def test_mode_property(self, sample_distribution):
243+
"""Test mode property."""
244+
mode = sample_distribution.mode
245+
assert mode.shape == sample_distribution.batch_shape
246+
assert torch.all(torch.isfinite(mode))
247+
248+
def test_batch_and_event_shape(self, sample_distribution):
249+
"""Test batch_shape and event_shape properties."""
250+
assert sample_distribution.batch_shape == sample_distribution.loc.shape
251+
assert sample_distribution.event_shape == torch.Size() # Scalar event
252+
253+
def test_support_property(self, sample_distribution):
254+
"""Test support property."""
255+
support = sample_distribution.support
256+
# SymLogDistribution has real support
257+
assert support is not None
258+
259+
def test_cdf_and_icdf(self, sample_distribution):
260+
"""Test CDF and inverse CDF functionality."""
261+
# Test CDF
262+
test_values = torch.tensor([0.0, 1.0, -1.0])
263+
cdf_values = sample_distribution.cdf(test_values)
264+
265+
assert cdf_values.shape == test_values.shape
266+
assert torch.all((cdf_values >= 0) & (cdf_values <= 1))
267+
assert torch.all(torch.isfinite(cdf_values))
268+
269+
# Test ICDF
270+
prob_values = torch.tensor([0.1, 0.5, 0.9])
271+
icdf_values = sample_distribution.icdf(prob_values)
272+
273+
assert icdf_values.shape == prob_values.shape
274+
assert torch.all(torch.isfinite(icdf_values))
275+
276+
# Test CDF/ICDF inverse relationship
277+
reconstructed_probs = sample_distribution.cdf(icdf_values)
278+
assert torch.allclose(reconstructed_probs, prob_values, atol=1e-5)
279+
280+
def test_unflatten_distribution(self, sample_params):
281+
"""Test _unflatten_distribution class method."""
282+
td_symlog = TensorSymLog(sample_params["loc"], sample_params["scale"])
283+
284+
# Get attributes
285+
attributes = {
286+
"_loc": td_symlog._loc,
287+
"_scale": td_symlog._scale,
288+
"_validate_args": td_symlog._validate_args,
289+
}
290+
291+
# Reconstruct distribution
292+
reconstructed = TensorSymLog._unflatten_distribution(attributes)
293+
294+
# Check that the reconstructed distribution is equivalent
295+
assert torch.allclose(reconstructed.loc, td_symlog.loc)
296+
assert torch.allclose(reconstructed.scale, td_symlog.scale)
297+
assert reconstructed._validate_args == td_symlog._validate_args

0 commit comments

Comments
 (0)