1# Copyright 2019 Lawrence Livermore National Security, LLC and other
2# Bridge Kernel Project Developers. See the top-level LICENSE file for details.
3#
4# SPDX-License-Identifier: BSD-3-Clause
5
6from __future__ import print_function
7from __future__ import unicode_literals
8
9from contextlib import contextmanager
10import getpass
11from io import BytesIO
12import json
13import os
14import re
15import rlcompleter
16import socket
17import struct
18import sys
19import time
20
21try:
22    from base64 import encodebytes
23except ImportError:
24    from base64 import encodestring as encodebytes
25
26
27class MPIServer():
28    config_dir = os.path.join(os.environ["HOME"], ".config", "bridge_kernel")
29    localhost = socket.gethostname()
30    frame_format = "!I"
31    frame_length = struct.calcsize(frame_format)
32    max_message_length = (2**(8 * frame_length)) - 1
33    chunk_size = 1024
34
35    class ClientFile():
36        def __init__(self, client, id_str):
37            self.client = client
38            self.id_str = id_str
39
40        def write(self, s):
41            self.client.writemsg(dict(type=self.id_str, code=s))
42
43        def flush(self):
44            pass
45
46    class NullFile():
47        def write(self, s):
48            pass
49
50        def flush(self):
51            pass
52
53    def __init__(self, comm, ns, callback, prefix=None, inline_matplotlib=True):
54        if comm is None:
55          self.MPI_ENABLED = False
56          self.rank = 0
57        else:
58          self.MPI_ENABLED = True
59          self.comm = comm
60          self.rank = comm.Get_rank()
61          self.nprocs = comm.Get_size()
62
63        self.callback = callback
64        self.exec_name = os.path.splitext(os.path.basename(sys.argv[0]))[0]
65        self.ns = ns
66        self.files = []
67        self.start_time = time.strftime("%Y-%m-%d %H:%M:%S")
68        self.log = sys.__stdout__.write
69        self.go = True
70
71        if self.rank == 0:
72            self.client_stdout = self.ClientFile(self, "stdout")
73            self.client_stderr = self.ClientFile(self, "stderr")
74        else:
75            self.client_stdout = self.NullFile()
76            self.client_stderr = self.NullFile()
77
78        if prefix is not None:
79            self.prefix = prefix
80        else:
81            self.prefix = "%s_%s_%s" % (self.exec_name, self.localhost,
82                re.sub(r"\W", "_", self.start_time))
83
84        self.config_file = os.path.join(self.config_dir, "{}.json".format(self.prefix))
85        self.addr = os.path.join(self.config_dir, "{}-ipc".format(self.prefix))
86
87        if self.rank == 0:
88            if not os.path.exists(self.config_dir):
89                os.makedirs(self.config_dir)
90
91        ns["quit"] = self._quitter
92        ns["exit"] = self._quitter
93        if self.MPI_ENABLED:
94          ns["mpi_print"] = self._mpi_print
95        ns["image"] = self._write_bytes_image
96
97        if inline_matplotlib:
98            self._setup_matplotlib()
99
100    def _quitter(self):
101        self.go = False
102
103    def _mpi_print(self, value, ranks="all"):
104        if ranks == "all":
105            ranks = range(self.nprocs)
106        for r in ranks:
107            if r == 0:
108                print(value)
109            else:
110                if self.rank == r:
111                    self.comm.send(value, dest=0)
112                elif self.rank == 0:
113                    print(self.comm.recv(source=r))
114
115    @contextmanager
116    def _redirect(self):
117        stdout = sys.stdout
118        stderr = sys.stderr
119        sys.stdout = self.client_stdout
120        sys.stderr = self.client_stderr
121        try:
122            yield
123        finally:
124            sys.stdout = stdout
125            sys.stderr = stderr
126
127    def _remove_files(self):
128        if self.files is not None:
129            while len(self.files) > 0:
130                f = self.files.pop()
131                if os.path.exists(f):
132                    print("removing {}".format(f))
133                    os.remove(f)
134
135    def _write_bytes_image(self, data, fmt):
136        if self.rank == 0:
137            self.writemsg({
138                "type": "display",
139                "module": "IPython.display",
140                "attr": "Image",
141                "args": {"data": encodebytes(data).decode("utf-8"), "format": fmt},
142                # request that the "data" key in "args" be decoded
143                "decode_bytes": ["data"]
144            })
145
146    def _setup_matplotlib(self):
147        # allow use without matplotlib
148        # TODO: figure out best strategy to warn on import failure
149        try:
150            import matplotlib
151            matplotlib.use("agg")
152            from matplotlib import pyplot
153        except ImportError:
154            return
155        if self.rank == 0:
156            def write_image(*args, **kwargs):
157                print(repr(pyplot.gcf()))
158                bio = BytesIO()
159                pyplot.savefig(bio, format="png")
160                bio.seek(0)
161                # TODO: Old form?
162                # self._write_bytes_image(bio.read(), "png" , "[matplotlib plot]\n")
163                self._write_bytes_image(bio.read(), "png")
164                pyplot.close()
165            pyplot._show = write_image
166            # TODO: draw_if_interactive is called automatically when using
167            # plot commands. This is probably what we want but I'm not
168            # quite sure how to hook it up, the approach below
169            # writes blank images
170            #pyplot.draw_if_interactive = write_image
171        else:
172            pyplot.show = lambda *args, **kwargs: None
173            pyplot.show_if_interactive = pyplot.show
174
175        if self.MPI_ENABLED:
176          self.comm.Barrier()
177
178    def readmsg(self):
179        try:
180            header = self._client_sock.recv(self.frame_length)
181        except socket.error:
182            print("client disconnected")
183            self._client_sock = None
184            return None
185
186        if header is None or len(header) == 0:
187            print("client disconnected")
188            self._client_sock = None
189            return None
190
191        length = struct.unpack(self.frame_format, header)[0]
192        bytes_read = 0
193        chunks = []
194        if length > 0:
195            while bytes_read < length:
196                to_read = min(self.chunk_size, length - bytes_read)
197                chunk = self._client_sock.recv(to_read)
198                if chunk is None:
199                    return
200                chunks.append(chunk)
201                bytes_read += len(chunk)
202            message = b"".join(chunks)
203        else:
204            message = None
205
206        try:
207            obj = json.loads(message.decode("utf-8"))
208        except ValueError:
209            print("invalid message format")
210            return None
211
212        return obj
213
214    def writemsg(self, data):
215        serial = json.dumps(data, ensure_ascii=False)
216        header = struct.pack(self.frame_format, len(serial))
217        message = header + serial.encode("utf-8")
218        message_length = len(message)
219
220        if message_length > self.max_message_length:
221            self.log("fatal: message too long\n")
222            return
223
224        bytes_sent = 0
225        while bytes_sent < message_length:
226            try:
227                chunk = message[bytes_sent:bytes_sent + min(self.chunk_size, message_length - bytes_sent)]
228                sent = self._client_sock.send(chunk)
229            except socket.error as sockerr:
230                (code, msg) = sockerr.args
231                # TODO: error, simply print for now
232                print("socket error: {}, {}".format(code, msg))
233                return
234
235            if sent == 0:
236                self.post_message("error: socket connection broken")
237                return
238            bytes_sent += sent
239
240    #TODO make this default behavior of writemsg?
241    def root_writemsg(self, data):
242        if self.rank == 0:
243            self.writemsg(data)
244
245    def evaluate(self, code):
246        with self._redirect():
247            if self.rank == 0:
248                self.log("python> {}\n".format(code))
249            try:
250                resp = eval(code, self.ns)
251                if resp is not None:
252                    resp = str(resp)
253            except Exception as e:
254                try:
255                    exec(code, self.ns)
256                    resp = None
257                except Exception as e:
258                    resp = str(e)
259                    self.writemsg({"type": "stdout", "code": resp})
260
261            if self.rank == 0 and resp is not None and len(resp) > 0:
262                try:
263                    resp = resp + "\n"
264                    self.log("{}".format(resp))
265                except UnicodeDecodeError:
266                    self.log("<cannot display, non-ascii>\n")
267                self.writemsg({"type": "eval_result", "code": resp})
268
269    def complete(self, text, cursor_pos=None):
270        if self.rank == 0:
271            if cursor_pos is None:
272                cursor_pos = len(text)
273            cursor_start = 0
274            for i in range(cursor_pos - 1, -1, -1):
275                # only works in python3
276                # if text[i].isidentifier() or text[i] == ".":
277                if text[i].isalpha() or text[i].isnumeric() or text[i] == "_" or text[i] == ".":
278                    pass
279                else:
280                    cursor_start = i + 1
281                    break
282
283            text = text[cursor_start:cursor_pos]
284            completor = rlcompleter.Completer(namespace=self.ns)
285
286            if len(text) == 0:
287                matches = completor.global_matches("")
288            else:
289                completor.complete(text, 0)
290                if hasattr(completor, "matches"):
291                    matches = completor.matches
292                else:
293                    matches = None
294
295            self.writemsg({
296                "type": "complete",
297                "cursor_start": cursor_start,
298                "cursor_end": cursor_pos,
299                "matches": [m[0:-1] if m[-1] == "(" else m for m in matches]
300            })
301
302        if self.MPI_ENABLED:
303          self.comm.Barrier()
304
305    def serve(self):
306        if self.rank == 0:
307            try:
308                os.unlink(self.addr)
309            except OSError:
310                if os.path.exists(self.addr):
311                    raise
312            sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
313            print("trying to bind to {}... ".format(self.addr))
314            sock.bind(self.addr)
315
316            self.files.append(self.addr)
317
318            if "LCSCHEDCLUSTER" in os.environ:
319                hosts = [os.environ["LCSCHEDCLUSTER"]]
320            else:
321                hosts = []
322
323            hosts.append(self.localhost)
324
325            config = {
326                "protocol": "ipc",
327                "uds": self.addr,
328                "hosts": hosts,
329                "code": self.exec_name,
330                "argv": " ".join(sys.argv),
331                "user": getpass.getuser(),
332                "date": self.start_time
333            }
334
335            if config is None or sock is None:
336                print("fatal: failed to create socket")
337                return
338
339            print("connect to: {}".format(self.config_file))
340            with open(self.config_file, "w") as _:
341                json.dump(config, _, indent=4, sort_keys=True)
342            self.files.append(self.config_file)
343
344            sock.listen(0)
345
346            self.config = config
347            self._sock = sock
348
349            print("bind successful. waiting for client...")
350            self._client_sock, addr = sock.accept()
351            print("got client - {}".format(addr))
352
353            self.writemsg({"type": "idle"})
354
355        if self.MPI_ENABLED:
356          self.comm.Barrier()
357
358        while self.go:
359            if self.rank == 0:
360                data = self.readmsg()
361            else:
362                data = None
363            if self.MPI_ENABLED:
364              data = self.comm.bcast(data, root=0)
365
366            if data is None:
367                print("disconnected")
368                break
369
370            if data["type"] == "execute":
371                self.evaluate(data["code"])
372            elif data["type"] == "complete":
373                self.complete(data["code"], cursor_pos=data["cursor_pos"])
374            elif data["type"] == "custom":
375                self.callback(data)
376            elif data["type"] == "ping":
377                self.root_writemsg({"type": "pong"})
378            elif data["type"] == "disconnect":
379                break
380
381            self.root_writemsg({"type": "idle"})
382
383        self.root_writemsg({"type": "disconnect"})
384        if self.rank == 0:
385            self._remove_files()
386