Skip to content

Commit 6a394f5

Browse files
authored
Add support for config command (#98)
* Add support for ai.config command * Sort postprocessing methods * Fix some warnings
1 parent 6e8ecb7 commit 6a394f5

File tree

4 files changed

+138
-59
lines changed

4 files changed

+138
-59
lines changed

redisai/client.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,44 @@ def loadbackend(self, identifier: AnyStr, path: AnyStr) -> str:
145145
'OK'
146146
"""
147147
args = builder.loadbackend(identifier, path)
148-
res = self.execute_command(*args)
148+
res = self.execute_command(args)
149149
return res if not self.enable_postprocess else processor.loadbackend(res)
150150

151+
def config(self, name: str, value: Union[str, int, None] = None) -> str:
152+
"""
153+
Get/Set configuration item. Current available configurations are: BACKENDSPATH and MODEL_CHUNK_SIZE.
154+
For more details, see: https://oss.redis.com/redisai/master/commands/#aiconfig.
155+
If value is given - the configuration under name will be overriten.
156+
157+
Parameters
158+
----------
159+
name: str
160+
RedisAI config item to retreive/override (BACKENDSPATH / MODEL_CHUNK_SIZE).
161+
value: Union[str, int]
162+
Value to set the config item with (if given).
163+
164+
Returns
165+
-------
166+
The current configuration value if value is None,
167+
'OK' if value was given and configuration overitten succeeded,
168+
raise an exception otherwise
169+
170+
171+
Example
172+
-------
173+
>>> con.config('MODEL_CHUNK_SIZE', 128 * 1024)
174+
'OK'
175+
>>> con.config('BACKENDSPATH', '/my/backends/path')
176+
'OK'
177+
>>> con.config('BACKENDSPATH')
178+
'/my/backends/path'
179+
>>> con.config('MODEL_CHUNK_SIZE')
180+
'131072'
181+
"""
182+
args = builder.config(name, value)
183+
res = self.execute_command(args)
184+
return res if not self.enable_postprocess or not isinstance(res, bytes) else processor.config(res)
185+
151186
def modelstore(
152187
self,
153188
key: AnyStr,
@@ -209,6 +244,7 @@ def modelstore(
209244
... inputs=['a', 'b'], outputs=['mul'], tag='v1.0')
210245
'OK'
211246
"""
247+
chunk_size = self.config('MODEL_CHUNK_SIZE')
212248
args = builder.modelstore(
213249
key,
214250
backend,
@@ -220,6 +256,7 @@ def modelstore(
220256
tag,
221257
inputs,
222258
outputs,
259+
chunk_size=chunk_size
223260
)
224261
res = self.execute_command(*args)
225262
return res if not self.enable_postprocess else processor.modelstore(res)

redisai/command_builder.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88

99

1010
def loadbackend(identifier: AnyStr, path: AnyStr) -> Sequence:
11-
return "AI.CONFIG LOADBACKEND", identifier, path
11+
return f'AI.CONFIG LOADBACKEND {identifier} {path}'
12+
13+
14+
def config(name: str, value: Union[str, int, None] = None) -> Sequence:
15+
if value is not None:
16+
return f'AI.CONFIG {name} {value}'
17+
return f'AI.CONFIG GET {name}'
1218

1319

1420
def modelstore(
@@ -22,6 +28,7 @@ def modelstore(
2228
tag: AnyStr,
2329
inputs: Union[AnyStr, List[AnyStr]],
2430
outputs: Union[AnyStr, List[AnyStr]],
31+
chunk_size: int = 500 * 1024 * 1024
2532
) -> Sequence:
2633
if name is None:
2734
raise ValueError("Model name was not given")
@@ -66,9 +73,7 @@ def modelstore(
6673
raise ValueError(
6774
"Inputs and outputs keywords should not be specified for this backend"
6875
)
69-
chunk_size = 500 * 1024 * 1024 # TODO: this should be configurable.
7076
data_chunks = [data[i: i + chunk_size] for i in range(0, len(data), chunk_size)]
71-
# TODO: need a test case for this
7277
args += ["BLOB", *data_chunks]
7378
return args
7479

redisai/postprocessor.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,28 @@ def tensorget(res, as_numpy, as_numpy_mutable, meta_only):
2727
rai_result = utils.list2dict(res)
2828
if meta_only is True:
2929
return rai_result
30-
elif as_numpy_mutable is True:
30+
if as_numpy_mutable is True:
3131
return utils.blob2numpy(
3232
rai_result["blob"],
3333
rai_result["shape"],
3434
rai_result["dtype"],
3535
mutable=True,
3636
)
37-
elif as_numpy is True:
37+
if as_numpy is True:
3838
return utils.blob2numpy(
3939
rai_result["blob"],
4040
rai_result["shape"],
4141
rai_result["dtype"],
4242
mutable=False,
4343
)
44+
45+
if rai_result["dtype"] == "STRING":
46+
def target(b):
47+
return b.decode()
4448
else:
45-
if rai_result["dtype"] == "STRING":
46-
def target(b):
47-
return b.decode()
48-
else:
49-
target = float if rai_result["dtype"] in ("FLOAT", "DOUBLE") else int
50-
utils.recursive_bytetransform(rai_result["values"], target)
51-
return rai_result
49+
target = float if rai_result["dtype"] in ("FLOAT", "DOUBLE") else int
50+
utils.recursive_bytetransform(rai_result["values"], target)
51+
return rai_result
5252

5353
@staticmethod
5454
def scriptget(res):
@@ -66,19 +66,20 @@ def infoget(res):
6666
# These functions are only doing decoding on the output from redis
6767
decoder = staticmethod(decoder)
6868
decoding_functions = (
69+
"config",
70+
"inforeset",
6971
"loadbackend",
70-
"modelstore",
71-
"modelset",
7272
"modeldel",
7373
"modelexecute",
7474
"modelrun",
75-
"tensorset",
76-
"scriptset",
77-
"scriptstore",
75+
"modelset",
76+
"modelstore",
7877
"scriptdel",
79-
"scriptrun",
8078
"scriptexecute",
81-
"inforeset",
79+
"scriptrun",
80+
"scriptset",
81+
"scriptstore",
82+
"tensorset",
8283
)
8384
for fn in decoding_functions:
8485
setattr(Processor, fn, decoder)

0 commit comments

Comments
 (0)