Skip to content

Commit 161e3fd

Browse files
author
Tim Joseph
committed
Fix ruff errors
1 parent 2c571e9 commit 161e3fd

File tree

5 files changed

+11
-17
lines changed

5 files changed

+11
-17
lines changed

src/tensorcontainer/distributions/sampling.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Dict
2-
31
import torch
42
from torch.distributions import Distribution
53

@@ -47,7 +45,6 @@ def entropy(self):
4745
samples = self.base_dist.rsample((self.n,))
4846
logprob = self.base_dist.log_prob(samples)
4947
return -logprob.mean(0)
50-
48+
5149
def log_prob(self, value):
5250
return self.base_dist.log_prob(value)
53-

tests/tensor_dataclass/test_detach.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def test_basic_detach(self, simple_tensor_data_instance, compile_mode):
3434
assert not detached_instance.b.requires_grad
3535

3636
# Check that original tensors still have gradients (should be False for fixture)
37-
assert test_instance.a.requires_grad == False
38-
assert test_instance.b.requires_grad == False
37+
assert not test_instance.a.requires_grad
38+
assert not test_instance.b.requires_grad
3939

4040
# Check that data is preserved
4141
assert torch.equal(detached_instance.a, test_instance.a)

tests/tensor_dict/test_recompilations.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import pytest
22
import torch
3-
4-
from tensorcontainer.tensor_dict import TensorDict
5-
from tests.conftest import skipif_no_compile
63
import torch._dynamo
74
import torch._dynamo.utils
85
from torch._dynamo.testing import CompileCounter # New import
96

7+
from tensorcontainer.tensor_dict import TensorDict
8+
from tests.conftest import skipif_no_compile
109

1110
keys = ["a", "b", "c", "d", "e", "f", "g"]
1211

@@ -26,12 +25,10 @@ def test_getitem_recompilation(key):
2625
torch._dynamo.utils.counters.clear()
2726

2827
# --- Sanity check with a simple lambda using CompileCounter ---
29-
simple_lambda = lambda x: x + 1
30-
3128
lambda_compile_counter = CompileCounter()
3229
# Using fullgraph=True as it was in the original attempts
3330
compiled_lambda = torch.compile(
34-
simple_lambda, backend=lambda_compile_counter, fullgraph=True
31+
lambda x: x + 1, backend=lambda_compile_counter, fullgraph=True
3532
)
3633

3734
_ = compiled_lambda(torch.randn(1))

tests/tensor_dict/test_stack.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from torch._dynamo import exc as dynamo_exc
44

55
from tensorcontainer.tensor_dict import TensorDict
6+
from tests.compile_utils import run_and_compare_compiled
67
from tests.conftest import skipif_no_compile
78
from tests.tensor_dict import common
89
from tests.tensor_dict.common import compare_nested_dict, compute_stack_shape
9-
from tests.compile_utils import run_and_compare_compiled
1010

1111

1212
@pytest.fixture(autouse=True)
@@ -143,7 +143,7 @@ def stack_operation(tensor_dict_instance, stack_dimension):
143143
)
144144

145145
compiled_stack_op = torch.compile(stack_operation, fullgraph=True)
146-
with pytest.raises(dynamo_exc.Unsupported) as excinfo:
146+
with pytest.raises(dynamo_exc.Unsupported):
147147
compiled_stack_op(td, dim)
148148

149149

tests/test_unsafe_construction.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_unsafe_construction_nested_contexts(sample_tensors):
138138
"""Test that nested unsafe construction contexts work correctly."""
139139
with TensorContainer.unsafe_construction():
140140
# Should work in outer context
141-
container1 = SampleTensorDataClass(
141+
SampleTensorDataClass(
142142
features=sample_tensors["incompatible_features"],
143143
labels=sample_tensors["labels"],
144144
shape=(4,),
@@ -210,7 +210,7 @@ def test_thread(thread_id, use_unsafe):
210210
with TensorContainer.unsafe_construction():
211211
# Add small delay to test concurrency
212212
time.sleep(0.01)
213-
container = SampleTensorDataClass(
213+
SampleTensorDataClass(
214214
features=sample_tensors["incompatible_features"],
215215
labels=sample_tensors["incompatible_labels"],
216216
shape=(4,),
@@ -219,7 +219,7 @@ def test_thread(thread_id, use_unsafe):
219219
results[thread_id] = "success"
220220
else:
221221
try:
222-
container = SampleTensorDataClass(
222+
SampleTensorDataClass(
223223
features=sample_tensors["incompatible_features"],
224224
labels=sample_tensors["incompatible_labels"],
225225
shape=(4,),

0 commit comments

Comments
 (0)