-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrpconnect.py
154 lines (127 loc) · 5.08 KB
/
rpconnect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import builtins
import logging
import socket
import json
import threading
def read_socket(source_socket, num_bytes):
message = b""
while True:
chunk = source_socket.recv(num_bytes - len(message))
message += chunk
if not chunk or len(message) == num_bytes:
return message
def read_message(source_socket):
message_length = int.from_bytes(read_socket(source_socket, 16), 'little')
raw_message = read_socket(source_socket, message_length)
return raw_message.decode()
def send_message(target_socket, message: str):
raw_message = message.encode()
target_socket.sendall(len(raw_message).to_bytes(16, 'little'))
target_socket.sendall(raw_message)
def remote_call(_host: str, _port: int, _name: str, *args, **kwargs):
"""
Dispatch a call to a :py:class:`RpcServer` at ``_host``
:param _host: host to connect to
:param _port: port on host the server is listening add
:param _name: name a payload is registered with the server
:param args: positional arguments for the payload
:param kwargs: keyword arguments for the payload
"""
serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
serversocket.connect((_host, _port))
data = {'_name': _name, 'args': args, 'kwargs': kwargs}
message = json.dumps(data)
send_message(serversocket, message)
raw_result = read_message(serversocket)
serversocket.close()
reply = json.loads(raw_result)
if reply['type'] == 'error':
if isinstance(getattr(builtins, reply['exc_type'], None), Exception):
raise getattr(builtins, reply['exc_type'])(reply['message'])
raise ValueError('%s: %s' % (reply['exc_type'], reply['message']))
elif reply['type'] == 'result':
return reply['content']
else:
raise ValueError('malformed reply: %s' % reply)
class RpcServer(object):
"""
Server to expose callables over TCP for remote procedure call
:param port: port number to listen on
:param interface: the interface to bind to, e.g. 'localhost'
.. code::
def pingpong(*args, **kwargs):
return args, kwargs
with RpcServer(23000) as server:
server.register(pingpong)
server.run()
"""
def __init__(self, port: int, interface: str=''):
self.port = port
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.bind((interface, port))
self._payloads = {} # name => callable
self._socket.listen(5)
self._closed = False
@property
def closed(self):
"""Whether the server has been shut down"""
return self._closed
@closed.setter
def closed(self, value):
if not value and self._closed:
raise ValueError('cannot re-open a server')
if not self._closed:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close()
self._closed = True
def register(self, payload: callable, name: str=None):
"""
Register a callable under a given name
:param name: name to register the payload with
:param payload: callable to register
:return:
"""
if name is None:
name = payload.__name__
self._payloads[name] = payload
def run(self):
"""Run the server, blocking until interrupted"""
while True:
(clientsocket, address) = self._socket.accept()
logging.info('new connection: %s', address)
# TODO: register these?
connection_thread = threading.Thread(target=self._handle_connection, args=(clientsocket, address))
connection_thread.start()
def _handle_connection(self, clientsocket, address):
logging.info('[%s] handling connection', address)
try:
message = read_message(clientsocket)
data = json.loads(message)
payload_name = data['_name']
payload_args = data['args']
payload_kwargs = data['kwargs']
logging.debug('[%s]: %s(*%s, **%s)', address, payload_name, payload_args, payload_kwargs)
result = self._payloads[payload_name](*payload_args, **payload_kwargs)
send_message(clientsocket, self._format_result(result))
logging.debug('[%s]: %s', address, result)
except Exception as err:
logging.exception('[%s]: Exception', address)
send_message(clientsocket, self._format_exception(err))
clientsocket.close()
@staticmethod
def _format_exception(err, message=''):
data = {'type': 'error', 'exc_type': err.__class__.__name__, 'message': message or str(err)}
return json.dumps(data)
@staticmethod
def _format_result(result):
data = {'type': 'result', 'content': result}
return json.dumps(data)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
logging.error('__exit__: %s %s', exc_type, exc_val)
self.closed = True
return False
def __del__(self):
self.closed = True
__all__ = ['remote_call', 'RpcServer']