diff --git a/AUTHORS b/AUTHORS index 129e9f3..a758051 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,2 +1,3 @@ Masahiro Nakagawa INADA Naoki +Harish Vishwanath \ No newline at end of file diff --git a/README.md b/README.md index 7394e6d..92c7d02 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,39 @@ +# Unix Domain Socket support +Unix domain socket support is now available for msgpack-rpc. Sample examples below. + +## UDS examples + +### Server + +```python +import msgpackrpc.udsaddress +from msgpackrpc.transport import uds +class SumServer(object): + def sum(self, x, y): + return x + y + +# Use builder as uds. default builder is tcp which creates tcp sockets +server = msgpackrpc.Server(SumServer(), builder=uds) +# Use UDSAddress instead of msgpackrpc.Address +server.listen(msgpackrpc.udsaddress.UDSAddress('/tmp/exrpc')) +server.start() +``` + +### Client +```python +import msgpackrpc.udsaddress +from msgpackrpc.transport import uds + +#Use UDSAddress instead of default Address object +client = msgpackrpc.Client(msgpackrpc.udsaddress.UDSAddress("/tmp/exrpc"), builder=uds) +result = client.call('sum', 1, 2) # = > +print "Sum of 1 and 2 : %d" % result +``` + +Go through the below sections for general usage of Message Pack RPC Library # MessagePack RPC for Python diff --git a/example/uds_simpleclient.py b/example/uds_simpleclient.py new file mode 100644 index 0000000..c12b624 --- /dev/null +++ b/example/uds_simpleclient.py @@ -0,0 +1,11 @@ +''' +@author: hvishwanath | harish.shastry@gmail.com +''' + +import msgpackrpc.udsaddress +from msgpackrpc.transport import uds + +#Use UDSAddress instead of default Address object +client = msgpackrpc.Client(msgpackrpc.udsaddress.UDSAddress("/tmp/exrpc"), builder=uds) +result = client.call('sum', 1, 2) # = > +print "Sum of 1 and 2 : %d" % result \ No newline at end of file diff --git a/example/uds_simpleserver.py b/example/uds_simpleserver.py new file mode 100644 index 0000000..3f019a4 --- /dev/null +++ b/example/uds_simpleserver.py @@ -0,0 +1,15 @@ +''' +@author: hvishwanath | harish.shastry@gmail.com +''' + +import msgpackrpc.udsaddress +from msgpackrpc.transport import uds +class SumServer(object): + def sum(self, x, y): + return x + y + +# Use builder as uds. default builder is tcp which creates tcp sockets +server = msgpackrpc.Server(SumServer(), builder=uds) +# Use UDSAddress instead of msgpackrpc.Address +server.listen(msgpackrpc.udsaddress.UDSAddress('/tmp/exrpc')) +server.start() diff --git a/msgpackrpc/__init__.py b/msgpackrpc/__init__.py index 502ba89..045d892 100644 --- a/msgpackrpc/__init__.py +++ b/msgpackrpc/__init__.py @@ -5,3 +5,4 @@ from msgpackrpc.client import Client from msgpackrpc.server import Server from msgpackrpc.address import Address +from msgpackrpc.udsaddress import UDSAddress \ No newline at end of file diff --git a/msgpackrpc/transport/uds.py b/msgpackrpc/transport/uds.py new file mode 100644 index 0000000..be7d38e --- /dev/null +++ b/msgpackrpc/transport/uds.py @@ -0,0 +1,53 @@ +''' +@author: hvishwanath | harish.shastry@gmail.com +''' + +import msgpackrpc.transport +from tornado.netutil import bind_unix_socket +from tornado import tcpserver +from tornado.iostream import IOStream + +# Much of the implementation will be same as that of tcp module +# Changes required for unix domain socket support are done in this module +# Rest will be automatically used from tcp + +# Create namespace equals +BaseSocket = msgpackrpc.transport.tcp.BaseSocket +ClientSocket = msgpackrpc.transport.tcp.ClientSocket +ClientTransport = msgpackrpc.transport.tcp.ClientTransport + +ServerSocket = msgpackrpc.transport.tcp.ServerSocket +ServerTransport = msgpackrpc.transport.tcp.ServerTransport + + +class UDSServer(tcpserver.TCPServer): + """Define a Unix domain socket server. + Instead of binding to TCP/IP socket, bind to UDS socket and listen""" + + def __init__(self, io_loop=None, ssl_options=None): + tcpserver.TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options) + + def listen(self, port, address=""): + """Bind to a unix domain socket and add to self. + Note that port in our case actually contains the uds file name""" + + # Create a Unix domain socket and bind + socket = bind_unix_socket(port) + + # Add to self + self.add_socket(socket) + +class MessagePackServer(UDSServer): + """The MessagePackServer inherits from UDSServer + instead of tornado's TCP Server""" + + def __init__(self, transport, io_loop=None, encodings=None): + self._transport = transport + self._encodings = encodings + UDSServer.__init__(self, io_loop=io_loop) + + def handle_stream(self, stream, address): + ServerSocket(stream, self._transport, self._encodings) + +#Monkey patch the MessagePackServer +msgpackrpc.transport.tcp.MessagePackServer = MessagePackServer \ No newline at end of file diff --git a/msgpackrpc/udsaddress.py b/msgpackrpc/udsaddress.py new file mode 100644 index 0000000..91a82e8 --- /dev/null +++ b/msgpackrpc/udsaddress.py @@ -0,0 +1,40 @@ +''' +@author: hvishwanath | harish.shastry@gmail.com +''' + +import socket +from tornado.platform.auto import set_close_exec + +class UDSAddress(object): + """This class abstracts Unix domain socket address. + For compatibility with other code in the library, port is always equal to host""" + + def __init__(self, host, port=None): + self._host = host + + # Passed value for port is ignored. + # Port is also made equal to host. + # This is because some of the code in transport.tcp uses address._port to connect. + # For a unix socket, there is no port. Hence if port = host, that code should work. + self._port = host + + @property + def host(self): + return self._host + + @property + def port(self): + return self._port + + def unpack(self): + # Return only the host + return self._host + + def socket(self, family=socket.AF_UNSPEC): + """Return a Unix domain socket instead of tcp socket""" + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + set_close_exec(sock.fileno()) + sock.setblocking(0) + + return sock \ No newline at end of file diff --git a/test/test_uds_msgpackrpc.py b/test/test_uds_msgpackrpc.py new file mode 100644 index 0000000..f4a9bea --- /dev/null +++ b/test/test_uds_msgpackrpc.py @@ -0,0 +1,202 @@ +''' +@author: hvishwanath | harish.shastry@gmail.com +''' + +from msgpackrpc.transport import uds +from time import sleep +import threading +try: + import unittest2 as unittest +except ImportError: + import unittest + +import helper +import msgpackrpc +from msgpackrpc import error + +class TestMessagePackRPC(unittest.TestCase): + ENABLE_TIMEOUT_TEST = False + + class TestArg: + ''' this class must know completely how to deserialize ''' + def __init__(self, a, b, c): + self.a = a + self.b = b + self.c = c + + def to_msgpack(self): + return (self.a, self.b, self.c) + + def add(self, rhs): + self.a += rhs.a + self.b -= rhs.b + self.c *= rhs.c + return self + + def __eq__(self, rhs): + return (self.a == rhs.a and self.b == rhs.b and self.c == rhs.c) + + @staticmethod + def from_msgpack(arg): + return TestMessagePackRPC.TestArg(arg[0], arg[1], arg[2]) + + class TestServer(object): + def hello(self): + return "world" + + def sum(self, x, y): + return x + y + + def nil(self): + return None + + def add_arg(self, arg0, arg1): + lhs = TestMessagePackRPC.TestArg.from_msgpack(arg0) + rhs = TestMessagePackRPC.TestArg.from_msgpack(arg1) + return lhs.add(rhs) + + def raise_error(self): + raise Exception('error') + + def long_exec(self): + sleep(3) + return 'finish!' + + def async_result(self): + ar = msgpackrpc.server.AsyncResult() + def do_async(): + sleep(2) + ar.set_result("You are async!") + threading.Thread(target=do_async).start() + return ar + + def setUp(self): + # Create UDSAddress + self._address = msgpackrpc.UDSAddress('/tmp/unusedsocket') + + def setup_env(self): + def _on_started(): + self._server._loop.dettach_periodic_callback() + lock.release() + def _start_server(server): + server._loop.attach_periodic_callback(_on_started, 1) + server.start() + server.close() + + # Use builder=uds + self._server = msgpackrpc.Server(TestMessagePackRPC.TestServer(), builder=uds) + self._server.listen(self._address) + self._thread = threading.Thread(target=_start_server, args=(self._server,)) + + lock = threading.Lock() + self._thread.start() + lock.acquire() + lock.acquire() # wait for the server to start + + self._client = msgpackrpc.Client(self._address, unpack_encoding='utf-8') + return self._client; + + def tearDown(self): + self._client.close(); + self._server.stop(); + self._thread.join(); + + def test_call(self): + client = self.setup_env(); + + result1 = client.call('hello') + result2 = client.call('sum', 1, 2) + result3 = client.call('nil') + + self.assertEqual(result1, "world", "'hello' result is incorrect") + self.assertEqual(result2, 3, "'sum' result is incorrect") + self.assertIsNone(result3, "'nil' result is incorrect") + + def test_call_userdefined_arg(self): + client = self.setup_env(); + + arg = TestMessagePackRPC.TestArg(0, 1, 2) + arg2 = TestMessagePackRPC.TestArg(23, 3, -23) + + result1 = TestMessagePackRPC.TestArg.from_msgpack(client.call('add_arg', arg, arg2)) + self.assertEqual(result1, arg.add(arg2)) + + result2 = TestMessagePackRPC.TestArg.from_msgpack(client.call('add_arg', arg2, arg)) + self.assertEqual(result2, arg2.add(arg)) + + result3 = TestMessagePackRPC.TestArg.from_msgpack(client.call('add_arg', result1, result2)) + self.assertEqual(result3, result1.add(result2)) + + def test_call_async(self): + client = self.setup_env(); + + future1 = client.call_async('hello') + future2 = client.call_async('sum', 1, 2) + future3 = client.call_async('nil') + future1.join() + future2.join() + future3.join() + + self.assertEqual(future1.result, "world", "'hello' result is incorrect in call_async") + self.assertEqual(future2.result, 3, "'sum' result is incorrect in call_async") + self.assertIsNone(future3.result, "'nil' result is incorrect in call_async") + + def test_notify(self): + client = self.setup_env(); + + result = True + try: + client.notify('hello') + client.notify('sum', 1, 2) + client.notify('nil') + except: + result = False + + self.assertTrue(result) + + def test_raise_error(self): + client = self.setup_env(); + self.assertRaises(error.RPCError, lambda: client.call('raise_error')) + + def test_unknown_method(self): + client = self.setup_env(); + self.assertRaises(error.RPCError, lambda: client.call('unknown', True)) + try: + client.call('unknown', True) + self.assertTrue(False) + except error.RPCError as e: + message = e.args[0] + self.assertEqual(message, "'unknown' method not found", "Error message mismatched") + + def test_async_result(self): + client = self.setup_env(); + self.assertEqual(client.call('async_result'), "You are async!") + + def test_connect_failed(self): + client = self.setup_env(); + port = helper.unused_port() + client = msgpackrpc.Client(msgpackrpc.Address('localhost', port), unpack_encoding='utf-8') + self.assertRaises(error.TransportError, lambda: client.call('hello')) + + def test_timeout(self): + client = self.setup_env(); + + if self.__class__.ENABLE_TIMEOUT_TEST: + self.assertEqual(client.call('long_exec'), 'finish!', "'long_exec' result is incorrect") + + client = msgpackrpc.Client(self._address, timeout=1, unpack_encoding='utf-8') + self.assertRaises(error.TimeoutError, lambda: client.call('long_exec')) + else: + print("Skip test_timeout") + + +if __name__ == '__main__': + import sys + + try: + sys.argv.remove('--timeout-test') + TestMessagePackRPC.ENABLE_TIMEOUT_TEST = True + except: + pass + + unittest.main() diff --git a/test/udsclient.py b/test/udsclient.py new file mode 100644 index 0000000..8bb22af --- /dev/null +++ b/test/udsclient.py @@ -0,0 +1,55 @@ +''' +Created on Mar 25, 2013 + +@author: hvishwanath +''' + +import socket +import datetime +import threading +import msgpack + +MANUAL = False +ITER = 50 +THREADS = 10 +SLEEP = 2 + +def random_word(size = 10): + abets = 'abcdefghijklmnopqrstuvwxyz' + rval = [] + import random + for i in range(0, size): + rval.append(abets[random.randint(0, 25)]) + + return ''.join(rval) + +def udsInteract(): + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + s.connect("/tmp/socketname") + for i in range(0, ITER): + if MANUAL: + d = raw_input("Text to send (type x to exit) : ") + else: + d = random_word() + print "Packing and Sending data..." + d = msgpack.packb(d) + s.send(d+"\n") + data = s.recv(4096) + print 'Received', repr(data) + print 'Unpacking' + data = msgpack.unpackb(data) + print data + import time + time.sleep(SLEEP) + + print "Closing connection.." + s.close() + +tlist = [] +for i in range(0, THREADS): + t = threading.Thread(target=udsInteract) + t.start() + tlist.append(t) + +for t in tlist: + t.join() diff --git a/test/udsserver.py b/test/udsserver.py new file mode 100644 index 0000000..0f504e1 --- /dev/null +++ b/test/udsserver.py @@ -0,0 +1,132 @@ +''' +Created on Mar 25, 2013 + +@author: hvishwanath +''' + +import socket,os +import msgpack +import signal +import tornado +from tornado.ioloop import IOLoop +from tornado import stack_context +from tornado.options import options, parse_command_line, define +from tornado.netutil import * +from tornado.ioloop import IOLoop +from tornado.util import bytes_type +from tornado import tcpserver +from tornado.iostream import IOStream + + +def simple_uds_server(): + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + os.remove("/tmp/socketname") + except OSError: + pass + + s.bind("/tmp/socketname") + print "Listening..." + s.listen(1) + print "Accepting..." + conn, addr = s.accept() + + data = conn.recv(4096) + print ' Received : ', data + print "Len of received data : ", len(data) + print "Trimming the last delimiter (\n)" + data = data.rstrip('\n') + + print "Trying to unpack using msgpack" + x = msgpack.unpackb(data) + print x + + print "Packing a custom object and sending to client" + y = msgpack.packb(['String', True, 1, 3.1123]); + + + conn.send(y) + conn.close() + + + +def handle_signal(sig, frame): + IOLoop.instance().add_callback(IOLoop.instance().stop) + +class UDSConnection(object): + """UDS Connection handler""" + + def __init__(self, stream, address, server): + """Initialize base params and call stream reader for next line""" + self.stream = stream + self.address = address + self.server = server + self.stream.set_close_callback(self._on_disconnect) + self.wait() + + def _on_read(self, line): + """Called when new line received from connection""" + # Some game logic (or magic) + print "From stream (%s), received data : %s" % (str(self.stream), line) + print "Stripping delimiter from the data.." + line = line.rstrip('\n') + + print "Trying to unpack using msgpack" + x = msgpack.unpackb(line) + for i in x: + print ("%s : %s" % (i, type(i))) + + print "Packing a custom object and sending to client" + y = msgpack.packb(['String', True, 1, 3.1123, False, 8383]); + + + self.stream.write(y) + self.wait() + + def wait(self): + """Read from stream until the next signed end of line""" + + print "Will read until delimiter (\\n) " + +# chunk = self.stream._read_from_socket() +# print "Read chunk : ", chunk +# print "Trying to unpack using msgpack" +# x = msgpack.unpackb(chunk) +# print x + + + #self.stream.read_bytes(3240, self._on_read) + self.stream.read_until(b'\n', self._on_read) + + def _on_disconnect(self, *args, **kwargs): + """Called on client disconnected""" + print 'Client disconnected (stream %s, address %r)'% (id(self.stream), self.address) + + + def __str__(self): + """Build string representation, will be used for working with + server identity (not only name) in future""" + return "UDS Connection (stream %s, address %r)" % (id(self.stream), self.address) + +class UDSServer(tcpserver.TCPServer): + + def handle_stream(self, stream, address): + print "New incoming connection (stream %s, address %r)" % (id(stream), address) + UDSConnection(stream, address, self) + +def tornado_uds_server(): + + us = bind_unix_socket("/tmp/socketname") + tornado.process.fork_processes(0) + server = UDSServer() + server.add_socket(us) + IOLoop.instance().start() + IOLoop.instance().close() + + +if __name__ == '__main__': + signal.signal(signal.SIGINT, handle_signal) + signal.signal(signal.SIGTERM, handle_signal) + + #simple_uds_server() + tornado_uds_server() \ No newline at end of file