Skip to content

Commit 58f8698

Browse files
author
Tim Joseph
committed
feat(tensor_distribution): add unflatten distribution method
Add _unflatten_distribution class method to support deserialization of TransformedDistribution instances from attribute dictionaries. This enables proper reconstruction of distribution objects from serialized state. Comprehensive test suite added to verify copy functionality works correctly, ensuring tensor sharing is preserved while creating independent distribution instances.
1 parent d61df07 commit 58f8698

File tree

2 files changed

+101
-8
lines changed

2 files changed

+101
-8
lines changed

src/tensorcontainer/tensor_distribution/transformed_distribution.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from typing import Any
34

45
from torch.distributions import Distribution
56
from torch.distributions import TransformedDistribution as TorchTransformedDistribution
@@ -32,6 +33,16 @@ def __init__(
3233
base_distribution.batch_shape, base_distribution.device, validate_args
3334
)
3435

36+
@classmethod
37+
def _unflatten_distribution(
38+
cls, attributes: dict[str, Any]
39+
) -> TransformedDistribution:
40+
return cls(
41+
base_distribution=attributes["base_distribution"],
42+
transforms=attributes["transforms"],
43+
validate_args=attributes.get("_validate_args"),
44+
)
45+
3546
def dist(self) -> Distribution:
3647
"""
3748
Returns the underlying torch.distributions.Distribution instance.

tests/tensor_distribution/test_transformed_distribution.py

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch.distributions import (
44
TransformedDistribution as TorchTransformedDistribution,
55
)
6-
from torch.distributions.transforms import AffineTransform, ExpTransform
6+
from torch.distributions.transforms import AffineTransform, ExpTransform, Transform
77

88
from tensorcontainer.tensor_distribution.normal import TensorNormal
99
from tensorcontainer.tensor_distribution.transformed_distribution import (
@@ -33,7 +33,7 @@ def test_broadcasting_shapes(self, loc_shape, scale_shape, expected_batch_shape)
3333
loc = torch.randn(loc_shape)
3434
scale = torch.rand(scale_shape).exp()
3535
base_dist = TensorNormal(loc=loc, scale=scale)
36-
transforms = [ExpTransform()]
36+
transforms: list[Transform] = [ExpTransform()]
3737
td = TransformedDistribution(base_distribution=base_dist, transforms=transforms)
3838
assert td.batch_shape == expected_batch_shape
3939
assert td.dist().batch_shape == expected_batch_shape
@@ -59,7 +59,7 @@ def test_compile_compatibility(self, param_shape):
5959
loc = torch.randn(*param_shape)
6060
scale = torch.rand(*param_shape).exp()
6161
base_dist = TensorNormal(loc=loc, scale=scale)
62-
transforms = [ExpTransform()]
62+
transforms: list[Transform] = [ExpTransform()]
6363
td = TransformedDistribution(base_distribution=base_dist, transforms=transforms)
6464

6565
sample = td.sample()
@@ -81,26 +81,108 @@ def test_sample_log_prob(self):
8181
loc = torch.randn(3, 5)
8282
scale = torch.rand(3, 5).exp()
8383
base_dist = TensorNormal(loc=loc, scale=scale)
84-
transforms = [ExpTransform()]
84+
transforms: list[Transform] = [ExpTransform()]
8585
td = TransformedDistribution(base_distribution=base_dist, transforms=transforms)
8686

8787
torch_td = TorchTransformedDistribution(base_dist.dist(), transforms)
8888

89-
sample_shape = (2, 1)
89+
sample_shape = torch.Size([2, 1])
9090
sample = td.sample(sample_shape)
91-
assert sample.shape == torch_td.sample(sample_shape).shape
91+
torch_sample = torch_td.sample(sample_shape)
92+
assert torch_sample is not None
93+
assert sample.shape == torch_sample.shape
9294
assert torch.allclose(td.log_prob(sample), torch_td.log_prob(sample))
9395

9496
rsample = td.rsample(sample_shape)
95-
assert rsample.shape == torch_td.rsample(sample_shape).shape
97+
torch_rsample = torch_td.rsample(sample_shape)
98+
assert torch_rsample is not None
99+
assert rsample.shape == torch_rsample.shape
96100
assert torch.allclose(td.log_prob(rsample), torch_td.log_prob(rsample))
97101

98102

103+
class TestTransformedDistributionCopy:
104+
@pytest.fixture
105+
def base_distribution(self):
106+
"""Create a base normal distribution for testing."""
107+
loc = torch.randn(3, 5)
108+
scale = torch.rand(3, 5).exp()
109+
return TensorNormal(loc=loc, scale=scale)
110+
111+
@pytest.fixture
112+
def transforms(self):
113+
"""Create a list of transforms for testing."""
114+
return [
115+
ExpTransform(),
116+
AffineTransform(loc=1.0, scale=2.0),
117+
]
118+
119+
@pytest.fixture
120+
def original_dist(self, base_distribution, transforms):
121+
"""Create an original transformed distribution for testing."""
122+
return TransformedDistribution(
123+
base_distribution=base_distribution, transforms=transforms
124+
)
125+
126+
def test_copy_creates_new_object(self, original_dist):
127+
"""Test that copy creates a new object of the correct type."""
128+
copied_dist = original_dist.copy()
129+
130+
# Check that the copy is a different object but same type
131+
assert copied_dist is not original_dist
132+
assert isinstance(copied_dist, TransformedDistribution)
133+
134+
def test_copy_base_distribution_handling(self, original_dist):
135+
"""Test that the base distribution is handled correctly in copy."""
136+
copied_dist = original_dist.copy()
137+
138+
# Check that the base distribution is a different object
139+
assert copied_dist.base_distribution is not original_dist.base_distribution
140+
assert isinstance(copied_dist.base_distribution, TensorNormal)
141+
142+
# Check that tensor parameters are the same objects (identity)
143+
original_base = original_dist.base_distribution
144+
copied_base = copied_dist.base_distribution
145+
146+
assert original_base._loc is copied_base._loc
147+
assert original_base._scale is copied_base._scale
148+
149+
def test_copy_transforms_handling(self, original_dist):
150+
"""Test that transforms are handled correctly in copy."""
151+
copied_dist = original_dist.copy()
152+
153+
# Check that the transforms are the same objects (they're not tensors)
154+
assert copied_dist.transforms is original_dist.transforms
155+
156+
def test_copy_sampling_consistency(self, original_dist):
157+
"""Test that copied distribution produces consistent sampling results."""
158+
copied_dist = original_dist.copy()
159+
sample_shape = torch.Size([2, 1])
160+
161+
# Check that samples have the same shape
162+
original_sample = original_dist.sample(sample_shape)
163+
copied_sample = copied_dist.sample(sample_shape)
164+
assert original_sample.shape == copied_sample.shape
165+
166+
# Check that log_prob values are consistent for the same sample
167+
torch.testing.assert_close(
168+
original_dist.log_prob(original_sample),
169+
copied_dist.log_prob(original_sample),
170+
)
171+
172+
def test_copy_property_consistency(self, original_dist):
173+
"""Test that copied distribution has the same properties."""
174+
copied_dist = original_dist.copy()
175+
176+
# Check that the distributions have the same properties
177+
assert original_dist.batch_shape == copied_dist.batch_shape
178+
assert original_dist.device == copied_dist.device
179+
180+
99181
class TestTransformedDistributionAPIMatch:
100182
def test_properties_match(self):
101183
loc = torch.randn(3, 5)
102184
scale = torch.rand(3, 5).exp()
103185
base_dist = TensorNormal(loc=loc, scale=scale)
104-
transforms = [ExpTransform()]
186+
transforms: list[Transform] = [ExpTransform()]
105187
td = TransformedDistribution(base_distribution=base_dist, transforms=transforms)
106188
assert_property_values_match(td)

0 commit comments

Comments
 (0)