diff --git a/msgpackrpc/server.py b/msgpackrpc/server.py index 947fc11..9cbd84f 100644 --- a/msgpackrpc/server.py +++ b/msgpackrpc/server.py @@ -43,10 +43,13 @@ def on_notify(self, method, param): def dispatch(self, method, param, responder): try: method = force_str(method) - if not hasattr(self._dispatcher, method): - raise error.NoMethodError("'{0}' method not found".format(method)) + if callable(self._dispatcher): + result = self._dispatcher(method, param) + else: + if not hasattr(self._dispatcher, method): + raise error.NoMethodError("'{0}' method not found".format(method)) + result = getattr(self._dispatcher, method)(*param) - result = getattr(self._dispatcher, method)(*param) if isinstance(result, AsyncResult): result.set_responder(responder) else: diff --git a/test/test_msgpackrpc.py b/test/test_msgpackrpc.py index 3543eb5..3281f04 100644 --- a/test/test_msgpackrpc.py +++ b/test/test_msgpackrpc.py @@ -88,15 +88,15 @@ def _start_server(server): lock.acquire() # wait for the server to start self._client = msgpackrpc.Client(self._address, unpack_encoding='utf-8') - return self._client; + return self._client def tearDown(self): - self._client.close(); - self._server.stop(); - self._thread.join(); + self._client.close() + self._server.stop() + self._thread.join() def test_call(self): - client = self.setup_env(); + client = self.setup_env() result1 = client.call('hello') result2 = client.call('sum', 1, 2) @@ -107,7 +107,7 @@ def test_call(self): self.assertIsNone(result3, "'nil' result is incorrect") def test_call_userdefined_arg(self): - client = self.setup_env(); + client = self.setup_env() arg = TestMessagePackRPC.TestArg(0, 1, 2) arg2 = TestMessagePackRPC.TestArg(23, 3, -23) @@ -122,7 +122,7 @@ def test_call_userdefined_arg(self): self.assertEqual(result3, result1.add(result2)) def test_call_async(self): - client = self.setup_env(); + client = self.setup_env() future1 = client.call_async('hello') future2 = client.call_async('sum', 1, 2) @@ -136,7 +136,7 @@ def test_call_async(self): self.assertIsNone(future3.result, "'nil' result is incorrect in call_async") def test_notify(self): - client = self.setup_env(); + client = self.setup_env() result = True try: @@ -149,11 +149,11 @@ def test_notify(self): self.assertTrue(result) def test_raise_error(self): - client = self.setup_env(); + client = self.setup_env() self.assertRaises(error.RPCError, lambda: client.call('raise_error')) def test_unknown_method(self): - client = self.setup_env(); + client = self.setup_env() self.assertRaises(error.RPCError, lambda: client.call('unknown', True)) try: client.call('unknown', True) @@ -163,17 +163,17 @@ def test_unknown_method(self): self.assertEqual(message, "'unknown' method not found", "Error message mismatched") def test_async_result(self): - client = self.setup_env(); + client = self.setup_env() self.assertEqual(client.call('async_result'), "You are async!") def test_connect_failed(self): - client = self.setup_env(); + client = self.setup_env() port = helper.unused_port() client = msgpackrpc.Client(msgpackrpc.Address('localhost', port), unpack_encoding='utf-8') self.assertRaises(error.TransportError, lambda: client.call('hello')) def test_timeout(self): - client = self.setup_env(); + client = self.setup_env() if self.__class__.ENABLE_TIMEOUT_TEST: self.assertEqual(client.call('long_exec'), 'finish!', "'long_exec' result is incorrect") @@ -184,6 +184,66 @@ def test_timeout(self): print("Skip test_timeout") +class TestMessagePackRPCCustomDispatcher(unittest.TestCase): + + class TestDispatcher(object): + def __init__(self): + self.successCalls = 0 + self.failCalls = 0 + + def customDispatcher(self, method, args): + if method == 'testMethod' and len(args) == 1 and args[0] == 42: + self.successCalls += 1 + return 'success' + self.failCalls += 1 + return 'fail' + + + def setUp(self): + self._address = msgpackrpc.Address('localhost', helper.unused_port()) + + def setup_env(self): + def _on_started(): + self._server._loop.dettach_periodic_callback() + lock.release() + def _start_server(server): + server._loop.attach_periodic_callback(_on_started, 1) + server.start() + server.close() + + self.__dispatcher = TestMessagePackRPCCustomDispatcher.TestDispatcher() + self._server = msgpackrpc.Server(self.__dispatcher.customDispatcher) + self._server.listen(self._address) + self._thread = threading.Thread(target=_start_server, args=(self._server,)) + + lock = threading.Lock() + self._thread.start() + lock.acquire() + lock.acquire() # wait for the server to start + + self._client = msgpackrpc.Client(self._address, unpack_encoding='utf-8') + return self._client + + def tearDown(self): + self._client.close() + self._server.stop() + self._thread.join() + + def test_call(self): + client = self.setup_env() + + result1 = client.call('testMethod', 42) + result2 = client.call('testMethod', 1, 2) + result3 = client.call('someFunc', 12, 18) + + self.assertEqual(result1, "success", "'testMethod' result is incorrect") + self.assertEqual(result2, "fail", "'testMethod' result is incorrect") + self.assertEqual(result3, "fail", "'someFunc' result is incorrect") + + self.assertEqual(self.__dispatcher.successCalls, 1, "wrong number of success calls") + self.assertEqual(self.__dispatcher.failCalls, 2, "wrong number of fail calls") + + if __name__ == '__main__': import sys