Skip to content

Commit 8b6c8e4

Browse files
committed
lint
1 parent d3445e6 commit 8b6c8e4

File tree

4 files changed

+32
-20
lines changed

4 files changed

+32
-20
lines changed

test/test_compile.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import pytest
1515

1616
import torch
17+
18+
from _utils_internal import is_npu_available
1719
from packaging import version
1820

1921
from tensordict import (
@@ -39,8 +41,6 @@
3941

4042
from tensordict.tensorclass import TensorClass
4143

42-
from _utils_internal import is_npu_available
43-
4444
from torch.utils._pytree import SUPPORTED_NODES, tree_map
4545

4646
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
@@ -300,7 +300,8 @@ def test_to_device(td):
300300
assert td_device_c.device == torch.device(device)
301301

302302
@pytest.mark.skipif(
303-
is_npu_available(), reason="torch.device in torch.compile is not supported on NPU currently."
303+
is_npu_available(),
304+
reason="torch.device in torch.compile is not supported on NPU currently.",
304305
)
305306
def test_lock(self, mode):
306307
def locked_op(td):
@@ -593,7 +594,8 @@ def test_to_device(tc):
593594
assert tc_device_c.device == torch.device(device)
594595

595596
@pytest.mark.skipif(
596-
is_npu_available(), reason="torch.device in torch.compile is not supported on NPU currently."
597+
is_npu_available(),
598+
reason="torch.device in torch.compile is not supported on NPU currently.",
597599
)
598600
def test_tc_lock(self, mode):
599601
def locked_op(tc):

test/test_distributed.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
import pytest
1212
import torch
1313
from _pytest.fixtures import fixture
14+
from _utils_internal import is_npu_available
1415
from packaging import version
1516

1617
from packaging.version import parse
1718

1819
from tensordict import LazyStackedTensorDict, MemoryMappedTensor, TensorDict
1920
from tensordict.utils import logger as tdlogger
20-
from _utils_internal import is_npu_available
2121
from torch import distributed as dist, multiprocessing as mp, nn
2222
from torch.distributed._tensor import (
2323
DeviceMesh,
@@ -109,7 +109,8 @@ def test_fsdp_module(self, tmpdir):
109109

110110

111111
@pytest.mark.skipif(
112-
not is_npu_available() or not torch.npu.device_count() > 2, reason="not enough npu devices"
112+
not is_npu_available() or not torch.npu.device_count() > 2,
113+
reason="not enough npu devices",
113114
)
114115
class TestNPUFSDP:
115116
class MyDModule(nn.Module):
@@ -127,9 +128,7 @@ def forward(self, input):
127128
@classmethod
128129
def make_module(cls, device=None):
129130
with (
130-
torch.device(f"npu:{device}")
131-
if device is not None
132-
else torch.device("npu")
131+
torch.device(f"npu:{device}") if device is not None else torch.device("npu")
133132
):
134133
my_module = cls.MyDModule()
135134
my_sharded_module = FSDP(my_module, device_id=device)
@@ -307,8 +306,8 @@ def server(queue):
307306
},
308307
[2],
309308
)
310-
.expand(1, 2)
311-
.contiguous()
309+
.expand(1, 2)
310+
.contiguous()
312311
)
313312
td.gather_and_stack(0)
314313
assert (td != 0).all()
@@ -380,8 +379,8 @@ def server(queue, op, async_op, return_premature):
380379
},
381380
[2],
382381
)
383-
.expand(1, 2)
384-
.contiguous()
382+
.expand(1, 2)
383+
.contiguous()
385384
)
386385
out = td.reduce(0, op=op, async_op=async_op, return_premature=return_premature)
387386
if not async_op:
@@ -864,8 +863,8 @@ def server(cls, queue):
864863
},
865864
[2],
866865
)
867-
.expand(1, 2)
868-
.contiguous()
866+
.expand(1, 2)
867+
.contiguous()
869868
)
870869
td.init_remote(dst=1)
871870

test/test_nn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pytest
1919
import torch
20+
from _utils_internal import is_npu_available
2021

2122
from tensordict import (
2223
is_tensor_collection,
@@ -57,7 +58,6 @@
5758
skip_existing,
5859
)
5960
from tensordict.tensorclass import TensorClass
60-
from _utils_internal import is_npu_available
6161

6262
from torch import distributions, nn
6363
from torch.distributions import Categorical, Normal
@@ -2186,6 +2186,7 @@ def test_module_buffer():
21862186
)
21872187
def test_to_context(original_device, new_device, tc):
21882188
if tc:
2189+
21892190
class MyTC(TensorClass):
21902191
x: torch.Tensor
21912192
y: torch.Tensor | None = None

