Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions neo4j/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def hydrate(cls, message=None, code=None, **metadata):
code = code or "Neo.DatabaseError.General.UnknownError"
try:
_, classification, category, title = code.split(".")
if code == "Neo.ClientError.Security.AuthorizationExpired":
classification = CLASSIFICATION_TRANSIENT
except ValueError:
classification = CLASSIFICATION_DATABASE
category = "General"
Expand Down Expand Up @@ -124,6 +126,9 @@ def _extract_error_class(cls, classification, code):
else:
return cls

def invalidates_all_connections(self):
return self.code == "Neo.ClientError.Security.AuthorizationExpired"

def __str__(self):
return "{{code: {code}}} {{message: {message}}}".format(code=self.code, message=self.message)

Expand Down
227 changes: 210 additions & 17 deletions neo4j/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
"check_supported_server_product",
]


import abc
from collections import deque
from logging import getLogger
from random import choice
Expand Down Expand Up @@ -69,6 +69,7 @@
from neo4j.addressing import Address
from neo4j.api import (
READ_ACCESS,
ServerInfo,
Version,
WRITE_ACCESS,
)
Expand All @@ -77,23 +78,36 @@
WorkspaceConfig,
)
from neo4j.exceptions import (
AuthError,
ClientError,
ConfigurationError,
DriverError,
IncompleteCommit,
ReadServiceUnavailable,
ServiceUnavailable,
SessionExpired,
UnsupportedServerProduct,
WriteServiceUnavailable,
)
from neo4j.io._common import (
CommitResponse,
Inbox,
InitResponse,
Outbox,
Response,
)
from neo4j.meta import get_user_agent
from neo4j.packstream import (
Packer,
Unpacker,
)
from neo4j.routing import RoutingTable


# Set up logger
log = getLogger("neo4j")


class Bolt:
class Bolt(abc.ABC):
""" Server connection for Bolt protocol.

A :class:`.Bolt` should be constructed following a
Expand All @@ -107,6 +121,69 @@ class Bolt:

PROTOCOL_VERSION = None

# The socket
in_use = False

# The socket
_closed = False

# The socket
_defunct = False

#: The pool of which this connection is a member
pool = None

def __init__(self, unresolved_address, sock, max_connection_lifetime, *, auth=None, user_agent=None, routing_context=None):
self.unresolved_address = unresolved_address
self.socket = sock
self.server_info = ServerInfo(Address(sock.getpeername()), self.PROTOCOL_VERSION)
self.outbox = Outbox()
self.inbox = Inbox(self.socket, on_error=self._set_defunct_read)
self.packer = Packer(self.outbox)
self.unpacker = Unpacker(self.inbox)
self.responses = deque()
self._max_connection_lifetime = max_connection_lifetime
self._creation_timestamp = perf_counter()
self._is_reset = True
self.routing_context = routing_context

# Determine the user agent
if user_agent:
self.user_agent = user_agent
else:
self.user_agent = get_user_agent()

# Determine auth details
if not auth:
self.auth_dict = {}
elif isinstance(auth, tuple) and 2 <= len(auth) <= 3:
from neo4j import Auth
self.auth_dict = vars(Auth("basic", *auth))
else:
try:
self.auth_dict = vars(auth)
except (KeyError, TypeError):
raise AuthError("Cannot determine auth details from %r" % auth)

# Check for missing password
try:
credentials = self.auth_dict["credentials"]
except KeyError:
pass
else:
if credentials is None:
raise AuthError("Password cannot be None")

@property
@abc.abstractmethod
def supports_multiple_results(self):
pass

@property
@abc.abstractmethod
def supports_multiple_databases(self):
pass

@classmethod
def protocol_handlers(cls, protocol_version=None):
""" Return a dictionary of available Bolt protocol handlers,
Expand Down Expand Up @@ -258,26 +335,31 @@ def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_
return connection

@property
@abc.abstractmethod
def encrypted(self):
raise NotImplementedError
pass

@property
@abc.abstractmethod
def der_encoded_server_certificate(self):
raise NotImplementedError
pass

@property
@abc.abstractmethod
def local_port(self):
raise NotImplementedError
pass

@abc.abstractmethod
def hello(self):
raise NotImplementedError
pass

def __del__(self):
try:
self.close()
except OSError:
pass

@abc.abstractmethod
def route(self, database=None, bookmarks=None):
""" Fetch a routing table from the server for the given
`database`. For Bolt 4.3 and above, this appends a ROUTE
Expand All @@ -290,7 +372,9 @@ def route(self, database=None, bookmarks=None):
transaction should begin
:return: dictionary of raw routing data
"""
pass

@abc.abstractmethod
def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
timeout=None, db=None, **handlers):
""" Appends a RUN message to the output stream.
Expand All @@ -305,7 +389,9 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
:param handlers: handler functions passed into the returned Response object
:return: Response object
"""
pass

