Skip to content

Commit 35310d1

Browse files
author
WatcherBox
committed
add support for numpy arrays of atomic types
1 parent 2a78a49 commit 35310d1

File tree

4 files changed

+59
-4
lines changed

4 files changed

+59
-4
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ packages = find:
2121
install_requires =
2222
PyYAML>=5.3.1
2323
pyserial>=3.4
24+
numpy>=1.19.4
2425

2526
[options.entry_points]
2627
console_scripts =

simple_rpc/io.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from typing import Any, BinaryIO
23
from struct import calcsize, pack, unpack
34

@@ -43,9 +44,12 @@ def _write_basic(
4344
if basic_type == 's':
4445
stream.write(value + b'\0')
4546
return
47+
48+
elif isinstance(value, np.ndarray):
49+
stream.write(value.tobytes())
50+
return
4651

4752
full_type = (endianness + basic_type).encode('utf-8')
48-
4953
stream.write(pack(full_type, cast(basic_type)(value)))
5054

5155

@@ -82,9 +86,15 @@ def read(
8286
return [
8387
read(stream, endianness, size_t, item) for _ in range(length)
8488
for item in obj_type]
89+
8590
if isinstance(obj_type, tuple):
8691
return tuple(
8792
read(stream, endianness, size_t, item) for item in obj_type)
93+
94+
if isinstance(obj_type, np.ndarray):
95+
length = _read_basic(stream, endianness, size_t)
96+
return np.frombuffer(
97+
stream.read(length * obj_type.itemsize), obj_type.dtype)
8898
return _read_basic(stream, endianness, obj_type)
8999

90100

@@ -104,14 +114,21 @@ def write(
104114
:arg obj: Object of type {obj_type}.
105115
"""
106116
if isinstance(obj_type, list):
117+
# print(f" size_t: {size_t}, len:{len(obj) // len(obj_type)}")
107118
_write_basic(stream, endianness, size_t, len(obj) // len(obj_type))
108-
if isinstance(obj_type, list) or isinstance(obj_type, tuple):
119+
if isinstance(obj_type, np.ndarray):
120+
# print(f"writing array: {size_t}, {obj.size}, {obj.dtype}, obj_tpye: {obj_type}")
121+
_write_basic(stream, endianness, size_t, obj.size)
122+
_write_basic(stream, endianness, obj_type, obj)
123+
elif isinstance(obj_type, list) or isinstance(obj_type, tuple):
109124
for item_type, item in zip(obj_type * len(obj), obj):
110125
write(stream, endianness, size_t, item_type, item)
111126
else:
112127
_write_basic(stream, endianness, obj_type, obj)
113128

114129

130+
131+
115132
def until(
116133
condition: callable, f: callable, *args: Any, **kwargs: Any) -> None:
117134
"""Call {f(*args, **kwargs)} until {condition} is true.

simple_rpc/protocol.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,25 @@
1+
import numpy as np
2+
13
from typing import Any, BinaryIO
24

35
from .io import cast, read_byte_string
46

7+
dtype_map = {
8+
'b': np.int8,
9+
'B': np.uint8,
10+
'h': np.int16,
11+
'H': np.uint16,
12+
'i': np.int32,
13+
'I': np.uint32,
14+
'l': np.int32,
15+
'L': np.uint32,
16+
'q': np.int64,
17+
'Q': np.uint64,
18+
'f': np.float32,
19+
'd': np.float64,
20+
'?': np.bool_,
21+
'c': np.byte # Note: 'c' in struct is a single byte; for strings, consider np.bytes_ or np.chararray.
22+
}
523

624
def _parse_type(type_str: bytes) -> Any:
725
"""Parse a type definition string.
@@ -18,7 +36,12 @@ def _construct_type(tokens: tuple):
1836
obj_type.append(_construct_type(tokens))
1937
elif token == b'(':
2038
obj_type.append(tuple(_construct_type(tokens)))
21-
elif token in (b')', b']'):
39+
elif token == b'{':
40+
t = _construct_type(tokens)
41+
assert len(t) == 1, 'only atomic types allowed in np arrays'
42+
dtype = _get_dtype(t[0])
43+
obj_type.append(np.ndarray(dtype=dtype, shape=(1, 1)))
44+
elif token in (b')', b']', b'}'):
2245
break
2346
else:
2447
obj_type.append(token.decode())
@@ -33,6 +56,15 @@ def _construct_type(tokens: tuple):
3356
return ''
3457
return obj_type[0]
3558

59+
def _get_dtype(type_str: bytes) -> Any:
60+
"""Get the NumPy data type of a type definition string.
61+
62+
:arg type_str: Type definition string.
63+
64+
:returns: NumPy data type.
65+
"""
66+
return dtype_map.get(type_str, np.byte)
67+
3668

3769
def _type_name(obj_type: Any) -> str:
3870
"""Python type name of a C object type.
@@ -41,6 +73,8 @@ def _type_name(obj_type: Any) -> str:
4173
4274
:returns: Python type name.
4375
"""
76+
if isinstance(obj_type, np.ndarray):
77+
return '{' + ', '.join([_type_name(item) for item in obj_type]) + '}'
4478
if not obj_type:
4579
return ''
4680
if isinstance(obj_type, list):

simple_rpc/simple_rpc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
2+
import numpy as np
3+
14
from functools import wraps
25
from time import sleep
36
from types import MethodType
@@ -184,7 +187,7 @@ def call_method(self: object, name: str, *args: Any) -> Any:
184187
self._write(parameter['fmt'], args[index])
185188

186189
# Read return value (if any).
187-
if method['return']['fmt']:
190+
if method['return']['fmt'] or isinstance(method['return']['fmt'], np.ndarray):
188191
return self._read(method['return']['fmt'])
189192
return None
190193

0 commit comments

Comments
 (0)