Skip to content

Commit dca85ba

Browse files
committed
more api improvements
1 parent 2502530 commit dca85ba

File tree

1 file changed

+34
-16
lines changed

1 file changed

+34
-16
lines changed

redisai/client.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from enum import Enum
22
from redis import StrictRedis
33
from ._util import to_string
4+
import six
45

56
try:
67
import numpy as np
78
except ImportError:
89
np = None
910

1011
try:
11-
from typing import Union, Any, AnyStr, ByteString, Collection
12+
from typing import Union, Any, AnyStr, ByteString, Collection, Type
1213
except ImportError:
1314
pass
1415

@@ -41,6 +42,12 @@ class DType(Enum):
4142
float64 = 'double'
4243

4344

45+
def _str_or_strlist(v):
46+
if isinstance(v, six.string_types):
47+
return [v]
48+
return v
49+
50+
4451
class Tensor(object):
4552
ARGNAME = 'VALUES'
4653

@@ -158,13 +165,13 @@ def modelset(self,
158165
name, # type: AnyStr
159166
backend, # type: Backend
160167
device, # type: Device
161-
inputs, # type: Collection[AnyStr]
162-
outputs, # type: Collection[AnyStr]
168+
input, # type: Union[AnyStr|Collection[AnyStr]]
169+
output, # type: Union[AnyStr|Collection[AnyStr]]
163170
data # type: ByteString
164171
):
165172
args = ['AI.MODELSET', name, backend.value, device.value, 'INPUTS']
166-
args += inputs
167-
args += ['OUTPUTS'] + outputs
173+
args += _str_or_strlist(input)
174+
args += ['OUTPUTS'] + _str_or_strlist(output)
168175
args += [data]
169176
return self.execute_command(*args)
170177

@@ -176,9 +183,14 @@ def modelget(self, name):
176183
'data': rv[2]
177184
}
178185

179-
def modelrun(self, name, inputs, outputs):
186+
def modelrun(self,
187+
name,
188+
input, # type: Union[AnyStr|Collection[AnyStr]]
189+
output # type: Union[AnyStr|Collection[AnyStr]]
190+
):
180191
args = ['AI.MODELRUN', name]
181-
args += ['INPUTS'] + inputs + ['OUTPUTS'] + outputs
192+
args += ['INPUTS'] + _str_or_strlist(input)
193+
args += ['OUTPUTS'] + _str_or_strlist(output)
182194
return self.execute_command(*args)
183195

184196
def tensorset(self, key, tensor):
@@ -196,22 +208,23 @@ def tensorset(self, key, tensor):
196208
args += tensor.value
197209
return self.execute_command(*args)
198210

199-
def tensorget(self, key, astype=Tensor, meta_only=False):
211+
def tensorget(self, key, as_type=Tensor, meta_only=False):
212+
# type: (AnyStr, Type[Tensor], bool) -> Tensor
200213
"""
201214
Retrieve the value of a tensor from the server
202215
:param key: the name of the tensor
203-
:param astype: the resultant tensor type
216+
:param as_type: the resultant tensor type
204217
:param meta_only: if true, then the value is not retrieved,
205218
only the shape and the type
206-
:return: an instance of astype
219+
:return: an instance of as_type
207220
"""
208-
argname = 'META' if meta_only else astype.ARGNAME
221+
argname = 'META' if meta_only else as_type.ARGNAME
209222
res = self.execute_command('AI.TENSORGET', key, argname)
210223
dtype, shape = to_string(res[0]), res[1]
211224
if meta_only:
212-
return astype(dtype, shape, [])
225+
return as_type(dtype, shape, [])
213226
else:
214-
return astype(dtype, shape, res[2])
227+
return as_type(dtype, shape, res[2])
215228

216229
def scriptset(self, name, device, script):
217230
return self.execute_command('AI.SCRIPTSET', name, device.value, script)
@@ -223,9 +236,14 @@ def scriptget(self, name):
223236
'script': to_string(r[1])
224237
}
225238

226-
def scriptrun(self, name, function, inputs, outputs):
239+
def scriptrun(self,
240+
name,
241+
function, # type: AnyStr
242+
input, # type: Union[AnyStr|Collection[AnyStr]]
243+
output # type: Union[AnyStr|Collection[AnyStr]]
244+
):
227245
args = ['AI.SCRIPTRUN', name, function, 'INPUTS']
228-
args += inputs
246+
args += _str_or_strlist(input)
229247
args += ['OUTPUTS']
230-
args += outputs
248+
args += _str_or_strlist(output)
231249
return self.execute_command(*args)

0 commit comments

Comments
 (0)