Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit 56c990b

Browse files
Errors propagate through transaction (#247)
## What is the goal of this PR? We align with Client NodeJS version 2.6.1 (some of the work in typedb/typedb-driver-nodejs#197), which implements a better error propagation mechanism: when an exception occurs, we store it against all the transaction's active transmit queues to retrieve whenever the user tries to perform an operation in the transaction anywhere. ## What are the changes implemented in this PR? * store errors received from gRPC against each receive queue * return a new exception type, transaction is closed with errors, which throws the errors from all queues (note that this can be duplicate if there are multiple open transmit queues that have been given the same error) * we clean up queues that are no longer needed, so we minimise the number of times the user sees the same exception
1 parent 5acc0dd commit 56c990b

File tree

4 files changed

+66
-43
lines changed

4 files changed

+66
-43
lines changed

typedb/connection/transaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
113113
return False
114114

115115
def _raise_transaction_closed(self):
116-
errors = self._bidirectional_stream.drain_errors()
116+
errors = self._bidirectional_stream.get_errors()
117117
if len(errors) == 0:
118118
raise TypeDBClientException.of(TRANSACTION_CLOSED)
119119
else:

typedb/stream/bidirectional_stream.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ def stream(self, req: transaction_proto.Transaction.Req) -> Iterator[transaction
6262
self._dispatcher.dispatch(req)
6363
return ResponsePartIterator(request_id, self, self._dispatcher)
6464

65+
def done(self, request_id: UUID):
66+
self._response_collector.remove(request_id)
67+
6568
def is_open(self) -> bool:
6669
return self._is_open.get()
6770

@@ -78,8 +81,9 @@ def fetch(self, request_id: UUID) -> Union[transaction_proto.Transaction.Res, tr
7881
raise TypeDBClientException.of(TRANSACTION_CLOSED)
7982
server_msg = next(self._response_iterator)
8083
except RpcError as e:
81-
self.close(e)
82-
raise TypeDBClientException.of_rpc(e)
84+
error = TypeDBClientException.of_rpc(e)
85+
self.close(error)
86+
raise error
8387
except StopIteration:
8488
self.close()
8589
raise TypeDBClientException.of(TRANSACTION_CLOSED)
@@ -100,10 +104,10 @@ def _collect(self, response: Union[transaction_proto.Transaction.Res, transactio
100104
else:
101105
raise TypeDBClientException.of(UNKNOWN_REQUEST_ID, request_id)
102106

103-
def drain_errors(self) -> List[RpcError]:
104-
return self._response_collector.drain_errors()
107+
def get_errors(self) -> List[TypeDBClientException]:
108+
return self._response_collector.get_errors()
105109

106-
def close(self, error: RpcError = None):
110+
def close(self, error: TypeDBClientException = None):
107111
if self._is_open.compare_and_set(True, False):
108112
self._response_collector.close(error)
109113
try:
@@ -127,7 +131,9 @@ def __init__(self, request_id: UUID, stream: "BidirectionalStream"):
127131
self._stream = stream
128132

129133
def get(self) -> T:
130-
return self._stream.fetch(self._request_id)
134+
value = self._stream.fetch(self._request_id)
135+
self._stream.done(self._request_id)
136+
return value
131137

132138

133139
class RequestIterator(Iterator[Union[transaction_proto.Transaction.Req, StopIteration]]):

typedb/stream/response_collector.py

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,81 +21,98 @@
2121

2222
import queue
2323
from threading import Lock
24-
from typing import Generic, TypeVar, Dict, Optional, Union
24+
from typing import Generic, TypeVar, Dict, Optional
2525
from uuid import UUID
2626

27-
from grpc import RpcError
28-
29-
from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED
27+
from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED, ILLEGAL_STATE, \
28+
TRANSACTION_CLOSED_WITH_ERRORS
3029

3130
R = TypeVar('R')
3231

3332

3433
class ResponseCollector(Generic[R]):
3534

3635
def __init__(self):
37-
self._collectors: Dict[UUID, ResponseCollector.Queue[R]] = {}
36+
self._response_queues: Dict[UUID, ResponseCollector.Queue[R]] = {}
3837
self._collectors_lock = Lock()
3938

4039
def new_queue(self, request_id: UUID):
4140
with self._collectors_lock:
4241
collector: ResponseCollector.Queue[R] = ResponseCollector.Queue()
43-
self._collectors[request_id] = collector
42+
self._response_queues[request_id] = collector
4443
return collector
4544

4645
def get(self, request_id: UUID) -> Optional["ResponseCollector.Queue[R]"]:
47-
return self._collectors.get(request_id)
46+
return self._response_queues.get(request_id)
47+
48+
def remove(self, request_id: UUID):
49+
with self._collectors_lock:
50+
del self._response_queues[request_id]
4851

49-
def close(self, error: Optional[RpcError]):
52+
def close(self, error: Optional[TypeDBClientException]):
5053
with self._collectors_lock:
51-
for collector in self._collectors.values():
54+
for collector in self._response_queues.values():
5255
collector.close(error)
5356

54-
def drain_errors(self) -> [RpcError]:
57+
def get_errors(self) -> [TypeDBClientException]:
5558
errors = []
5659
with self._collectors_lock:
57-
for collector in self._collectors.values():
58-
errors.extend(collector.drain_errors())
60+
for collector in self._response_queues.values():
61+
error = collector.get_error()
62+
if error is not None:
63+
errors.append(error)
5964
return errors
6065

61-
6266
class Queue(Generic[R]):
6367

6468
def __init__(self):
65-
self._response_queue: queue.Queue[Union[Response[R], Done]] = queue.Queue()
69+
self._response_queue: queue.Queue[Response] = queue.Queue()
70+
self._error: TypeDBClientException = None
6671

6772
def get(self, block: bool) -> R:
6873
response = self._response_queue.get(block=block)
69-
if response.message:
70-
return response.message
71-
elif response.error:
72-
raise TypeDBClientException.of_rpc(response.error)
73-
else:
74+
if response.is_value():
75+
return response.value
76+
elif response.is_done() and self._error is None:
7477
raise TypeDBClientException.of(TRANSACTION_CLOSED)
78+
elif response.is_done() and self._error is not None:
79+
raise TypeDBClientException.of(TRANSACTION_CLOSED_WITH_ERRORS, self._error)
80+
else:
81+
raise TypeDBClientException.of(ILLEGAL_STATE)
7582

7683
def put(self, response: R):
77-
self._response_queue.put(Response(response))
84+
self._response_queue.put(ValueResponse(response))
7885

79-
def close(self, error: Optional[RpcError]):
80-
self._response_queue.put(Done(error))
86+
def close(self, error: Optional[TypeDBClientException]):
87+
self._error = error
88+
self._response_queue.put(DoneResponse())
8189

82-
def drain_errors(self) -> [RpcError]:
83-
errors = []
84-
while not self._response_queue.empty():
85-
response = self._response_queue.get(block = False)
86-
if response.error:
87-
errors.append(response.error)
88-
return errors
90+
def get_error(self) -> TypeDBClientException:
91+
return self._error
8992

9093

94+
class Response:
9195

92-
class Response(Generic[R]):
96+
def is_value(self):
97+
return False
98+
99+
def is_done(self):
100+
return False
101+
102+
103+
class ValueResponse(Response, Generic[R]):
93104

94105
def __init__(self, value: R):
95-
self.message = value
106+
self.value = value
96107

108+
def is_value(self):
109+
return True
97110

98-
class Done:
99111

100-
def __init__(self, error: Optional[RpcError]):
101-
self.error = error
112+
class DoneResponse(Response):
113+
114+
def __init__(self):
115+
pass
116+
117+
def is_done(self):
118+
return True

typedb/stream/response_part_iterator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
# specific language governing permissions and limitations
1919
# under the License.
2020
#
21-
from enum import Enum
2221
from typing import Iterator, TYPE_CHECKING
2322
from uuid import UUID
2423

2524
import typedb_protocol.common.transaction_pb2 as transaction_proto
26-
25+
from enum import Enum
2726
from typedb.common.exception import TypeDBClientException, ILLEGAL_ARGUMENT, MISSING_RESPONSE, ILLEGAL_STATE
2827
from typedb.common.rpc.request_builder import transaction_stream_req
2928
from typedb.stream.request_transmitter import RequestTransmitter
@@ -78,6 +77,7 @@ def _has_next(self) -> bool:
7877

7978
def __next__(self) -> transaction_proto.Transaction.ResPart:
8079
if not self._has_next():
80+
self._bidirectional_stream.done(self._request_id)
8181
raise StopIteration
8282
self._state = ResponsePartIterator.State.EMPTY
8383
return self._next

0 commit comments

Comments
 (0)