Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions boltstub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .parsing import (
parse_file,
Script,
ScriptDeviation,
ScriptFailure,
)
from .wiring import (
ReadWakeup,
Expand Down Expand Up @@ -122,7 +122,7 @@ def handle(self):
log.info("[#%04X>#%04X] S: <EXIT> %s",
self.client_address.port_number,
self.server_address.port_number, e)
except ScriptDeviation as e:
except ScriptFailure as e:
e.script = script
service.exceptions.append(e)
except Exception as e:
Expand Down
11 changes: 5 additions & 6 deletions boltstub/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@

from argparse import ArgumentParser
from logging import getLogger, INFO

from . import (
BoltStubService,
ScriptDeviation,
from . import BoltStubService
from .parsing import (
parse_file,
ScriptFailure,
)
from .parsing import parse_file
from .watcher import watch


Expand Down Expand Up @@ -115,7 +114,7 @@ def _main():
extra = ""
if hasattr(error, 'script') and error.script.filename:
extra += " in {!r}".format(error.script.filename)
if isinstance(error, ScriptDeviation):
if isinstance(error, ScriptFailure):
print("Script mismatch{}:\n{}\n".format(extra, error))
else:
print("Error{}:\n{}\n".format(extra, error))
Expand Down
75 changes: 30 additions & 45 deletions boltstub/bolt_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def __eq__(self, other):
class BoltProtocol:
protocol_version = None
version_aliases = set()
# allow the server to negotiate other bolt versions
equivalent_versions = set()

messages = {
"C": {},
Expand Down Expand Up @@ -119,6 +121,8 @@ class Bolt1Protocol(BoltProtocol):

protocol_version = (1, 0)
version_aliases = {(1,), (3, 0), (3, 1), (3, 2), (3, 3)}
# allow the server to negotiate other bolt versions
equivalent_versions = set()

messages = {
"C": {
Expand All @@ -141,7 +145,8 @@ class Bolt1Protocol(BoltProtocol):

@classmethod
def decode_versions(cls, b):
# ignore minor versions and ranges
# only major version is supported
# ignore all but last byte
masked = bytes(0 if i % 4 != 3 else b[i]
for i in range(len(b)))
return BoltProtocol.decode_versions(masked)
Expand All @@ -160,6 +165,8 @@ class Bolt2Protocol(Bolt1Protocol):

protocol_version = (2, 0)
version_aliases = {(2,), (3, 4)}
# allow the server to negotiate other bolt versions
equivalent_versions = set()

server_agent = "Neo4j/3.4.0"

Expand All @@ -177,6 +184,8 @@ class Bolt3Protocol(Bolt2Protocol):

protocol_version = (3, 0)
version_aliases = {(3,), (3, 5), (3, 6)}
# allow the server to negotiate other bolt versions
equivalent_versions = set()

messages = {
"C": {
Expand Down Expand Up @@ -215,6 +224,8 @@ class Bolt4x0Protocol(Bolt3Protocol):

protocol_version = (4, 0)
version_aliases = {(4,)}
# allow the server to negotiate other bolt versions
equivalent_versions = set()

messages = {
"C": {
Expand All @@ -240,10 +251,9 @@ class Bolt4x0Protocol(Bolt3Protocol):

@classmethod
def decode_versions(cls, b):
# minor version were introduced
# range support was backported from 4.3
# ignore reserved byte
masked = bytes(0 if i % 4 == 0 else b[i]
# minor version was introduced
# ignore first two bytes
masked = bytes(0 if i % 4 <= 1 else b[i]
for i in range(len(b)))
return BoltProtocol.decode_versions(masked)

Expand All @@ -261,6 +271,9 @@ def get_auto_response(cls, request: TranslatedStructure):
class Bolt4x1Protocol(Bolt4x0Protocol):

protocol_version = (4, 1)
version_aliases = set()
# allow the server to negotiate other bolt versions
equivalent_versions = set()

messages = {
"C": {
Expand Down Expand Up @@ -299,44 +312,19 @@ def get_auto_response(cls, request: TranslatedStructure):
class Bolt4x2Protocol(Bolt4x1Protocol):

protocol_version = (4, 2)

messages = {
"C": {
b"\x01": "HELLO",
b"\x02": "GOODBYE",
b"\x0F": "RESET",
b"\x10": "RUN",
b"\x11": "BEGIN",
b"\x12": "COMMIT",
b"\x13": "ROLLBACK",
b"\x2F": "DISCARD",
b"\x3F": "PULL",
},
"S": {
b"\x70": "SUCCESS",
b"\x71": "RECORD",
b"\x7E": "IGNORED",
b"\x7F": "FAILURE",
},
}
version_aliases = set()
# allow the server to negotiate other bolt versions
equivalent_versions = {(4, 1)}

server_agent = "Neo4j/4.2.0"

@classmethod
def get_auto_response(cls, request: TranslatedStructure):
if request.tag == b"\x01":
return TranslatedStructure("SUCCESS", b"\x70", {
"connection_id": "bolt-0",
"server": cls.server_agent,
"routing": None,
})
else:
return TranslatedStructure("SUCCESS", b"\x70", {})


class Bolt4x3Protocol(Bolt4x2Protocol):

protocol_version = (4, 3)
version_aliases = set()
# allow the server to negotiate other bolt versions
equivalent_versions = set()

messages = {
"C": {
Expand All @@ -362,12 +350,9 @@ class Bolt4x3Protocol(Bolt4x2Protocol):
server_agent = "Neo4j/4.3.0"

@classmethod
def get_auto_response(cls, request: TranslatedStructure):
if request.tag == b"\x01":
return TranslatedStructure("SUCCESS", b"\x70", {
"connection_id": "bolt-0",
"server": cls.server_agent,
"routing": None,
})
else:
return TranslatedStructure("SUCCESS", b"\x70", {})
def decode_versions(cls, b):
# minor version ranges were introduced
# ignore first byte
masked = bytes(0 if i % 4 == 0 else b[i]
for i in range(len(b)))
return BoltProtocol.decode_versions(masked)
31 changes: 22 additions & 9 deletions boltstub/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .bolt_protocol import get_bolt_protocol
from .errors import ServerExit
from .packstream import PackStream
from .parsing import ServerLine
from .parsing import ScriptFailure
from .util import hex_repr


Expand Down Expand Up @@ -37,19 +37,32 @@ def version_handshake(self):
# Check that the server protocol version is among the ones supported
# by the driver.
supported_version = self.bolt_protocol.protocol_version
requested_versions = self.bolt_protocol.decode_versions(request)
requested_versions = set(
self.bolt_protocol.decode_versions(request)
)
if supported_version in requested_versions:
response = bytes(
(0, 0, supported_version[1], supported_version[0])
)
else:
self.wire.write(b"\x00\x00\x00\x00")
self.wire.send()
raise ServerExit(
"Failed handshake, stub server talks protocol {}. "
"Driver sent handshake: {}".format(supported_version,
hex_repr(request))
)
fallback_versions = (requested_versions
& self.bolt_protocol.equivalent_versions)
if fallback_versions:
version = sorted(fallback_versions, reverse=True)[0]
response = bytes((0, 0, version[1], version[0]))
else:
try:
self._log("S: <HANDSHAKE> %s",
hex_repr(b"\x00\x00\x00\x00"))
self.wire.write(b"\x00\x00\x00\x00")
self.wire.send()
except OSError:
pass
raise ScriptFailure(
"Failed handshake, stub server talks protocol {}. "
"Driver sent handshake: {}".format(supported_version,
hex_repr(request))
)
self.wire.write(response)
self.wire.send()
self._log("S: <HANDSHAKE> %s", hex_repr(response))
Expand Down
6 changes: 5 additions & 1 deletion boltstub/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,11 @@ def server_lines(self):
yield from block.server_lines


class ScriptDeviation(RuntimeError):
class ScriptFailure(RuntimeError):
pass


class ScriptDeviation(ScriptFailure):
def __init__(self, expected_lines: List[Line], received: Line):
assert expected_lines
self.expected_lines = expected_lines
Expand Down
106 changes: 61 additions & 45 deletions boltstub/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import pytest

from .. import BoltStubService
from ..parsing import parse
from ..parsing import (
parse,
ScriptFailure,
)
from ..util import hex_repr
from ._common import (
ALL_REQUESTS_PER_VERSION,
Expand Down Expand Up @@ -149,49 +152,57 @@ def test_magic_bytes(server_version, magic_bytes, server_factory, fail,
assert not server.service.exceptions


@pytest.mark.parametrize(("client_version", "server_version", "matches"), [
[b"\x00\x00\x00\x01", (1,), True],
[b"\x00\x00\x00\x01", (1, 0), True],
[b"\x00\x00\x00\x01\x00\x00\x00\x02", (1, 0), True],
[b"\x00\x00\x00\x01\x00\x00\x00\x02", (2, 0), True],
[b"\x00\x00\x00\x01\x00\x00\x00\x03", (2,), False],
[b"\x00\x00\x00\x03", (3,), True],
[b"\x00\x00\x00\x04", (4,), True],
[b"\x00\x00\x01\x04", (4, 1), True],
[b"\x00\x00\x02\x04", (4, 2), True],
[b"\x00\x00\x03\x04\x00\x00\x02\x04\x00\x00\x01\x04", (4, 2), True],
[b"\x00\x00\x03\x04", (4, 3), True],
[b"\x00\x02\x03\x04", (4, 3), True],
@pytest.mark.parametrize(("client_version", "server_version",
"negotiated_version"), [
[b"\x00\x00\x00\x01", (1,), (1,)],
[b"\x00\x00\x00\x01", (1, 0), (1, 0)],
[b"\x00\x00\x00\x01\x00\x00\x00\x02", (1, 0), (1, 0)],
[b"\x00\x00\x00\x01\x00\x00\x00\x02", (2, 0), (2, 0)],
[b"\x00\x00\x00\x01\x00\x00\x00\x03", (2,), None],
[b"\x00\x00\x00\x03", (3,), (3,)],
[b"\x00\x00\x00\x04", (4,), (4,)],
[b"\x00\x00\x01\x04", (4, 1), (4, 1)],
[b"\x00\x00\x02\x04", (4, 2), (4, 2)],
[b"\x00\x00\x03\x04\x00\x00\x02\x04\x00\x00\x01\x04", (4, 2), (4, 2)],
[b"\x00\x00\x03\x04", (4, 3), (4, 3)],
[b"\x00\x02\x03\x04", (4, 3), (4, 3)],
[b"\x00\x03\x03\x04\x00\x00\x02\x04\x00\x00\x01\x04\x00\x00\x00\x03",
(4, 0), True],
(4, 0), None],
[b"\x00\x01\x03\x04\x00\x00\x01\x04\x00\x00\x00\x04\x00\x00\x00\x03",
(4, 0), (4, 0)],
# ignore minor versions until 4.0
[b"\x00\x00\x10\x01", (1,), True],
[b"\x00\x00\x10\x02", (2,), True],
[b"\x00\x00\x10\x03", (3,), True],
[b"\x00\x00\x10\x04", (4, 0), False],
[b"\x00\x00\x10\x04", (4, 1), False],
[b"\x00\x00\x10\x04", (4, 2), False],
[b"\x00\x00\x10\x04", (4, 3), False],
# ignore version ranges until 4.0
[b"\x00\x09\x0A\x01", (1,), True],
[b"\x00\x0A\x0A\x01", (1,), True],
[b"\x00\x09\x0A\x02", (2,), True],
[b"\x00\x0A\x0A\x02", (2,), True],
[b"\x00\x09\x0A\x03", (3,), True],
[b"\x00\x0A\x0A\x03", (3,), True],
[b"\x00\x0A\x0A\x04", (4, 0), True],
[b"\x00\x09\x0A\x04", (4, 0), False],
[b"\x00\x0A\x0A\x04", (4, 1), True],
[b"\x00\x09\x0A\x04", (4, 1), True],
[b"\x00\x08\x0A\x04", (4, 1), False],
[b"\x00\x09\x0A\x04", (4, 2), True],
[b"\x00\x08\x0A\x04", (4, 2), True],
[b"\x00\x07\x0A\x04", (4, 2), False],
[b"\x00\x08\x0A\x04", (4, 3), True],
[b"\x00\x07\x0A\x04", (4, 3), True],
[b"\x00\x06\x0A\x04", (4, 3), False],
[b"\x00\x00\x10\x01", (1,), (1,)],
[b"\x00\x00\x10\x02", (2,), (2,)],
[b"\x00\x00\x10\x03", (3,), (3,)],
[b"\x00\x00\x10\x04", (4, 0), None],
[b"\x00\x00\x10\x04", (4, 1), None],
[b"\x00\x00\x10\x04", (4, 2), None],
[b"\x00\x00\x10\x04", (4, 3), None],
# ignore version ranges until 4.3
[b"\x00\x09\x0A\x01", (1,), (1,)],
[b"\x00\x0A\x0A\x01", (1,), (1,)],
[b"\x00\x09\x0A\x02", (2,), (2,)],
[b"\x00\x0A\x0A\x02", (2,), (2,)],
[b"\x00\x09\x0A\x03", (3,), (3,)],
[b"\x00\x0A\x0A\x03", (3,), (3,)],
[b"\x00\x0A\x0A\x04", (4, 0), None],
[b"\x00\x09\x0A\x04", (4, 0), None],
[b"\x00\x0A\x0A\x04", (4, 1), None],
[b"\x00\x09\x0A\x04", (4, 1), None],
[b"\x00\x08\x0A\x04", (4, 1), None],
[b"\x00\x09\x0A\x04", (4, 2), None],
[b"\x00\x08\x0A\x04", (4, 2), None],
[b"\x00\x07\x0A\x04", (4, 2), None],
[b"\x00\x08\x0A\x04", (4, 3), (4, 3)],
[b"\x00\x07\x0A\x04", (4, 3), (4, 3)],
[b"\x00\x06\x0A\x04", (4, 3), None],
# special backwards compatibility
# (4.2 server allows to fall back to equivalent 4.1 protocol)
[b"\x00\x00\x01\x04", (4, 2), (4, 1)],
[b"\x00\x00\x02\x04", (4, 1), None],
[b"\x00\x00\x02\x04", (4, 3), None],
])
def test_handshake_auto(client_version, server_version, matches,
def test_handshake_auto(client_version, server_version, negotiated_version,
server_factory, connection_factory):
client_version = client_version + b"\x00" * (16 - len(client_version))

Expand All @@ -205,13 +216,18 @@ def test_handshake_auto(client_version, server_version, matches,
con = connection_factory("localhost", 7687)
con.write(b"\x60\x60\xb0\x17")
con.write(client_version)
if matches:
assert con.read(4) == server_version_to_version_response(server_version)
if negotiated_version is not None:
assert (con.read(4)
== server_version_to_version_response(negotiated_version))
else:
assert con.read(4) == b"\x00" * 4
with pytest.raises(BrokenSocket):
print(con.read(1))
assert not server.service.exceptions
if negotiated_version is None:
assert len(server.service.exceptions) == 1
assert isinstance(server.service.exceptions[0], ScriptFailure)
else:
assert not server.service.exceptions


@pytest.mark.parametrize("custom_handshake", [b"\x00\x00\xFF\x00", b"foobar"])
Expand Down Expand Up @@ -264,7 +280,7 @@ def test_auto_replies(server_version, request_tag, request_name,
con.read(1)
assert not server.service.exceptions

I = 5 # [I:(I + 1)]

@pytest.mark.parametrize(("server_version", "request_tag", "request_name",
"field_rep", "field_bin"), (
*((version, tag, name, rep, fields)
Expand Down
Loading