Skip to content

Commit c592437

Browse files
committed
Couple of session hooks
1 parent 066b458 commit c592437

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

neo4j/bolt/connection.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -343,12 +343,18 @@ def send(self):
343343
self.channel.send()
344344

345345
def fetch(self):
346-
""" Receive exactly one message from the server.
346+
""" Receive exactly one message from the server
347+
(if one is available).
348+
349+
:return: number of messages fetched (zero or one)
347350
"""
348351
if self.closed:
349352
raise ServiceUnavailable("Failed to read from closed connection %r" % (self.server.address,))
350353
if self.defunct:
351354
raise ServiceUnavailable("Failed to read from defunct connection %r" % (self.server.address,))
355+
if not self.responses:
356+
return 0
357+
352358
try:
353359
message_data = self.buffering_socket.read_message()
354360
except ProtocolError:
@@ -388,6 +394,8 @@ def fetch(self):
388394
else:
389395
raise ProtocolError("Unexpected response message with signature %02X" % signature)
390396

397+
return 1
398+
391399
def sync(self):
392400
""" Send and fetch all outstanding messages.
393401
"""

neo4j/v1/session.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ class Session(object):
3333
method.
3434
"""
3535

36+
response_class = Response
37+
38+
result_class = None
39+
3640
transaction = None
3741

3842
last_bookmark = None
@@ -66,9 +70,9 @@ def run(self, statement, parameters=None, **kwparameters):
6670
statement = _norm_statement(statement)
6771
parameters = _norm_parameters(parameters, **kwparameters)
6872

69-
run_response = Response(self.connection)
70-
pull_all_response = Response(self.connection)
71-
result = StatementResult(self, run_response, pull_all_response)
73+
run_response = self.response_class(self.connection)
74+
pull_all_response = self.response_class(self.connection)
75+
result = self.result_class(self, run_response, pull_all_response)
7276
result.statement = statement
7377
result.parameters = parameters
7478

@@ -79,6 +83,9 @@ def run(self, statement, parameters=None, **kwparameters):
7983
return result
8084

8185
def fetch(self):
86+
""" Fetch the next message if available and return
87+
the number of messages fetched (one or zero).
88+
"""
8289
try:
8390
return self.connection.fetch()
8491
except ServiceUnavailable as cause:
@@ -331,6 +338,9 @@ def peek(self):
331338
raise ResultError("End of stream")
332339

333340

341+
Session.result_class = StatementResult
342+
343+
334344
class Record(object):
335345
""" Record is an ordered collection of fields.
336346

0 commit comments

Comments
 (0)