33from torch .distributions import (
44 TransformedDistribution as TorchTransformedDistribution ,
55)
6- from torch .distributions .transforms import AffineTransform , ExpTransform
6+ from torch .distributions .transforms import AffineTransform , ExpTransform , Transform
77
88from tensorcontainer .tensor_distribution .normal import TensorNormal
99from tensorcontainer .tensor_distribution .transformed_distribution import (
@@ -33,7 +33,7 @@ def test_broadcasting_shapes(self, loc_shape, scale_shape, expected_batch_shape)
3333 loc = torch .randn (loc_shape )
3434 scale = torch .rand (scale_shape ).exp ()
3535 base_dist = TensorNormal (loc = loc , scale = scale )
36- transforms = [ExpTransform ()]
36+ transforms : list [ Transform ] = [ExpTransform ()]
3737 td = TransformedDistribution (base_distribution = base_dist , transforms = transforms )
3838 assert td .batch_shape == expected_batch_shape
3939 assert td .dist ().batch_shape == expected_batch_shape
@@ -59,7 +59,7 @@ def test_compile_compatibility(self, param_shape):
5959 loc = torch .randn (* param_shape )
6060 scale = torch .rand (* param_shape ).exp ()
6161 base_dist = TensorNormal (loc = loc , scale = scale )
62- transforms = [ExpTransform ()]
62+ transforms : list [ Transform ] = [ExpTransform ()]
6363 td = TransformedDistribution (base_distribution = base_dist , transforms = transforms )
6464
6565 sample = td .sample ()
@@ -81,26 +81,108 @@ def test_sample_log_prob(self):
8181 loc = torch .randn (3 , 5 )
8282 scale = torch .rand (3 , 5 ).exp ()
8383 base_dist = TensorNormal (loc = loc , scale = scale )
84- transforms = [ExpTransform ()]
84+ transforms : list [ Transform ] = [ExpTransform ()]
8585 td = TransformedDistribution (base_distribution = base_dist , transforms = transforms )
8686
8787 torch_td = TorchTransformedDistribution (base_dist .dist (), transforms )
8888
89- sample_shape = ( 2 , 1 )
89+ sample_shape = torch . Size ([ 2 , 1 ] )
9090 sample = td .sample (sample_shape )
91- assert sample .shape == torch_td .sample (sample_shape ).shape
91+ torch_sample = torch_td .sample (sample_shape )
92+ assert torch_sample is not None
93+ assert sample .shape == torch_sample .shape
9294 assert torch .allclose (td .log_prob (sample ), torch_td .log_prob (sample ))
9395
9496 rsample = td .rsample (sample_shape )
95- assert rsample .shape == torch_td .rsample (sample_shape ).shape
97+ torch_rsample = torch_td .rsample (sample_shape )
98+ assert torch_rsample is not None
99+ assert rsample .shape == torch_rsample .shape
96100 assert torch .allclose (td .log_prob (rsample ), torch_td .log_prob (rsample ))
97101
98102
103+ class TestTransformedDistributionCopy :
104+ @pytest .fixture
105+ def base_distribution (self ):
106+ """Create a base normal distribution for testing."""
107+ loc = torch .randn (3 , 5 )
108+ scale = torch .rand (3 , 5 ).exp ()
109+ return TensorNormal (loc = loc , scale = scale )
110+
111+ @pytest .fixture
112+ def transforms (self ):
113+ """Create a list of transforms for testing."""
114+ return [
115+ ExpTransform (),
116+ AffineTransform (loc = 1.0 , scale = 2.0 ),
117+ ]
118+
119+ @pytest .fixture
120+ def original_dist (self , base_distribution , transforms ):
121+ """Create an original transformed distribution for testing."""
122+ return TransformedDistribution (
123+ base_distribution = base_distribution , transforms = transforms
124+ )
125+
126+ def test_copy_creates_new_object (self , original_dist ):
127+ """Test that copy creates a new object of the correct type."""
128+ copied_dist = original_dist .copy ()
129+
130+ # Check that the copy is a different object but same type
131+ assert copied_dist is not original_dist
132+ assert isinstance (copied_dist , TransformedDistribution )
133+
134+ def test_copy_base_distribution_handling (self , original_dist ):
135+ """Test that the base distribution is handled correctly in copy."""
136+ copied_dist = original_dist .copy ()
137+
138+ # Check that the base distribution is a different object
139+ assert copied_dist .base_distribution is not original_dist .base_distribution
140+ assert isinstance (copied_dist .base_distribution , TensorNormal )
141+
142+ # Check that tensor parameters are the same objects (identity)
143+ original_base = original_dist .base_distribution
144+ copied_base = copied_dist .base_distribution
145+
146+ assert original_base ._loc is copied_base ._loc
147+ assert original_base ._scale is copied_base ._scale
148+
149+ def test_copy_transforms_handling (self , original_dist ):
150+ """Test that transforms are handled correctly in copy."""
151+ copied_dist = original_dist .copy ()
152+
153+ # Check that the transforms are the same objects (they're not tensors)
154+ assert copied_dist .transforms is original_dist .transforms
155+
156+ def test_copy_sampling_consistency (self , original_dist ):
157+ """Test that copied distribution produces consistent sampling results."""
158+ copied_dist = original_dist .copy ()
159+ sample_shape = torch .Size ([2 , 1 ])
160+
161+ # Check that samples have the same shape
162+ original_sample = original_dist .sample (sample_shape )
163+ copied_sample = copied_dist .sample (sample_shape )
164+ assert original_sample .shape == copied_sample .shape
165+
166+ # Check that log_prob values are consistent for the same sample
167+ torch .testing .assert_close (
168+ original_dist .log_prob (original_sample ),
169+ copied_dist .log_prob (original_sample ),
170+ )
171+
172+ def test_copy_property_consistency (self , original_dist ):
173+ """Test that copied distribution has the same properties."""
174+ copied_dist = original_dist .copy ()
175+
176+ # Check that the distributions have the same properties
177+ assert original_dist .batch_shape == copied_dist .batch_shape
178+ assert original_dist .device == copied_dist .device
179+
180+
99181class TestTransformedDistributionAPIMatch :
100182 def test_properties_match (self ):
101183 loc = torch .randn (3 , 5 )
102184 scale = torch .rand (3 , 5 ).exp ()
103185 base_dist = TensorNormal (loc = loc , scale = scale )
104- transforms = [ExpTransform ()]
186+ transforms : list [ Transform ] = [ExpTransform ()]
105187 td = TransformedDistribution (base_distribution = base_dist , transforms = transforms )
106188 assert_property_values_match (td )
0 commit comments