1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Defines abstractions and implementations of the RPC transport used with micro TVM."""
19
20import abc
21import logging
22import string
23import subprocess
24import typing
25
26import tvm
27
28_LOG = logging.getLogger(__name__)
29
30
31@tvm.error.register_error
32class SessionTerminatedError(Exception):
33    """Raised when a transport read operationd discovers that the remote session is terminated."""
34
35
36class Transport(metaclass=abc.ABCMeta):
37    """The abstract Transport class used for micro TVM."""
38
39    def __enter__(self):
40        self.open()
41        return self
42
43    def __exit__(self, exc_type, exc_value, exc_traceback):
44        self.close()
45
46    @abc.abstractmethod
47    def open(self):
48        """Open any resources needed to send and receive RPC protocol data for a single session."""
49        raise NotImplementedError()
50
51    @abc.abstractmethod
52    def close(self):
53        """Release resources associated with this transport."""
54        raise NotImplementedError()
55
56    @abc.abstractmethod
57    def read(self, n):
58        """Read up to n bytes from the transport.
59
60        Parameters
61        ----------
62        n : int
63            Maximum number of bytes to read from the transport.
64
65        Returns
66        -------
67        bytes :
68            Data read from the channel. Less than `n` bytes may be returned, but 0 bytes should
69            never be returned except in error. Note that if a transport error occurs, an Exception
70            should be raised rather than simply returning empty bytes.
71
72
73        Raises
74        ------
75        SessionTerminatedError :
76            When the transport layer determines that the active session was terminated by the
77            remote side. Typically this indicates that the remote device has reset.
78        """
79        raise NotImplementedError()
80
81    @abc.abstractmethod
82    def write(self, data):
83        """Write data to the transport channel.
84
85        Parameters
86        ----------
87        data : bytes
88            The data to write over the channel.
89
90        Returns
91        -------
92        int :
93            The number of bytes written to the underlying channel. This can be less than the length
94            of `data`, but cannot be 0.
95        """
96        raise NotImplementedError()
97
98
99class TransportLogger(Transport):
100    """Wraps a Transport implementation and logs traffic to the Python logging infrastructure."""
101
102    def __init__(self, name, child, logger=None, level=logging.INFO):
103        self.name = name
104        self.child = child
105        self.logger = logger or _LOG
106        self.level = level
107
108    # Construct PRINTABLE to exclude whitespace from string.printable.
109    PRINTABLE = string.digits + string.ascii_letters + string.punctuation
110
111    @classmethod
112    def _to_hex(cls, data):
113        lines = []
114        if not data:
115            lines.append("")
116            return lines
117
118        for i in range(0, (len(data) + 15) // 16):
119            chunk = data[i * 16 : (i + 1) * 16]
120            hex_chunk = " ".join(f"{c:02x}" for c in chunk)
121            ascii_chunk = "".join((chr(c) if chr(c) in cls.PRINTABLE else ".") for c in chunk)
122            lines.append(f"{i * 16:04x}  {hex_chunk:47}  {ascii_chunk}")
123
124        if len(lines) == 1:
125            lines[0] = lines[0][6:]
126
127        return lines
128
129    def open(self):
130        self.logger.log(self.level, "opening transport")
131        self.child.open()
132
133    def close(self):
134        self.logger.log(self.level, "closing transport")
135        return self.child.close()
136
137    def read(self, n):
138        data = self.child.read(n)
139        hex_lines = self._to_hex(data)
140        if len(hex_lines) > 1:
141            self.logger.log(
142                self.level,
143                "%s read %4d B -> [%d B]:\n%s",
144                self.name,
145                n,
146                len(data),
147                "\n".join(hex_lines),
148            )
149        else:
150            self.logger.log(
151                self.level, "%s read %4d B -> [%d B]: %s", self.name, n, len(data), hex_lines[0]
152            )
153
154        return data
155
156    def write(self, data):
157        bytes_written = self.child.write(data)
158        hex_lines = self._to_hex(data[:bytes_written])
159        if len(hex_lines) > 1:
160            self.logger.log(
161                self.level,
162                "%s write      <- [%d B]:\n%s",
163                self.name,
164                bytes_written,
165                "\n".join(hex_lines),
166            )
167        else:
168            self.logger.log(
169                self.level, "%s write      <- [%d B]: %s", self.name, bytes_written, hex_lines[0]
170            )
171
172        return bytes_written
173
174
175class SubprocessTransport(Transport):
176    """A Transport implementation that uses a subprocess's stdin/stdout as the channel."""
177
178    def __init__(self, args, **kwargs):
179        self.args = args
180        self.kwargs = kwargs
181        self.popen = None
182
183    def open(self):
184        self.kwargs["stdout"] = subprocess.PIPE
185        self.kwargs["stdin"] = subprocess.PIPE
186        self.kwargs["bufsize"] = 0
187        self.popen = subprocess.Popen(self.args, **self.kwargs)
188        self.stdin = self.popen.stdin
189        self.stdout = self.popen.stdout
190
191    def write(self, data):
192        to_return = self.stdin.write(data)
193        self.stdin.flush()
194
195        return to_return
196
197    def read(self, n):
198        return self.stdout.read(n)
199
200    def close(self):
201        self.stdin.close()
202        self.stdout.close()
203        self.popen.terminate()
204
205
206class DebugWrapperTransport(Transport):
207    """A Transport wrapper class that launches a debugger before opening the transport.
208
209    This is primiarly useful when debugging the other end of a SubprocessTransport. It allows you
210    to pipe data through the GDB process to drive the subprocess with a debugger attached.
211    """
212
213    def __init__(self, debugger, transport):
214        self.debugger = debugger
215        self.transport = transport
216        self.debugger.on_terminate_callbacks.append(self.transport.close)
217
218    def open(self):
219        self.debugger.Start()
220
221        try:
222            self.transport.open()
223        except Exception:
224            self.debugger.Stop()
225            raise
226
227    def write(self, data):
228        return self.transport.write(data)
229
230    def read(self, n):
231        return self.transport.read(n)
232
233    def close(self):
234        self.transport.close()
235        self.debugger.Stop()
236
237
238TransportContextManager = typing.ContextManager[Transport]
239