Skip to content

Commit 824fa53

Browse files
authored
Helpers (#4)
Update helper methods
1 parent e7dfe39 commit 824fa53

File tree

2 files changed

+28
-15
lines changed

2 files changed

+28
-15
lines changed

image_dataset_viz/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
__version__ = '0.2.1'
2+
__version__ = '0.2.1.1'
33

44
import numpy as np
55

@@ -10,13 +10,13 @@ def bbox_to_points(bbox_xyxy):
1010
"""Helper method to convert bounding box as list/tuple of [x1, y1, x2, y2] into points `ndarray` of shape (N, 2)
1111
1212
Args:
13-
bbox_xyxy (list or tuple): bounding box as list/tuple of [x1, y1, x2, y2]
13+
bbox_xyxy (list or tuple or ndarray): bounding box as list/tuple of [x1, y1, x2, y2]
1414
1515
Returns:
1616
ndarray of points
1717
1818
"""
19-
assert isinstance(bbox_xyxy, (list, tuple)) and len(bbox_xyxy) == 4, \
19+
assert isinstance(bbox_xyxy, (list, tuple, np.ndarray)) and len(bbox_xyxy) == 4, \
2020
"Argument bbox_xyxy should be a list/tuple of [x1, y1, x2, y2]"
2121

2222
return np.array([
@@ -37,7 +37,7 @@ def xywh_to_xyxy(xywh):
3737
list [x1, y1, x2, y2]
3838
3939
"""
40-
assert isinstance(xywh, (list, tuple)) and len(xywh) == 4, \
40+
assert isinstance(xywh, (list, tuple, np.ndarray)) and len(xywh) == 4, \
4141
"Argument xywh should be a list/tuple of [x1, y1, width, height]"
4242
x1, y1 = xywh[0], xywh[1]
4343
x2 = x1 + max(0, xywh[2] - 1)
@@ -55,7 +55,7 @@ def xyxy_to_xywh(xyxy):
5555
list [x1, y1, width, height]
5656
5757
"""
58-
assert isinstance(xyxy, (list, tuple)) and len(xyxy) == 4, \
58+
assert isinstance(xyxy, (list, tuple, np.ndarray)) and len(xyxy) == 4, \
5959
"Argument xyxy should be a list/tuple of [x1, y1, x2, y2]"
6060
x1, y1 = xyxy[0], xyxy[1]
6161
w = xyxy[2] - x1 + 1

tests/test_image_dataset_viz.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,45 @@
77

88
class TestHelperMethods(TestCase):
99

10-
def test_bbox_to_points(self):
10+
def _test_func(self, input_, true_output, func):
11+
12+
output = func(list(input_))
13+
self.assertTrue(np.equal(true_output, output).all())
14+
15+
output = func(tuple(input_))
16+
self.assertTrue(np.equal(true_output, output).all())
17+
18+
output = func(np.array(input_))
19+
self.assertTrue(np.equal(true_output, output).all())
20+
21+
with self.assertRaises(AssertionError):
22+
func("1234")
23+
24+
with self.assertRaises(AssertionError):
25+
func([1, 2])
1126

27+
with self.assertRaises(AssertionError):
28+
func([1, 2, 3, 4, 5])
29+
30+
def test_bbox_to_points(self):
1231
bbox = (10, 12, 34, 45)
1332
true_points = np.array([
1433
[10, 12],
1534
[34, 12],
1635
[34, 45],
1736
[10, 45]
1837
])
19-
points = bbox_to_points(bbox)
20-
21-
self.assertTrue((true_points == points).all())
38+
self._test_func(bbox, true_points, bbox_to_points)
2239

2340
def test_xywh_to_xyxy(self):
24-
2541
xywh = (10, 12, 34, 45)
2642
true_xyxy = [10, 12, 43, 56]
27-
xyxy = xywh_to_xyxy(xywh)
28-
self.assertEqual(true_xyxy, xyxy)
43+
self._test_func(xywh, true_xyxy, xywh_to_xyxy)
2944

3045
def test_xyxy_to_xywh(self):
31-
3246
xyxy = [10, 12, 43, 56]
3347
true_xywh = [10, 12, 34, 45]
34-
xywh = xyxy_to_xywh(xyxy)
35-
self.assertEqual(true_xywh, xywh)
48+
self._test_func(xyxy, true_xywh, xyxy_to_xywh)
3649

3750

3851
if __name__ == "__main__":

0 commit comments

Comments
 (0)