1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""RPC server implementation.
18
19Note
20----
21Server is TCP based with the following protocol:
22- Initial handshake to the peer
23  - [RPC_MAGIC, keysize(int32), key-bytes]
24- The key is in format
25   - {server|client}:device-type[:random-key] [-timeout=timeout]
26"""
27# pylint: disable=invalid-name
28
29from __future__ import absolute_import
30
31import os
32import ctypes
33import socket
34import select
35import struct
36import logging
37import multiprocessing
38import subprocess
39import time
40import sys
41import signal
42import platform
43
44from .._ffi.function import register_func
45from .._ffi.base import py_str
46from .._ffi.libinfo import find_lib_path
47from ..module import load as _load_module
48from ..contrib import util
49from . import base
50from . base import TrackerCode
51
52logger = logging.getLogger('RPCServer')
53
54def _server_env(load_library, work_path=None):
55    """Server environment function return temp dir"""
56    if work_path:
57        temp = work_path
58    else:
59        temp = util.tempdir()
60
61    # pylint: disable=unused-variable
62    @register_func("tvm.rpc.server.workpath")
63    def get_workpath(path):
64        return temp.relpath(path)
65
66    @register_func("tvm.rpc.server.load_module", override=True)
67    def load_module(file_name):
68        """Load module from remote side."""
69        path = temp.relpath(file_name)
70        m = _load_module(path)
71        logger.info("load_module %s", path)
72        return m
73
74    libs = []
75    load_library = load_library.split(":") if load_library else []
76    for file_name in load_library:
77        file_name = find_lib_path(file_name)[0]
78        libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
79        logger.info("Load additional library %s", file_name)
80    temp.libs = libs
81    return temp
82
83def _serve_loop(sock, addr, load_library, work_path=None):
84    """Server loop"""
85    sockfd = sock.fileno()
86    temp = _server_env(load_library, work_path)
87    base._ServerLoop(sockfd)
88    if not work_path:
89        temp.remove()
90    logger.info("Finish serving %s", addr)
91
92def _parse_server_opt(opts):
93    # parse client options
94    ret = {}
95    for kv in opts:
96        if kv.startswith("-timeout="):
97            ret["timeout"] = float(kv[9:])
98    return ret
99
100def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
101    """Listening loop of the server master."""
102    def _accept_conn(listen_sock, tracker_conn, ping_period=2):
103        """Accept connection from the other places.
104
105        Parameters
106        ----------
107        listen_sock: Socket
108            The socket used by listening process.
109
110        tracker_conn : connnection to tracker
111            Tracker connection
112
113        ping_period : float, optional
114            ping tracker every k seconds if no connection is accepted.
115        """
116        old_keyset = set()
117        # Report resource to tracker
118        if tracker_conn:
119            matchkey = base.random_key(rpc_key + ":")
120            base.sendjson(tracker_conn,
121                          [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr])
122            assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
123        else:
124            matchkey = rpc_key
125
126        unmatch_period_count = 0
127        unmatch_timeout = 4
128        # Wait until we get a valid connection
129        while True:
130            if tracker_conn:
131                trigger = select.select([listen_sock], [], [], ping_period)
132                if not listen_sock in trigger[0]:
133                    base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS])
134                    pending_keys = base.recvjson(tracker_conn)
135                    old_keyset.add(matchkey)
136                    # if match key not in pending key set
137                    # it means the key is acquired by a client but not used.
138                    if matchkey not in pending_keys:
139                        unmatch_period_count += 1
140                    else:
141                        unmatch_period_count = 0
142                    # regenerate match key if key is acquired but not used for a while
143                    if unmatch_period_count * ping_period > unmatch_timeout + ping_period:
144                        logger.info("no incoming connections, regenerate key ...")
145                        matchkey = base.random_key(rpc_key + ":", old_keyset)
146                        base.sendjson(tracker_conn,
147                                      [TrackerCode.PUT, rpc_key, (port, matchkey),
148                                       custom_addr])
149                        assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
150                        unmatch_period_count = 0
151                    continue
152            conn, addr = listen_sock.accept()
153            magic = struct.unpack("<i", base.recvall(conn, 4))[0]
154            if magic != base.RPC_MAGIC:
155                conn.close()
156                continue
157            keylen = struct.unpack("<i", base.recvall(conn, 4))[0]
158            key = py_str(base.recvall(conn, keylen))
159            arr = key.split()
160            expect_header = "client:" + matchkey
161            server_key = "server:" + rpc_key
162            if arr[0] != expect_header:
163                conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH))
164                conn.close()
165                logger.warning("mismatch key from %s", addr)
166                continue
167            else:
168                conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS))
169                conn.sendall(struct.pack("<i", len(server_key)))
170                conn.sendall(server_key.encode("utf-8"))
171                return conn, addr, _parse_server_opt(arr[1:])
172
173    # Server logic
174    tracker_conn = None
175    while True:
176        try:
177            # step 1: setup tracker and report to tracker
178            if tracker_addr and tracker_conn is None:
179                tracker_conn = base.connect_with_retry(tracker_addr)
180                tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC))
181                magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0]
182                if magic != base.RPC_TRACKER_MAGIC:
183                    raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr))
184                # report status of current queue
185                cinfo = {"key" : "server:" + rpc_key}
186                base.sendjson(tracker_conn,
187                              [TrackerCode.UPDATE_INFO, cinfo])
188                assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS
189
190            # step 2: wait for in-coming connections
191            conn, addr, opts = _accept_conn(sock, tracker_conn)
192        except (socket.error, IOError):
193            # retry when tracker is dropped
194            if tracker_conn:
195                tracker_conn.close()
196                tracker_conn = None
197            continue
198        except RuntimeError as exc:
199            raise exc
200
201        # step 3: serving
202        work_path = util.tempdir()
203        logger.info("connection from %s", addr)
204        server_proc = multiprocessing.Process(target=_serve_loop,
205                                              args=(conn, addr, load_library, work_path))
206        server_proc.deamon = True
207        server_proc.start()
208        # close from our side.
209        conn.close()
210        # wait until server process finish or timeout
211        server_proc.join(opts.get("timeout", None))
212        if server_proc.is_alive():
213            logger.info("Timeout in RPC session, kill..")
214            import psutil
215            parent = psutil.Process(server_proc.pid)
216            # terminate worker childs
217            for child in parent.children(recursive=True):
218                child.terminate()
219            # terminate the worker
220            server_proc.terminate()
221        work_path.remove()
222
223
224def _connect_proxy_loop(addr, key, load_library):
225    key = "server:" + key
226    retry_count = 0
227    max_retry = 5
228    retry_period = 5
229    while True:
230        try:
231            sock = socket.socket(base.get_addr_family(addr), socket.SOCK_STREAM)
232            sock.connect(addr)
233            sock.sendall(struct.pack("<i", base.RPC_MAGIC))
234            sock.sendall(struct.pack("<i", len(key)))
235            sock.sendall(key.encode("utf-8"))
236            magic = struct.unpack("<i", base.recvall(sock, 4))[0]
237            if magic == base.RPC_CODE_DUPLICATE:
238                raise RuntimeError("key: %s has already been used in proxy" % key)
239            elif magic == base.RPC_CODE_MISMATCH:
240                logger.warning("RPCProxy do not have matching client key %s", key)
241            elif magic != base.RPC_CODE_SUCCESS:
242                raise RuntimeError("%s is not RPC Proxy" % str(addr))
243            keylen = struct.unpack("<i", base.recvall(sock, 4))[0]
244            remote_key = py_str(base.recvall(sock, keylen))
245            opts = _parse_server_opt(remote_key.split()[1:])
246            logger.info("connected to %s", str(addr))
247            process = multiprocessing.Process(
248                target=_serve_loop, args=(sock, addr, load_library))
249            process.deamon = True
250            process.start()
251            sock.close()
252            process.join(opts.get("timeout", None))
253            if process.is_alive():
254                logger.info("Timeout in RPC session, kill..")
255                process.terminate()
256            retry_count = 0
257        except (socket.error, IOError) as err:
258            retry_count += 1
259            logger.warning("Error encountered %s, retry in %g sec", str(err), retry_period)
260            if retry_count > max_retry:
261                raise RuntimeError("Maximum retry error: last error: %s" % str(err))
262            time.sleep(retry_period)
263
264def _popen(cmd):
265    proc = subprocess.Popen(cmd,
266                            stdout=subprocess.PIPE,
267                            stderr=subprocess.STDOUT,
268                            env=os.environ)
269    (out, _) = proc.communicate()
270    if proc.returncode != 0:
271        msg = "Server invoke error:\n"
272        msg += out
273        raise RuntimeError(msg)
274
275
276class Server(object):
277    """Start RPC server on a separate process.
278
279    This is a simple python implementation based on multi-processing.
280    It is also possible to implement a similar C based server with
281    TVM runtime which does not depend on the python.
282
283    Parameters
284    ----------
285    host : str
286        The host url of the server.
287
288    port : int
289        The port to be bind to
290
291    port_end : int, optional
292        The end port to search
293
294    is_proxy : bool, optional
295        Whether the address specified is a proxy.
296        If this is true, the host and port actually corresponds to the
297        address of the proxy server.
298
299    use_popen : bool, optional
300        Whether to use Popen to start a fresh new process instead of fork.
301        This is recommended to switch on if we want to do local RPC demonstration
302        for GPU devices to avoid fork safety issues.
303
304    tracker_addr: Tuple (str, int) , optional
305        The address of RPC Tracker in tuple(host, ip) format.
306        If is not None, the server will register itself to the tracker.
307
308    key : str, optional
309        The key used to identify the device type in tracker.
310
311    load_library : str, optional
312        List of additional libraries to be loaded during execution.
313
314    custom_addr: str, optional
315        Custom IP Address to Report to RPC Tracker
316
317    silent: bool, optional
318        Whether run this server in silent mode.
319    """
320    def __init__(self,
321                 host,
322                 port=9091,
323                 port_end=9199,
324                 is_proxy=False,
325                 use_popen=False,
326                 tracker_addr=None,
327                 key="",
328                 load_library=None,
329                 custom_addr=None,
330                 silent=False):
331        try:
332            if base._ServerLoop is None:
333                raise RuntimeError("Please compile with USE_RPC=1")
334        except NameError:
335            raise RuntimeError("Please compile with USE_RPC=1")
336        self.host = host
337        self.port = port
338        self.libs = []
339        self.custom_addr = custom_addr
340        self.use_popen = use_popen
341
342        if silent:
343            logger.setLevel(logging.ERROR)
344
345        if use_popen:
346            cmd = [sys.executable,
347                   "-m", "tvm.exec.rpc_server",
348                   "--host=%s" % host,
349                   "--port=%s" % port]
350            if tracker_addr:
351                assert key
352                cmd += ["--tracker=%s:%d" % tracker_addr,
353                        "--key=%s" % key]
354            if load_library:
355                cmd += ["--load-library", load_library]
356            if custom_addr:
357                cmd += ["--custom-addr", custom_addr]
358            if silent:
359                cmd += ["--silent"]
360
361            # prexec_fn is not thread safe and may result in deadlock.
362            # python 3.2 introduced the start_new_session parameter as
363            # an alternative to the common use case of
364            # prexec_fn=os.setsid.  Once the minimum version of python
365            # supported by TVM reaches python 3.2 this code can be
366            # rewritten in favour of start_new_session.  In the
367            # interim, stop the pylint diagnostic.
368            #
369            # pylint: disable=subprocess-popen-preexec-fn
370            self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid)
371            time.sleep(0.5)
372        elif not is_proxy:
373            sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM)
374            self.port = None
375            for my_port in range(port, port_end):
376                try:
377                    sock.bind((host, my_port))
378                    self.port = my_port
379                    break
380                except socket.error as sock_err:
381                    if sock_err.errno in [98, 48]:
382                        continue
383                    else:
384                        raise sock_err
385            if not self.port:
386                raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end))
387            logger.info("bind to %s:%d", host, self.port)
388            sock.listen(1)
389            self.sock = sock
390            self.proc = multiprocessing.Process(
391                target=_listen_loop, args=(
392                    self.sock, self.port, key, tracker_addr, load_library,
393                    self.custom_addr))
394            self.proc.deamon = True
395            self.proc.start()
396        else:
397            self.proc = multiprocessing.Process(
398                target=_connect_proxy_loop, args=((host, port), key, load_library))
399            self.proc.deamon = True
400            self.proc.start()
401
402    def terminate(self):
403        """Terminate the server process"""
404        if self.use_popen:
405            if self.proc:
406                if platform.system() == "Windows":
407                    os.kill(self.proc.pid, signal.CTRL_C_EVENT)
408                else:
409                    os.killpg(self.proc.pid, signal.SIGTERM)
410                self.proc = None
411        else:
412            if self.proc:
413                self.proc.terminate()
414                self.proc = None
415
416    def __del__(self):
417        self.terminate()
418