32
32
from select import select
33
33
from socket import socket , SOL_SOCKET , SO_KEEPALIVE , SHUT_RDWR , error as SocketError , timeout as SocketTimeout , AF_INET , AF_INET6
34
34
from struct import pack as struct_pack , unpack as struct_unpack
35
- from threading import RLock
35
+ from threading import RLock , Condition
36
36
37
37
from neo4j .addressing import SocketAddress , is_ip_address
38
38
from neo4j .bolt .cert import KNOWN_HOSTS
39
39
from neo4j .bolt .response import InitResponse , AckFailureResponse , ResetResponse
40
40
from neo4j .compat .ssl import SSL_AVAILABLE , HAS_SNI , SSLError
41
- from neo4j .exceptions import ProtocolError , SecurityError , ServiceUnavailable
41
+ from neo4j .exceptions import ClientError , ProtocolError , SecurityError , ServiceUnavailable
42
42
from neo4j .meta import version
43
43
from neo4j .packstream import Packer , Unpacker
44
44
from neo4j .util import import_best as _import_best
45
+ from time import clock
45
46
46
47
ChunkedInputBuffer = _import_best ("neo4j.bolt._io" , "neo4j.bolt.io" ).ChunkedInputBuffer
47
48
ChunkedOutputBuffer = _import_best ("neo4j.bolt._io" , "neo4j.bolt.io" ).ChunkedOutputBuffer
48
49
49
50
51
+ INFINITE = - 1
52
+ DEFAULT_MAX_CONNECTION_LIFETIME = INFINITE
53
+ DEFAULT_MAX_CONNECTION_POOL_SIZE = INFINITE
50
54
DEFAULT_CONNECTION_TIMEOUT = 5.0
55
+ DEFAULT_CONNECTION_ACQUISITION_TIMEOUT = 60
51
56
DEFAULT_PORT = 7687
52
57
DEFAULT_USER_AGENT = "neo4j-python/%s" % version
53
58
@@ -178,6 +183,8 @@ def __init__(self, address, sock, error_handler, **config):
178
183
self .packer = Packer (self .output_buffer )
179
184
self .unpacker = Unpacker ()
180
185
self .responses = deque ()
186
+ self ._max_connection_lifetime = config .get ("max_connection_lifetime" , DEFAULT_MAX_CONNECTION_LIFETIME )
187
+ self ._creation_timestamp = clock ()
181
188
182
189
# Determine the user agent and ensure it is a Unicode value
183
190
user_agent = config .get ("user_agent" , DEFAULT_USER_AGENT )
@@ -201,6 +208,7 @@ def __init__(self, address, sock, error_handler, **config):
201
208
# Pick up the server certificate, if any
202
209
self .der_encoded_server_certificate = config .get ("der_encoded_server_certificate" )
203
210
211
+ def Init (self ):
204
212
response = InitResponse (self )
205
213
self .append (INIT , (self .user_agent , self .auth_dict ), response = response )
206
214
self .sync ()
@@ -360,6 +368,9 @@ def _unpack(self):
360
368
more = False
361
369
return details , summary_signature , summary_metadata
362
370
371
+ def timedout (self ):
372
+ return 0 <= self ._max_connection_lifetime <= clock () - self ._creation_timestamp
373
+
363
374
def sync (self ):
364
375
""" Send and fetch all outstanding messages.
365
376
@@ -396,11 +407,14 @@ class ConnectionPool(object):
396
407
397
408
_closed = False
398
409
399
- def __init__ (self , connector , connection_error_handler ):
410
+ def __init__ (self , connector , connection_error_handler , ** config ):
400
411
self .connector = connector
401
412
self .connection_error_handler = connection_error_handler
402
413
self .connections = {}
403
414
self .lock = RLock ()
415
+ self .cond = Condition (self .lock )
416
+ self ._max_connection_pool_size = config .get ("max_connection_pool_size" , DEFAULT_MAX_CONNECTION_POOL_SIZE )
417
+ self ._connection_acquisition_timeout = config .get ("connection_acquisition_timeout" , DEFAULT_CONNECTION_ACQUISITION_TIMEOUT )
404
418
405
419
def __enter__ (self ):
406
420
return self
@@ -424,23 +438,42 @@ def acquire_direct(self, address):
424
438
connections = self .connections [address ]
425
439
except KeyError :
426
440
connections = self .connections [address ] = deque ()
427
- for connection in list (connections ):
428
- if connection .closed () or connection .defunct ():
429
- connections .remove (connection )
430
- continue
431
- if not connection .in_use :
432
- connection .in_use = True
433
- return connection
434
- try :
435
- connection = self .connector (address , self .connection_error_handler )
436
- except ServiceUnavailable :
437
- self .remove (address )
438
- raise
439
- else :
440
- connection .pool = self
441
- connection .in_use = True
442
- connections .append (connection )
443
- return connection
441
+
442
+ connection_acquisition_start_timestamp = clock ()
443
+ while True :
444
+ # try to find a free connection in pool
445
+ for connection in list (connections ):
446
+ if connection .closed () or connection .defunct () or connection .timedout ():
447
+ connections .remove (connection )
448
+ continue
449
+ if not connection .in_use :
450
+ connection .in_use = True
451
+ return connection
452
+ # all connections in pool are in-use
453
+ can_create_new_connection = self ._max_connection_pool_size == INFINITE or len (connections ) < self ._max_connection_pool_size
454
+ if can_create_new_connection :
455
+ try :
456
+ connection = self .connector (address , self .connection_error_handler )
457
+ except ServiceUnavailable :
458
+ self .remove (address )
459
+ raise
460
+ else :
461
+ connection .pool = self
462
+ connection .in_use = True
463
+ connections .append (connection )
464
+ return connection
465
+
466
+ # failed to obtain a connection from pool because the pool is full and no free connection in the pool
467
+ span_timeout = self ._connection_acquisition_timeout - (clock () - connection_acquisition_start_timestamp )
468
+ if span_timeout > 0 :
469
+ self .cond .wait (span_timeout )
470
+ # if timed out, then we throw error. This time computation is needed, as with python 2.7, we cannot
471
+ # tell if the condition is notified or timed out when we come to this line
472
+ if self ._connection_acquisition_timeout <= (clock () - connection_acquisition_start_timestamp ):
473
+ raise ClientError ("Failed to obtain a connection from pool within {!r}s" .format (
474
+ self ._connection_acquisition_timeout ))
475
+ else :
476
+ raise ClientError ("Failed to obtain a connection from pool within {!r}s" .format (self ._connection_acquisition_timeout ))
444
477
445
478
def acquire (self , access_mode = None ):
446
479
""" Acquire a connection to a server that can satisfy a set of parameters.
@@ -454,6 +487,7 @@ def release(self, connection):
454
487
"""
455
488
with self .lock :
456
489
connection .in_use = False
490
+ self .cond .notify_all ()
457
491
458
492
def in_use_connection_count (self , address ):
459
493
""" Count the number of connections currently in use to a given
@@ -600,8 +634,10 @@ def connect(address, ssl_context=None, error_handler=None, **config):
600
634
s .shutdown (SHUT_RDWR )
601
635
s .close ()
602
636
elif agreed_version == 1 :
603
- return Connection (address , s , der_encoded_server_certificate = der_encoded_server_certificate ,
637
+ connection = Connection (address , s , der_encoded_server_certificate = der_encoded_server_certificate ,
604
638
error_handler = error_handler , ** config )
639
+ connection .Init ()
640
+ return connection
605
641
elif agreed_version == 0x48545450 :
606
642
log_error ("S: [CLOSE]" )
607
643
s .close ()
0 commit comments