1# Copyright 2017-2020 Palantir Technologies, Inc.
2# Copyright 2021- Python Language Server Contributors.
3
4import logging
5import threading
6
7try:
8    import ujson as json
9except Exception:  # pylint: disable=broad-except
10    import json
11
12log = logging.getLogger(__name__)
13
14
15class JsonRpcStreamReader:
16    def __init__(self, rfile):
17        self._rfile = rfile
18
19    def close(self):
20        self._rfile.close()
21
22    def listen(self, message_consumer):
23        """Blocking call to listen for messages on the rfile.
24
25        Args:
26            message_consumer (fn): function that is passed each message as it is read off the socket.
27        """
28        while not self._rfile.closed:
29            try:
30                request_str = self._read_message()
31            except ValueError:
32                if self._rfile.closed:
33                    return
34                log.exception("Failed to read from rfile")
35
36            if request_str is None:
37                break
38
39            try:
40                message_consumer(json.loads(request_str.decode('utf-8')))
41            except ValueError:
42                log.exception("Failed to parse JSON message %s", request_str)
43                continue
44
45    def _read_message(self):
46        """Reads the contents of a message.
47
48        Returns:
49            body of message if parsable else None
50        """
51        line = self._rfile.readline()
52
53        if not line:
54            return None
55
56        content_length = self._content_length(line)
57
58        # Blindly consume all header lines
59        while line and line.strip():
60            line = self._rfile.readline()
61
62        if not line:
63            return None
64
65        # Grab the body
66        return self._rfile.read(content_length)
67
68    @staticmethod
69    def _content_length(line):
70        """Extract the content length from an input line."""
71        if line.startswith(b'Content-Length: '):
72            _, value = line.split(b'Content-Length: ')
73            value = value.strip()
74            try:
75                return int(value)
76            except ValueError as e:
77                raise ValueError("Invalid Content-Length header: {}".format(value)) from e
78
79        return None
80
81
82class JsonRpcStreamWriter:
83    def __init__(self, wfile, **json_dumps_args):
84        self._wfile = wfile
85        self._wfile_lock = threading.Lock()
86        self._json_dumps_args = json_dumps_args
87
88    def close(self):
89        with self._wfile_lock:
90            self._wfile.close()
91
92    def write(self, message):
93        with self._wfile_lock:
94            if self._wfile.closed:
95                return
96            try:
97                body = json.dumps(message, **self._json_dumps_args)
98
99                # Ensure we get the byte length, not the character length
100                content_length = len(body) if isinstance(body, bytes) else len(body.encode('utf-8'))
101
102                response = (
103                    "Content-Length: {}\r\n"
104                    "Content-Type: application/vscode-jsonrpc; charset=utf8\r\n\r\n"
105                    "{}".format(content_length, body)
106                )
107
108                self._wfile.write(response.encode('utf-8'))
109                self._wfile.flush()
110            except Exception:  # pylint: disable=broad-except
111                log.exception("Failed to write message to output file %s", message)
112