1# Copyright 2019 Lawrence Livermore National Security, LLC and other 2# Bridge Kernel Project Developers. See the top-level LICENSE file for details. 3# 4# SPDX-License-Identifier: BSD-3-Clause 5 6from __future__ import print_function 7from __future__ import unicode_literals 8 9from contextlib import contextmanager 10import getpass 11from io import BytesIO 12import json 13import os 14import re 15import rlcompleter 16import socket 17import struct 18import sys 19import time 20 21try: 22 from base64 import encodebytes 23except ImportError: 24 from base64 import encodestring as encodebytes 25 26 27class MPIServer(): 28 config_dir = os.path.join(os.environ["HOME"], ".config", "bridge_kernel") 29 localhost = socket.gethostname() 30 frame_format = "!I" 31 frame_length = struct.calcsize(frame_format) 32 max_message_length = (2**(8 * frame_length)) - 1 33 chunk_size = 1024 34 35 class ClientFile(): 36 def __init__(self, client, id_str): 37 self.client = client 38 self.id_str = id_str 39 40 def write(self, s): 41 self.client.writemsg(dict(type=self.id_str, code=s)) 42 43 def flush(self): 44 pass 45 46 class NullFile(): 47 def write(self, s): 48 pass 49 50 def flush(self): 51 pass 52 53 def __init__(self, comm, ns, callback, prefix=None, inline_matplotlib=True): 54 if comm is None: 55 self.MPI_ENABLED = False 56 self.rank = 0 57 else: 58 self.MPI_ENABLED = True 59 self.comm = comm 60 self.rank = comm.Get_rank() 61 self.nprocs = comm.Get_size() 62 63 self.callback = callback 64 self.exec_name = os.path.splitext(os.path.basename(sys.argv[0]))[0] 65 self.ns = ns 66 self.files = [] 67 self.start_time = time.strftime("%Y-%m-%d %H:%M:%S") 68 self.log = sys.__stdout__.write 69 self.go = True 70 71 if self.rank == 0: 72 self.client_stdout = self.ClientFile(self, "stdout") 73 self.client_stderr = self.ClientFile(self, "stderr") 74 else: 75 self.client_stdout = self.NullFile() 76 self.client_stderr = self.NullFile() 77 78 if prefix is not None: 79 self.prefix = prefix 80 else: 81 self.prefix = "%s_%s_%s" % (self.exec_name, self.localhost, 82 re.sub(r"\W", "_", self.start_time)) 83 84 self.config_file = os.path.join(self.config_dir, "{}.json".format(self.prefix)) 85 self.addr = os.path.join(self.config_dir, "{}-ipc".format(self.prefix)) 86 87 if self.rank == 0: 88 if not os.path.exists(self.config_dir): 89 os.makedirs(self.config_dir) 90 91 ns["quit"] = self._quitter 92 ns["exit"] = self._quitter 93 if self.MPI_ENABLED: 94 ns["mpi_print"] = self._mpi_print 95 ns["image"] = self._write_bytes_image 96 97 if inline_matplotlib: 98 self._setup_matplotlib() 99 100 def _quitter(self): 101 self.go = False 102 103 def _mpi_print(self, value, ranks="all"): 104 if ranks == "all": 105 ranks = range(self.nprocs) 106 for r in ranks: 107 if r == 0: 108 print(value) 109 else: 110 if self.rank == r: 111 self.comm.send(value, dest=0) 112 elif self.rank == 0: 113 print(self.comm.recv(source=r)) 114 115 @contextmanager 116 def _redirect(self): 117 stdout = sys.stdout 118 stderr = sys.stderr 119 sys.stdout = self.client_stdout 120 sys.stderr = self.client_stderr 121 try: 122 yield 123 finally: 124 sys.stdout = stdout 125 sys.stderr = stderr 126 127 def _remove_files(self): 128 if self.files is not None: 129 while len(self.files) > 0: 130 f = self.files.pop() 131 if os.path.exists(f): 132 print("removing {}".format(f)) 133 os.remove(f) 134 135 def _write_bytes_image(self, data, fmt): 136 if self.rank == 0: 137 self.writemsg({ 138 "type": "display", 139 "module": "IPython.display", 140 "attr": "Image", 141 "args": {"data": encodebytes(data).decode("utf-8"), "format": fmt}, 142 # request that the "data" key in "args" be decoded 143 "decode_bytes": ["data"] 144 }) 145 146 def _setup_matplotlib(self): 147 # allow use without matplotlib 148 # TODO: figure out best strategy to warn on import failure 149 try: 150 import matplotlib 151 matplotlib.use("agg") 152 from matplotlib import pyplot 153 except ImportError: 154 return 155 if self.rank == 0: 156 def write_image(*args, **kwargs): 157 print(repr(pyplot.gcf())) 158 bio = BytesIO() 159 pyplot.savefig(bio, format="png") 160 bio.seek(0) 161 # TODO: Old form? 162 # self._write_bytes_image(bio.read(), "png" , "[matplotlib plot]\n") 163 self._write_bytes_image(bio.read(), "png") 164 pyplot.close() 165 pyplot._show = write_image 166 # TODO: draw_if_interactive is called automatically when using 167 # plot commands. This is probably what we want but I'm not 168 # quite sure how to hook it up, the approach below 169 # writes blank images 170 #pyplot.draw_if_interactive = write_image 171 else: 172 pyplot.show = lambda *args, **kwargs: None 173 pyplot.show_if_interactive = pyplot.show 174 175 if self.MPI_ENABLED: 176 self.comm.Barrier() 177 178 def readmsg(self): 179 try: 180 header = self._client_sock.recv(self.frame_length) 181 except socket.error: 182 print("client disconnected") 183 self._client_sock = None 184 return None 185 186 if header is None or len(header) == 0: 187 print("client disconnected") 188 self._client_sock = None 189 return None 190 191 length = struct.unpack(self.frame_format, header)[0] 192 bytes_read = 0 193 chunks = [] 194 if length > 0: 195 while bytes_read < length: 196 to_read = min(self.chunk_size, length - bytes_read) 197 chunk = self._client_sock.recv(to_read) 198 if chunk is None: 199 return 200 chunks.append(chunk) 201 bytes_read += len(chunk) 202 message = b"".join(chunks) 203 else: 204 message = None 205 206 try: 207 obj = json.loads(message.decode("utf-8")) 208 except ValueError: 209 print("invalid message format") 210 return None 211 212 return obj 213 214 def writemsg(self, data): 215 serial = json.dumps(data, ensure_ascii=False) 216 header = struct.pack(self.frame_format, len(serial)) 217 message = header + serial.encode("utf-8") 218 message_length = len(message) 219 220 if message_length > self.max_message_length: 221 self.log("fatal: message too long\n") 222 return 223 224 bytes_sent = 0 225 while bytes_sent < message_length: 226 try: 227 chunk = message[bytes_sent:bytes_sent + min(self.chunk_size, message_length - bytes_sent)] 228 sent = self._client_sock.send(chunk) 229 except socket.error as sockerr: 230 (code, msg) = sockerr.args 231 # TODO: error, simply print for now 232 print("socket error: {}, {}".format(code, msg)) 233 return 234 235 if sent == 0: 236 self.post_message("error: socket connection broken") 237 return 238 bytes_sent += sent 239 240 #TODO make this default behavior of writemsg? 241 def root_writemsg(self, data): 242 if self.rank == 0: 243 self.writemsg(data) 244 245 def evaluate(self, code): 246 with self._redirect(): 247 if self.rank == 0: 248 self.log("python> {}\n".format(code)) 249 try: 250 resp = eval(code, self.ns) 251 if resp is not None: 252 resp = str(resp) 253 except Exception as e: 254 try: 255 exec(code, self.ns) 256 resp = None 257 except Exception as e: 258 resp = str(e) 259 self.writemsg({"type": "stdout", "code": resp}) 260 261 if self.rank == 0 and resp is not None and len(resp) > 0: 262 try: 263 resp = resp + "\n" 264 self.log("{}".format(resp)) 265 except UnicodeDecodeError: 266 self.log("<cannot display, non-ascii>\n") 267 self.writemsg({"type": "eval_result", "code": resp}) 268 269 def complete(self, text, cursor_pos=None): 270 if self.rank == 0: 271 if cursor_pos is None: 272 cursor_pos = len(text) 273 cursor_start = 0 274 for i in range(cursor_pos - 1, -1, -1): 275 # only works in python3 276 # if text[i].isidentifier() or text[i] == ".": 277 if text[i].isalpha() or text[i].isnumeric() or text[i] == "_" or text[i] == ".": 278 pass 279 else: 280 cursor_start = i + 1 281 break 282 283 text = text[cursor_start:cursor_pos] 284 completor = rlcompleter.Completer(namespace=self.ns) 285 286 if len(text) == 0: 287 matches = completor.global_matches("") 288 else: 289 completor.complete(text, 0) 290 if hasattr(completor, "matches"): 291 matches = completor.matches 292 else: 293 matches = None 294 295 self.writemsg({ 296 "type": "complete", 297 "cursor_start": cursor_start, 298 "cursor_end": cursor_pos, 299 "matches": [m[0:-1] if m[-1] == "(" else m for m in matches] 300 }) 301 302 if self.MPI_ENABLED: 303 self.comm.Barrier() 304 305 def serve(self): 306 if self.rank == 0: 307 try: 308 os.unlink(self.addr) 309 except OSError: 310 if os.path.exists(self.addr): 311 raise 312 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 313 print("trying to bind to {}... ".format(self.addr)) 314 sock.bind(self.addr) 315 316 self.files.append(self.addr) 317 318 if "LCSCHEDCLUSTER" in os.environ: 319 hosts = [os.environ["LCSCHEDCLUSTER"]] 320 else: 321 hosts = [] 322 323 hosts.append(self.localhost) 324 325 config = { 326 "protocol": "ipc", 327 "uds": self.addr, 328 "hosts": hosts, 329 "code": self.exec_name, 330 "argv": " ".join(sys.argv), 331 "user": getpass.getuser(), 332 "date": self.start_time 333 } 334 335 if config is None or sock is None: 336 print("fatal: failed to create socket") 337 return 338 339 print("connect to: {}".format(self.config_file)) 340 with open(self.config_file, "w") as _: 341 json.dump(config, _, indent=4, sort_keys=True) 342 self.files.append(self.config_file) 343 344 sock.listen(0) 345 346 self.config = config 347 self._sock = sock 348 349 print("bind successful. waiting for client...") 350 self._client_sock, addr = sock.accept() 351 print("got client - {}".format(addr)) 352 353 self.writemsg({"type": "idle"}) 354 355 if self.MPI_ENABLED: 356 self.comm.Barrier() 357 358 while self.go: 359 if self.rank == 0: 360 data = self.readmsg() 361 else: 362 data = None 363 if self.MPI_ENABLED: 364 data = self.comm.bcast(data, root=0) 365 366 if data is None: 367 print("disconnected") 368 break 369 370 if data["type"] == "execute": 371 self.evaluate(data["code"]) 372 elif data["type"] == "complete": 373 self.complete(data["code"], cursor_pos=data["cursor_pos"]) 374 elif data["type"] == "custom": 375 self.callback(data) 376 elif data["type"] == "ping": 377 self.root_writemsg({"type": "pong"}) 378 elif data["type"] == "disconnect": 379 break 380 381 self.root_writemsg({"type": "idle"}) 382 383 self.root_writemsg({"type": "disconnect"}) 384 if self.rank == 0: 385 self._remove_files() 386