@abc.abstractmethod
def discard(self, n=-1, qid=-1, **handlers):
""" Appends a DISCARD message to the output stream.

Expand All @@ -314,7 +400,9 @@ def discard(self, n=-1, qid=-1, **handlers):
:param handlers: handler functions passed into the returned Response object
:return: Response object
"""
pass

@abc.abstractmethod
def pull(self, n=-1, qid=-1, **handlers):
""" Appends a PULL message to the output stream.

Expand All @@ -323,7 +411,9 @@ def pull(self, n=-1, qid=-1, **handlers):
:param handlers: handler functions passed into the returned Response object
:return: Response object
"""
pass

@abc.abstractmethod
def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers):
""" Appends a BEGIN message to the output stream.

Expand All @@ -335,42 +425,139 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None,
:param handlers: handler functions passed into the returned Response object
:return: Response object
"""
pass

@abc.abstractmethod
def commit(self, **handlers):
raise NotImplementedError
pass

@abc.abstractmethod
def rollback(self, **handlers):
raise NotImplementedError
pass

@abc.abstractmethod
def reset(self):
""" Add a RESET message to the outgoing queue, send
it and consume all remaining messages.
"""
raise NotImplementedError
pass

def _append(self, signature, fields=(), response=None):
""" Add a message to the outgoing queue.

:arg signature: the signature of the message
:arg fields: the fields of the message as a tuple
:arg response: a response object to handle callbacks
"""
self.packer.pack_struct(signature, fields)
self.outbox.chunk()
self.outbox.chunk()
self.responses.append(response)

def _send_all(self):
data = self.outbox.view()
if data:
try:
self.socket.sendall(data)
except OSError as error:
self._set_defunct_write(error)
self.outbox.clear()

def send_all(self):
""" Send all queued messages to the server.
"""
raise NotImplementedError
if self.closed():
raise ServiceUnavailable("Failed to write to closed connection {!r} ({!r})".format(
self.unresolved_address, self.server_info.address))

if self.defunct():
raise ServiceUnavailable("Failed to write to defunct connection {!r} ({!r})".format(
self.unresolved_address, self.server_info.address))

self._send_all()

@abc.abstractmethod
def fetch_message(self):
""" Receive at least one message from the server, if available.

:return: 2-tuple of number of detail messages and number of summary
messages fetched
"""
raise NotImplementedError

def timedout(self):
raise NotImplementedError
pass

def fetch_all(self):
""" Fetch all outstanding messages.

:return: 2-tuple of number of detail messages and number of summary
messages fetched
"""
raise NotImplementedError
detail_count = summary_count = 0
while self.responses:
response = self.responses[0]
while not response.complete:
detail_delta, summary_delta = self.fetch_message()
detail_count += detail_delta
summary_count += summary_delta
return detail_count, summary_count

def _set_defunct_read(self, error=None, silent=False):
message = "Failed to read from defunct connection {!r} ({!r})".format(
self.unresolved_address, self.server_info.address
)
self._set_defunct(message, error=error, silent=silent)

def _set_defunct_write(self, error=None, silent=False):
message = "Failed to write data to connection {!r} ({!r})".format(
self.unresolved_address, self.server_info.address
)
self._set_defunct(message, error=error, silent=silent)

def _set_defunct(self, message, error=None, silent=False):
direct_driver = isinstance(self.pool, BoltPool)

if error:
log.error(str(error))
log.error(message)
# We were attempting to receive data but the connection
# has unexpectedly terminated. So, we need to close the
# connection from the client side, and remove the address
# from the connection pool.
self._defunct = True
self.close()
if self.pool:
self.pool.deactivate(address=self.unresolved_address)
# Iterate through the outstanding responses, and if any correspond
# to COMMIT requests then raise an error to signal that we are
# unable to confirm that the COMMIT completed successfully.
if silent:
return
for response in self.responses:
if isinstance(response, CommitResponse):
if error:
raise IncompleteCommit(message) from error
else:
raise IncompleteCommit(message)

if direct_driver:
if error:
raise ServiceUnavailable(message) from error
else:
raise ServiceUnavailable(message)
else:
if error:
raise SessionExpired(message) from error
else:
raise SessionExpired(message)

def stale(self):
return (self._stale
or (0 <= self._max_connection_lifetime
<= perf_counter()- self._creation_timestamp))

_stale = False

def set_stale(self):
self._stale = True

def close(self):
""" Close the connection.
Expand Down Expand Up @@ -430,7 +617,7 @@ def time_remaining():
while True:
# try to find a free connection in pool
for connection in list(connections):
if connection.closed() or connection.defunct() or connection.timedout():
if connection.closed() or connection.defunct() or connection.stale():
connections.remove(connection)
continue
if not connection.in_use:
Expand Down Expand Up @@ -497,6 +684,12 @@ def in_use_connection_count(self, address):
else:
return sum(1 if connection.in_use else 0 for connection in connections)

def mark_all_stale(self):
with self.lock:
for address in self.connections:
for connection in self.connections[address]:
connection.set_stale()

def deactivate(self, address):
""" Deactivate an address from the connection pool, if present, closing
all idle connection to that address
Expand Down
Loading