diff --git a/msgpackrpc/error.py b/msgpackrpc/error.py index ae2cd74..36d0831 100644 --- a/msgpackrpc/error.py +++ b/msgpackrpc/error.py @@ -23,6 +23,10 @@ class TransportError(RPCError): CODE = ".TransportError" pass +class SessionError(RPCError): + CODE = ".SessionError" + pass + class CallError(RPCError): CODE = ".NoMethodError" pass diff --git a/msgpackrpc/session.py b/msgpackrpc/session.py index e85da0e..23ccc68 100644 --- a/msgpackrpc/session.py +++ b/msgpackrpc/session.py @@ -3,7 +3,7 @@ from msgpackrpc.future import Future from msgpackrpc.transport import tcp from msgpackrpc.compat import iteritems -from msgpackrpc.error import TimeoutError +from msgpackrpc.error import TimeoutError, SessionError class Session(object): @@ -32,6 +32,7 @@ def __init__(self, address, timeout, loop=None, builder=tcp, reconnect_limit=5, self._transport = builder.ClientTransport(self, self._address, reconnect_limit, encodings=(pack_encoding, unpack_encoding)) self._generator = _NoSyncIDGenerator() self._request_table = {} + self._closing = False @property def address(self): @@ -45,6 +46,9 @@ def call_async(self, method, *args): def send_request(self, method, args): # need lock? + if self._closing: + raise SessionError('Session closed') + msgid = next(self._generator) future = Future(self._loop, self._timeout) self._request_table[msgid] = future @@ -58,6 +62,7 @@ def callback(): self._loop.start() def close(self): + self._closing = True if self._transport: self._transport.close() self._transport = None @@ -68,6 +73,7 @@ def on_connect_failed(self, reason): The callback called when the connection failed. Called by the transport layer. """ + self._closing = True # set error for all requests for msgid, future in iteritems(self._request_table): future.set_error(reason)