|
21 | 21 |
|
22 | 22 | import queue
|
23 | 23 | from threading import Lock
|
24 |
| -from typing import Generic, TypeVar, Dict, Optional, Union |
| 24 | +from typing import Generic, TypeVar, Dict, Optional |
25 | 25 | from uuid import UUID
|
26 | 26 |
|
27 |
| -from grpc import RpcError |
28 |
| - |
29 |
| -from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED |
| 27 | +from typedb.common.exception import TypeDBClientException, TRANSACTION_CLOSED, ILLEGAL_STATE, \ |
| 28 | + TRANSACTION_CLOSED_WITH_ERRORS |
30 | 29 |
|
31 | 30 | R = TypeVar('R')
|
32 | 31 |
|
33 | 32 |
|
34 | 33 | class ResponseCollector(Generic[R]):
|
35 | 34 |
|
36 | 35 | def __init__(self):
|
37 |
| - self._collectors: Dict[UUID, ResponseCollector.Queue[R]] = {} |
| 36 | + self._response_queues: Dict[UUID, ResponseCollector.Queue[R]] = {} |
38 | 37 | self._collectors_lock = Lock()
|
39 | 38 |
|
40 | 39 | def new_queue(self, request_id: UUID):
|
41 | 40 | with self._collectors_lock:
|
42 | 41 | collector: ResponseCollector.Queue[R] = ResponseCollector.Queue()
|
43 |
| - self._collectors[request_id] = collector |
| 42 | + self._response_queues[request_id] = collector |
44 | 43 | return collector
|
45 | 44 |
|
46 | 45 | def get(self, request_id: UUID) -> Optional["ResponseCollector.Queue[R]"]:
|
47 |
| - return self._collectors.get(request_id) |
| 46 | + return self._response_queues.get(request_id) |
| 47 | + |
| 48 | + def remove(self, request_id: UUID): |
| 49 | + with self._collectors_lock: |
| 50 | + del self._response_queues[request_id] |
48 | 51 |
|
49 |
| - def close(self, error: Optional[RpcError]): |
| 52 | + def close(self, error: Optional[TypeDBClientException]): |
50 | 53 | with self._collectors_lock:
|
51 |
| - for collector in self._collectors.values(): |
| 54 | + for collector in self._response_queues.values(): |
52 | 55 | collector.close(error)
|
53 | 56 |
|
54 |
| - def drain_errors(self) -> [RpcError]: |
| 57 | + def get_errors(self) -> [TypeDBClientException]: |
55 | 58 | errors = []
|
56 | 59 | with self._collectors_lock:
|
57 |
| - for collector in self._collectors.values(): |
58 |
| - errors.extend(collector.drain_errors()) |
| 60 | + for collector in self._response_queues.values(): |
| 61 | + error = collector.get_error() |
| 62 | + if error is not None: |
| 63 | + errors.append(error) |
59 | 64 | return errors
|
60 | 65 |
|
61 |
| - |
62 | 66 | class Queue(Generic[R]):
|
63 | 67 |
|
64 | 68 | def __init__(self):
|
65 |
| - self._response_queue: queue.Queue[Union[Response[R], Done]] = queue.Queue() |
| 69 | + self._response_queue: queue.Queue[Response] = queue.Queue() |
| 70 | + self._error: TypeDBClientException = None |
66 | 71 |
|
67 | 72 | def get(self, block: bool) -> R:
|
68 | 73 | response = self._response_queue.get(block=block)
|
69 |
| - if response.message: |
70 |
| - return response.message |
71 |
| - elif response.error: |
72 |
| - raise TypeDBClientException.of_rpc(response.error) |
73 |
| - else: |
| 74 | + if response.is_value(): |
| 75 | + return response.value |
| 76 | + elif response.is_done() and self._error is None: |
74 | 77 | raise TypeDBClientException.of(TRANSACTION_CLOSED)
|
| 78 | + elif response.is_done() and self._error is not None: |
| 79 | + raise TypeDBClientException.of(TRANSACTION_CLOSED_WITH_ERRORS, self._error) |
| 80 | + else: |
| 81 | + raise TypeDBClientException.of(ILLEGAL_STATE) |
75 | 82 |
|
76 | 83 | def put(self, response: R):
|
77 |
| - self._response_queue.put(Response(response)) |
| 84 | + self._response_queue.put(ValueResponse(response)) |
78 | 85 |
|
79 |
| - def close(self, error: Optional[RpcError]): |
80 |
| - self._response_queue.put(Done(error)) |
| 86 | + def close(self, error: Optional[TypeDBClientException]): |
| 87 | + self._error = error |
| 88 | + self._response_queue.put(DoneResponse()) |
81 | 89 |
|
82 |
| - def drain_errors(self) -> [RpcError]: |
83 |
| - errors = [] |
84 |
| - while not self._response_queue.empty(): |
85 |
| - response = self._response_queue.get(block = False) |
86 |
| - if response.error: |
87 |
| - errors.append(response.error) |
88 |
| - return errors |
| 90 | + def get_error(self) -> TypeDBClientException: |
| 91 | + return self._error |
89 | 92 |
|
90 | 93 |
|
| 94 | +class Response: |
91 | 95 |
|
92 |
| -class Response(Generic[R]): |
| 96 | + def is_value(self): |
| 97 | + return False |
| 98 | + |
| 99 | + def is_done(self): |
| 100 | + return False |
| 101 | + |
| 102 | + |
| 103 | +class ValueResponse(Response, Generic[R]): |
93 | 104 |
|
94 | 105 | def __init__(self, value: R):
|
95 |
| - self.message = value |
| 106 | + self.value = value |
96 | 107 |
|
| 108 | + def is_value(self): |
| 109 | + return True |
97 | 110 |
|
98 |
| -class Done: |
99 | 111 |
|
100 |
| - def __init__(self, error: Optional[RpcError]): |
101 |
| - self.error = error |
| 112 | +class DoneResponse(Response): |
| 113 | + |
| 114 | + def __init__(self): |
| 115 | + pass |
| 116 | + |
| 117 | + def is_done(self): |
| 118 | + return True |
0 commit comments