Skip to content

Commit da7252c

Browse files
author
Tim Joseph
committed
Fixed formatting
1 parent 2174a88 commit da7252c

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

tests/tensor_distribution/test_independent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_broadcasting_shapes(
4141
)
4242
assert td_independent.batch_shape == expected_batch_shape
4343
assert td_independent.dist().batch_shape == expected_batch_shape
44-
44+
4545
# Test that shape property matches expected_batch_shape (fixes bug with reinterpreted_batch_ndims=0)
4646
assert td_independent.shape == expected_batch_shape
4747

tests/tensor_distribution/test_tanh_normal.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ def test_initialization_tensor(self):
2929

3030
def test_initialization_reinterpreted_batch_ndims(self):
3131
# Test with TensorIndependent wrapper for reinterpreted batch dimensions
32-
from src.tensorcontainer.tensor_distribution.independent import TensorIndependent
33-
32+
from src.tensorcontainer.tensor_distribution.independent import (
33+
TensorIndependent,
34+
)
35+
3436
loc = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
3537
scale = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
3638
base_dist = TensorTanhNormal(loc, scale)
3739
dist = TensorIndependent(base_dist, reinterpreted_batch_ndims=1)
38-
40+
3941
assert torch.equal(base_dist.loc, loc)
4042
assert torch.equal(base_dist.scale, scale)
4143
assert base_dist.shape == loc.shape

0 commit comments

Comments
 (0)