Skip to content

Commit 460529a

Browse files
authored
Boolean tensors support (#89)
* added support for dtype 'bool' * added test cases for boolean tensors * updated .gitignore file * improved tests * lint fix
1 parent 292d9ed commit 460529a

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
.pydevproject
33
*.pyc
44
.venv/
5+
venv/
56
redisai.egg-info
67
.idea
78
.mypy_cache/
89
build/
910
dist/
1011
docs/_build/
1112
.DS_Store
13+
.vscode

redisai/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"uint16": "UINT16",
1515
"uint32": "UINT32",
1616
"uint64": "UINT64",
17+
"bool": "BOOL",
1718
}
1819

1920
allowed_devices = {"CPU", "GPU"}

test/test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def test_set_non_numpy_tensor(self):
111111
self.assertEqual([2, 3, 4, 5], result["values"])
112112
self.assertEqual([2, 2], result["shape"])
113113

114+
con.tensorset("x", (1, 1, 0, 0), dtype="bool", shape=(2, 2))
115+
result = con.tensorget("x", as_numpy=False)
116+
self.assertEqual([True, True, False, False], result["values"])
117+
self.assertEqual([2, 2], result["shape"])
118+
self.assertEqual("BOOL", result["dtype"])
119+
114120
with self.assertRaises(TypeError):
115121
con.tensorset("x", (2, 3, 4, 5), dtype="wrongtype", shape=(2, 2))
116122
con.tensorset("x", (2, 3, 4, 5), dtype="int8", shape=(2, 2))
@@ -144,6 +150,12 @@ def test_numpy_tensor(self):
144150
values = con.tensorget("x")
145151
self.assertEqual(values.dtype, np.float64)
146152

153+
input_array = np.array([True, False])
154+
con.tensorset("x", input_array)
155+
values = con.tensorget("x")
156+
self.assertEqual(values.dtype, "bool")
157+
self.assertTrue(np.array_equal(values, [True, False]))
158+
147159
input_array = np.array([2, 3])
148160
con.tensorset("x", input_array)
149161
values = con.tensorget("x")

0 commit comments

Comments
 (0)