1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""RPC server implementation. 18 19Note 20---- 21Server is TCP based with the following protocol: 22- Initial handshake to the peer 23 - [RPC_MAGIC, keysize(int32), key-bytes] 24- The key is in format 25 - {server|client}:device-type[:random-key] [-timeout=timeout] 26""" 27# pylint: disable=invalid-name 28 29from __future__ import absolute_import 30 31import os 32import ctypes 33import socket 34import select 35import struct 36import logging 37import multiprocessing 38import subprocess 39import time 40import sys 41import signal 42import platform 43 44from .._ffi.function import register_func 45from .._ffi.base import py_str 46from .._ffi.libinfo import find_lib_path 47from ..module import load as _load_module 48from ..contrib import util 49from . import base 50from . base import TrackerCode 51 52logger = logging.getLogger('RPCServer') 53 54def _server_env(load_library, work_path=None): 55 """Server environment function return temp dir""" 56 if work_path: 57 temp = work_path 58 else: 59 temp = util.tempdir() 60 61 # pylint: disable=unused-variable 62 @register_func("tvm.rpc.server.workpath") 63 def get_workpath(path): 64 return temp.relpath(path) 65 66 @register_func("tvm.rpc.server.load_module", override=True) 67 def load_module(file_name): 68 """Load module from remote side.""" 69 path = temp.relpath(file_name) 70 m = _load_module(path) 71 logger.info("load_module %s", path) 72 return m 73 74 libs = [] 75 load_library = load_library.split(":") if load_library else [] 76 for file_name in load_library: 77 file_name = find_lib_path(file_name)[0] 78 libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) 79 logger.info("Load additional library %s", file_name) 80 temp.libs = libs 81 return temp 82 83def _serve_loop(sock, addr, load_library, work_path=None): 84 """Server loop""" 85 sockfd = sock.fileno() 86 temp = _server_env(load_library, work_path) 87 base._ServerLoop(sockfd) 88 if not work_path: 89 temp.remove() 90 logger.info("Finish serving %s", addr) 91 92def _parse_server_opt(opts): 93 # parse client options 94 ret = {} 95 for kv in opts: 96 if kv.startswith("-timeout="): 97 ret["timeout"] = float(kv[9:]) 98 return ret 99 100def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): 101 """Listening loop of the server master.""" 102 def _accept_conn(listen_sock, tracker_conn, ping_period=2): 103 """Accept connection from the other places. 104 105 Parameters 106 ---------- 107 listen_sock: Socket 108 The socket used by listening process. 109 110 tracker_conn : connnection to tracker 111 Tracker connection 112 113 ping_period : float, optional 114 ping tracker every k seconds if no connection is accepted. 115 """ 116 old_keyset = set() 117 # Report resource to tracker 118 if tracker_conn: 119 matchkey = base.random_key(rpc_key + ":") 120 base.sendjson(tracker_conn, 121 [TrackerCode.PUT, rpc_key, (port, matchkey), custom_addr]) 122 assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS 123 else: 124 matchkey = rpc_key 125 126 unmatch_period_count = 0 127 unmatch_timeout = 4 128 # Wait until we get a valid connection 129 while True: 130 if tracker_conn: 131 trigger = select.select([listen_sock], [], [], ping_period) 132 if not listen_sock in trigger[0]: 133 base.sendjson(tracker_conn, [TrackerCode.GET_PENDING_MATCHKEYS]) 134 pending_keys = base.recvjson(tracker_conn) 135 old_keyset.add(matchkey) 136 # if match key not in pending key set 137 # it means the key is acquired by a client but not used. 138 if matchkey not in pending_keys: 139 unmatch_period_count += 1 140 else: 141 unmatch_period_count = 0 142 # regenerate match key if key is acquired but not used for a while 143 if unmatch_period_count * ping_period > unmatch_timeout + ping_period: 144 logger.info("no incoming connections, regenerate key ...") 145 matchkey = base.random_key(rpc_key + ":", old_keyset) 146 base.sendjson(tracker_conn, 147 [TrackerCode.PUT, rpc_key, (port, matchkey), 148 custom_addr]) 149 assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS 150 unmatch_period_count = 0 151 continue 152 conn, addr = listen_sock.accept() 153 magic = struct.unpack("<i", base.recvall(conn, 4))[0] 154 if magic != base.RPC_MAGIC: 155 conn.close() 156 continue 157 keylen = struct.unpack("<i", base.recvall(conn, 4))[0] 158 key = py_str(base.recvall(conn, keylen)) 159 arr = key.split() 160 expect_header = "client:" + matchkey 161 server_key = "server:" + rpc_key 162 if arr[0] != expect_header: 163 conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH)) 164 conn.close() 165 logger.warning("mismatch key from %s", addr) 166 continue 167 else: 168 conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS)) 169 conn.sendall(struct.pack("<i", len(server_key))) 170 conn.sendall(server_key.encode("utf-8")) 171 return conn, addr, _parse_server_opt(arr[1:]) 172 173 # Server logic 174 tracker_conn = None 175 while True: 176 try: 177 # step 1: setup tracker and report to tracker 178 if tracker_addr and tracker_conn is None: 179 tracker_conn = base.connect_with_retry(tracker_addr) 180 tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC)) 181 magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0] 182 if magic != base.RPC_TRACKER_MAGIC: 183 raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr)) 184 # report status of current queue 185 cinfo = {"key" : "server:" + rpc_key} 186 base.sendjson(tracker_conn, 187 [TrackerCode.UPDATE_INFO, cinfo]) 188 assert base.recvjson(tracker_conn) == TrackerCode.SUCCESS 189 190 # step 2: wait for in-coming connections 191 conn, addr, opts = _accept_conn(sock, tracker_conn) 192 except (socket.error, IOError): 193 # retry when tracker is dropped 194 if tracker_conn: 195 tracker_conn.close() 196 tracker_conn = None 197 continue 198 except RuntimeError as exc: 199 raise exc 200 201 # step 3: serving 202 work_path = util.tempdir() 203 logger.info("connection from %s", addr) 204 server_proc = multiprocessing.Process(target=_serve_loop, 205 args=(conn, addr, load_library, work_path)) 206 server_proc.deamon = True 207 server_proc.start() 208 # close from our side. 209 conn.close() 210 # wait until server process finish or timeout 211 server_proc.join(opts.get("timeout", None)) 212 if server_proc.is_alive(): 213 logger.info("Timeout in RPC session, kill..") 214 import psutil 215 parent = psutil.Process(server_proc.pid) 216 # terminate worker childs 217 for child in parent.children(recursive=True): 218 child.terminate() 219 # terminate the worker 220 server_proc.terminate() 221 work_path.remove() 222 223 224def _connect_proxy_loop(addr, key, load_library): 225 key = "server:" + key 226 retry_count = 0 227 max_retry = 5 228 retry_period = 5 229 while True: 230 try: 231 sock = socket.socket(base.get_addr_family(addr), socket.SOCK_STREAM) 232 sock.connect(addr) 233 sock.sendall(struct.pack("<i", base.RPC_MAGIC)) 234 sock.sendall(struct.pack("<i", len(key))) 235 sock.sendall(key.encode("utf-8")) 236 magic = struct.unpack("<i", base.recvall(sock, 4))[0] 237 if magic == base.RPC_CODE_DUPLICATE: 238 raise RuntimeError("key: %s has already been used in proxy" % key) 239 elif magic == base.RPC_CODE_MISMATCH: 240 logger.warning("RPCProxy do not have matching client key %s", key) 241 elif magic != base.RPC_CODE_SUCCESS: 242 raise RuntimeError("%s is not RPC Proxy" % str(addr)) 243 keylen = struct.unpack("<i", base.recvall(sock, 4))[0] 244 remote_key = py_str(base.recvall(sock, keylen)) 245 opts = _parse_server_opt(remote_key.split()[1:]) 246 logger.info("connected to %s", str(addr)) 247 process = multiprocessing.Process( 248 target=_serve_loop, args=(sock, addr, load_library)) 249 process.deamon = True 250 process.start() 251 sock.close() 252 process.join(opts.get("timeout", None)) 253 if process.is_alive(): 254 logger.info("Timeout in RPC session, kill..") 255 process.terminate() 256 retry_count = 0 257 except (socket.error, IOError) as err: 258 retry_count += 1 259 logger.warning("Error encountered %s, retry in %g sec", str(err), retry_period) 260 if retry_count > max_retry: 261 raise RuntimeError("Maximum retry error: last error: %s" % str(err)) 262 time.sleep(retry_period) 263 264def _popen(cmd): 265 proc = subprocess.Popen(cmd, 266 stdout=subprocess.PIPE, 267 stderr=subprocess.STDOUT, 268 env=os.environ) 269 (out, _) = proc.communicate() 270 if proc.returncode != 0: 271 msg = "Server invoke error:\n" 272 msg += out 273 raise RuntimeError(msg) 274 275 276class Server(object): 277 """Start RPC server on a separate process. 278 279 This is a simple python implementation based on multi-processing. 280 It is also possible to implement a similar C based server with 281 TVM runtime which does not depend on the python. 282 283 Parameters 284 ---------- 285 host : str 286 The host url of the server. 287 288 port : int 289 The port to be bind to 290 291 port_end : int, optional 292 The end port to search 293 294 is_proxy : bool, optional 295 Whether the address specified is a proxy. 296 If this is true, the host and port actually corresponds to the 297 address of the proxy server. 298 299 use_popen : bool, optional 300 Whether to use Popen to start a fresh new process instead of fork. 301 This is recommended to switch on if we want to do local RPC demonstration 302 for GPU devices to avoid fork safety issues. 303 304 tracker_addr: Tuple (str, int) , optional 305 The address of RPC Tracker in tuple(host, ip) format. 306 If is not None, the server will register itself to the tracker. 307 308 key : str, optional 309 The key used to identify the device type in tracker. 310 311 load_library : str, optional 312 List of additional libraries to be loaded during execution. 313 314 custom_addr: str, optional 315 Custom IP Address to Report to RPC Tracker 316 317 silent: bool, optional 318 Whether run this server in silent mode. 319 """ 320 def __init__(self, 321 host, 322 port=9091, 323 port_end=9199, 324 is_proxy=False, 325 use_popen=False, 326 tracker_addr=None, 327 key="", 328 load_library=None, 329 custom_addr=None, 330 silent=False): 331 try: 332 if base._ServerLoop is None: 333 raise RuntimeError("Please compile with USE_RPC=1") 334 except NameError: 335 raise RuntimeError("Please compile with USE_RPC=1") 336 self.host = host 337 self.port = port 338 self.libs = [] 339 self.custom_addr = custom_addr 340 self.use_popen = use_popen 341 342 if silent: 343 logger.setLevel(logging.ERROR) 344 345 if use_popen: 346 cmd = [sys.executable, 347 "-m", "tvm.exec.rpc_server", 348 "--host=%s" % host, 349 "--port=%s" % port] 350 if tracker_addr: 351 assert key 352 cmd += ["--tracker=%s:%d" % tracker_addr, 353 "--key=%s" % key] 354 if load_library: 355 cmd += ["--load-library", load_library] 356 if custom_addr: 357 cmd += ["--custom-addr", custom_addr] 358 if silent: 359 cmd += ["--silent"] 360 361 # prexec_fn is not thread safe and may result in deadlock. 362 # python 3.2 introduced the start_new_session parameter as 363 # an alternative to the common use case of 364 # prexec_fn=os.setsid. Once the minimum version of python 365 # supported by TVM reaches python 3.2 this code can be 366 # rewritten in favour of start_new_session. In the 367 # interim, stop the pylint diagnostic. 368 # 369 # pylint: disable=subprocess-popen-preexec-fn 370 self.proc = subprocess.Popen(cmd, preexec_fn=os.setsid) 371 time.sleep(0.5) 372 elif not is_proxy: 373 sock = socket.socket(base.get_addr_family((host, port)), socket.SOCK_STREAM) 374 self.port = None 375 for my_port in range(port, port_end): 376 try: 377 sock.bind((host, my_port)) 378 self.port = my_port 379 break 380 except socket.error as sock_err: 381 if sock_err.errno in [98, 48]: 382 continue 383 else: 384 raise sock_err 385 if not self.port: 386 raise ValueError("cannot bind to any port in [%d, %d)" % (port, port_end)) 387 logger.info("bind to %s:%d", host, self.port) 388 sock.listen(1) 389 self.sock = sock 390 self.proc = multiprocessing.Process( 391 target=_listen_loop, args=( 392 self.sock, self.port, key, tracker_addr, load_library, 393 self.custom_addr)) 394 self.proc.deamon = True 395 self.proc.start() 396 else: 397 self.proc = multiprocessing.Process( 398 target=_connect_proxy_loop, args=((host, port), key, load_library)) 399 self.proc.deamon = True 400 self.proc.start() 401 402 def terminate(self): 403 """Terminate the server process""" 404 if self.use_popen: 405 if self.proc: 406 if platform.system() == "Windows": 407 os.kill(self.proc.pid, signal.CTRL_C_EVENT) 408 else: 409 os.killpg(self.proc.pid, signal.SIGTERM) 410 self.proc = None 411 else: 412 if self.proc: 413 self.proc.terminate() 414 self.proc = None 415 416 def __del__(self): 417 self.terminate() 418