Skip to content

Commit 35b1e3f

Browse files
authored
Allow using device id (#91)
* Add the option to use device id (i.e., "cpu\gpu:<number>", instead of just "cpu/gpu")
1 parent 460529a commit 35b1e3f

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

redisai/command_builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ def modelstore(
2525
) -> Sequence:
2626
if name is None:
2727
raise ValueError("Model name was not given")
28-
if device.upper() not in utils.allowed_devices:
28+
29+
# device format should be: "CPU | GPU [:<num>]"
30+
device_type = device.split(":")[0]
31+
if device_type.upper() not in utils.allowed_devices:
2932
raise ValueError(f"Device not allowed. Use any from {utils.allowed_devices}")
3033
if backend.upper() not in utils.allowed_backends:
3134
raise ValueError(f"Backend not allowed. Use any from {utils.allowed_backends}")

test/test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def test_modelexecute_non_list_input_output(self):
342342
ret = con.modelexecute("m", ["a", "b"], "out")
343343
self.assertEqual(ret, "OK")
344344

345-
def test_nonasciichar(self):
345+
def test_non_ascii_char(self):
346346
nonascii = "ĉ"
347347
model_path = os.path.join(MODEL_DIR, tf_graph)
348348
model_pb = load_model(model_path)
@@ -363,6 +363,21 @@ def test_nonasciichar(self):
363363
tensor = con.tensorget("c" + nonascii)
364364
self.assertTrue((np.allclose(tensor, [4.0, 9.0])))
365365

366+
def test_device_with_id(self):
367+
model_path = os.path.join(MODEL_DIR, tf_graph)
368+
model_pb = load_model(model_path)
369+
con = self.get_client()
370+
ret = con.modelstore(
371+
"m",
372+
"tf",
373+
"cpu:1",
374+
model_pb,
375+
inputs=["a", "b"],
376+
outputs=["mul"],
377+
tag="v1.0",
378+
)
379+
self.assertEqual('OK', ret)
380+
366381
def test_run_tf_model(self):
367382
model_path = os.path.join(MODEL_DIR, tf_graph)
368383
bad_model_path = os.path.join(MODEL_DIR, torch_graph)

0 commit comments

Comments
 (0)