Skip to content

Commit efd24ec

Browse files
authored
Merge pull request #152 from sovrasov/upd_tests
Update tests and linters
2 parents fb3aa21 + 994faeb commit efd24ec

File tree

7 files changed

+24
-13
lines changed

7 files changed

+24
-13
lines changed

.github/workflows/main.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ jobs:
4242
pip install .[dev]
4343
- name: Testing with pytest
4444
run: |
45-
python -m pytest . -s
45+
python -m pytest ./tests -s -v
4646
- name: Linting with flake8
4747
run: |
48-
python -m flake8 .
49-
python -m isort -rc --check-only --diff ./ptflops ./tests
48+
python -m flake8 ./ptflops ./tests ./samples
49+
python -m isort -rc --check-only --diff ./ptflops ./tests ./samples

.isort.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[isort]
2-
line_length = 79
2+
line_length = 89
33
multi_line_output = 0
44
known_standard_library = setuptools
55
known_first_party = ptflops

ptflops/aten_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def f(*args):
5454

5555
def exit_module(self, name):
5656
def f(*args):
57-
assert(self.parents[-1] == name)
57+
assert (self.parents[-1] == name)
5858
self.parents.pop()
5959
return f
6060

@@ -138,7 +138,7 @@ def get_flops_aten(model, input_res,
138138

139139
except Exception as e:
140140
print("Flops estimation was not finished successfully because of"
141-
f" the following exception:\n{type(e)} : {e}")
141+
f" the following exception: \n{type(e)}: {e}")
142142
traceback.print_exc()
143143

144144
return None, None

ptflops/pytorch_engine.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import torch.nn as nn
1616
import torch.nn.functional as F
1717

18-
from .pytorch_ops import (CUSTOM_MODULES_MAPPING, FUNCTIONAL_MAPPING,
19-
MODULES_MAPPING, TENSOR_OPS_MAPPING)
18+
from .pytorch_ops import (CUSTOM_MODULES_MAPPING, FUNCTIONAL_MAPPING, MODULES_MAPPING,
19+
TENSOR_OPS_MAPPING)
2020
from .utils import flops_to_string, params_to_string
2121

2222

@@ -72,7 +72,7 @@ def reset_environment():
7272

7373
except Exception as e:
7474
print("Flops estimation was not finished successfully because of"
75-
f" the following exception:\n{type(e)} : {e}")
75+
f" the following exception: \n{type(e)}: {e}")
7676
traceback.print_exc()
7777
reset_environment()
7878

ptflops/pytorch_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def bn_flops_counter_hook(module, input, output):
5757
module.__flops__ += int(batch_flops)
5858

5959

60-
def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops=0, transpose=False):
60+
def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops=0,
61+
transpose=False):
6162
# Can have multiple inputs, getting the first one
6263
input = input[0]
6364

@@ -84,8 +85,7 @@ def conv_flops_counter_hook(conv_module, input, output, extra_per_position_flops
8485
bias_flops = 0
8586

8687
if conv_module.bias is not None:
87-
88-
bias_flops = out_channels * active_elements_count
88+
bias_flops = batch_size * int(np.prod(list(output.shape[1:]), dtype=np.int64))
8989

9090
overall_flops = overall_conv_flops + bias_flops
9191

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626

2727
[project.optional-dependencies]
2828
dev = [
29-
"flake8==3.8.1",
29+
"flake8==5.0.1",
3030
"flake8-import-order==0.18.1",
3131
"isort==4.3.21",
3232
"torchvision>=0.5.0",

tests/common_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ def test_conv(self, default_input_image_size, backend: FLOPS_BACKEND):
3333
assert params == 3 * 3 * 2 * 3 + 2
3434
assert macs == 2759904
3535

36+
@pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN])
37+
def test_conv_t(self, default_input_image_size, backend: FLOPS_BACKEND):
38+
net = nn.ConvTranspose2d(3, 2, 3, stride=(2, 2), bias=True)
39+
macs, params = get_model_complexity_info(net, default_input_image_size,
40+
as_strings=False,
41+
print_per_layer_stat=False,
42+
backend=backend)
43+
44+
assert params == 3 * 3 * 2 * 3 + 2
45+
assert macs == 3112706
46+
3647
@pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN])
3748
def test_fc(self, backend: FLOPS_BACKEND):
3849
net = nn.Sequential(nn.Linear(3, 2, bias=True))

0 commit comments

Comments
 (0)