1
1
from enum import Enum
2
2
from redis import StrictRedis
3
3
from ._util import to_string
4
+ import six
4
5
5
6
try :
6
7
import numpy as np
7
8
except ImportError :
8
9
np = None
9
10
10
11
try :
11
- from typing import Union , Any , AnyStr , ByteString , Collection
12
+ from typing import Union , Any , AnyStr , ByteString , Collection , Type
12
13
except ImportError :
13
14
pass
14
15
@@ -41,6 +42,12 @@ class DType(Enum):
41
42
float64 = 'double'
42
43
43
44
45
+ def _str_or_strlist (v ):
46
+ if isinstance (v , six .string_types ):
47
+ return [v ]
48
+ return v
49
+
50
+
44
51
class Tensor (object ):
45
52
ARGNAME = 'VALUES'
46
53
@@ -158,13 +165,13 @@ def modelset(self,
158
165
name , # type: AnyStr
159
166
backend , # type: Backend
160
167
device , # type: Device
161
- inputs , # type: Collection[AnyStr]
162
- outputs , # type: Collection[AnyStr]
168
+ input , # type: Union[AnyStr| Collection[AnyStr] ]
169
+ output , # type: Union[AnyStr| Collection[AnyStr] ]
163
170
data # type: ByteString
164
171
):
165
172
args = ['AI.MODELSET' , name , backend .value , device .value , 'INPUTS' ]
166
- args += inputs
167
- args += ['OUTPUTS' ] + outputs
173
+ args += _str_or_strlist ( input )
174
+ args += ['OUTPUTS' ] + _str_or_strlist ( output )
168
175
args += [data ]
169
176
return self .execute_command (* args )
170
177
@@ -176,9 +183,14 @@ def modelget(self, name):
176
183
'data' : rv [2 ]
177
184
}
178
185
179
- def modelrun (self , name , inputs , outputs ):
186
+ def modelrun (self ,
187
+ name ,
188
+ input , # type: Union[AnyStr|Collection[AnyStr]]
189
+ output # type: Union[AnyStr|Collection[AnyStr]]
190
+ ):
180
191
args = ['AI.MODELRUN' , name ]
181
- args += ['INPUTS' ] + inputs + ['OUTPUTS' ] + outputs
192
+ args += ['INPUTS' ] + _str_or_strlist (input )
193
+ args += ['OUTPUTS' ] + _str_or_strlist (output )
182
194
return self .execute_command (* args )
183
195
184
196
def tensorset (self , key , tensor ):
@@ -196,22 +208,23 @@ def tensorset(self, key, tensor):
196
208
args += tensor .value
197
209
return self .execute_command (* args )
198
210
199
- def tensorget (self , key , astype = Tensor , meta_only = False ):
211
+ def tensorget (self , key , as_type = Tensor , meta_only = False ):
212
+ # type: (AnyStr, Type[Tensor], bool) -> Tensor
200
213
"""
201
214
Retrieve the value of a tensor from the server
202
215
:param key: the name of the tensor
203
- :param astype : the resultant tensor type
216
+ :param as_type : the resultant tensor type
204
217
:param meta_only: if true, then the value is not retrieved,
205
218
only the shape and the type
206
- :return: an instance of astype
219
+ :return: an instance of as_type
207
220
"""
208
- argname = 'META' if meta_only else astype .ARGNAME
221
+ argname = 'META' if meta_only else as_type .ARGNAME
209
222
res = self .execute_command ('AI.TENSORGET' , key , argname )
210
223
dtype , shape = to_string (res [0 ]), res [1 ]
211
224
if meta_only :
212
- return astype (dtype , shape , [])
225
+ return as_type (dtype , shape , [])
213
226
else :
214
- return astype (dtype , shape , res [2 ])
227
+ return as_type (dtype , shape , res [2 ])
215
228
216
229
def scriptset (self , name , device , script ):
217
230
return self .execute_command ('AI.SCRIPTSET' , name , device .value , script )
@@ -223,9 +236,14 @@ def scriptget(self, name):
223
236
'script' : to_string (r [1 ])
224
237
}
225
238
226
- def scriptrun (self , name , function , inputs , outputs ):
239
+ def scriptrun (self ,
240
+ name ,
241
+ function , # type: AnyStr
242
+ input , # type: Union[AnyStr|Collection[AnyStr]]
243
+ output # type: Union[AnyStr|Collection[AnyStr]]
244
+ ):
227
245
args = ['AI.SCRIPTRUN' , name , function , 'INPUTS' ]
228
- args += inputs
246
+ args += _str_or_strlist ( input )
229
247
args += ['OUTPUTS' ]
230
- args += outputs
248
+ args += _str_or_strlist ( output )
231
249
return self .execute_command (* args )
0 commit comments