test/test_tensorclass.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import pytest
2727
import tensordict.utils
2828
import torch
29+
from _utils_internal import is_npu_available
2930

3031
from tensordict import (
3132
assert_allclose_td,
@@ -45,7 +46,6 @@
4546
from tensordict._td import lazy_stack
4647
from tensordict.base import _GENERIC_NESTED_ERR
4748
from tensordict.tensorclass import from_dataclass
48-
from _utils_internal import is_npu_available
4949

5050
from torch import Tensor
5151

@@ -739,6 +739,7 @@ def test_disallowed_attributes(self):
739739
AttributeError,
740740
match="Attribute name reshape can't be used with @tensorclass",
741741
):
742+
742743
@tensorclass
743744
class MyInvalidClass:
744745
x: torch.Tensor
@@ -1100,6 +1101,7 @@ class MyDataParent:
11001101
@pytest.mark.parametrize("list_to_stack", [True, False])
11011102
def test_indexing(self, list_to_stack):
11021103
with set_list_to_stack(list_to_stack):
1104+
11031105
@tensorclass
11041106
class MyDataNested:
11051107
X: torch.Tensor
@@ -1436,8 +1438,8 @@ class MyDataNested(TensorClass):
14361438
assert (
14371439
repeated.X
14381440
== X.repeat_interleave(
1439-
torch.tensor([2, 3, 4, 5], device=data.device), dim=1
1440-
)
1441+
torch.tensor([2, 3, 4, 5], device=data.device), dim=1
1442+
)
14411443
).all()
14421444

14431445
def test_reshape(self):
@@ -2888,20 +2890,23 @@ class FuncAutoCast:
28882890
class TestShadow:
28892891
def test_no_shadow(self):
28902892
with pytest.raises(AttributeError):
2893+
28912894
@tensorclass
28922895
class MyClass:
28932896
x: str
28942897
y: int
28952898
batch_size: Any
28962899

28972900
with pytest.raises(AttributeError):
2901+
28982902
@tensorclass
28992903
class MyClass: # noqa: F811
29002904
x: str
29012905
y: int
29022906
names: Any
29032907

29042908
with pytest.raises(AttributeError):
2909+
29052910
@tensorclass
29062911
class MyClass: # noqa: F811
29072912
x: str
@@ -3099,7 +3104,7 @@ class MyClass:
30993104
_ = c / 1
31003105
_ = 1 / c
31013106

3102-
_ = c ** 1
3107+
_ = c**1
31033108
# not implemented
31043109
# 1 ** c
31053110

@@ -3299,13 +3304,15 @@ class TensorOnly:
32993304
c: torch.Tensor | None = None
33003305

33013306
with pytest.raises(TypeError, match="tensor_only"):
3307+
33023308
@tensorclass(tensor_only=True, nocast=True)
33033309
class TensorOnlyNocast:
33043310
a: torch.Tensor
33053311
b: torch.Tensor
33063312
c: torch.Tensor | None = None
33073313

33083314
with pytest.raises(TypeError, match="tensor_only"):
3315+
33093316
@tensorclass(tensor_only=True, autocast=True)
33103317
class TensorOnlyAutocast:
33113318
a: torch.Tensor
@@ -3330,6 +3337,7 @@ class TensorOnly(TensorClass["tensor_only"]):
33303337
TypeError,
33313338
match="tensor_only requires types to be Tensor, Tensor-subtrypes or None",
33323339
):
3340+
33333341
class TensorOnlyAny(TensorClass["tensor_only"]):
33343342
a: torch.Tensor
33353343
b: Any
@@ -3339,6 +3347,7 @@ class TensorOnlyAny(TensorClass["tensor_only"]):
33393347
TypeError,
33403348
match="tensor_only requires types to be Tensor, Tensor-subtrypes or None",
33413349
):
3350+
33423351
class TensorOnlyStr(TensorClass["tensor_only"]):
33433352
a: torch.Tensor
33443353
b: torch.Tensor | str
@@ -3348,6 +3357,7 @@ class TensorOnlyStr(TensorClass["tensor_only"]):
33483357
TypeError,
33493358
match="tensor_only requires types to be Tensor, Tensor-subtrypes or None",
33503359
):
3360+
33513361
class TensorOnlyStrUnion(TensorClass["tensor_only"]):
33523362
a: torch.Tensor
33533363
b: torch.Tensor

0 commit comments

Comments
 (0)