1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""Defines abstractions and implementations of the RPC transport used with micro TVM.""" 19 20import abc 21import logging 22import string 23import subprocess 24import typing 25 26import tvm 27 28_LOG = logging.getLogger(__name__) 29 30 31@tvm.error.register_error 32class SessionTerminatedError(Exception): 33 """Raised when a transport read operationd discovers that the remote session is terminated.""" 34 35 36class Transport(metaclass=abc.ABCMeta): 37 """The abstract Transport class used for micro TVM.""" 38 39 def __enter__(self): 40 self.open() 41 return self 42 43 def __exit__(self, exc_type, exc_value, exc_traceback): 44 self.close() 45 46 @abc.abstractmethod 47 def open(self): 48 """Open any resources needed to send and receive RPC protocol data for a single session.""" 49 raise NotImplementedError() 50 51 @abc.abstractmethod 52 def close(self): 53 """Release resources associated with this transport.""" 54 raise NotImplementedError() 55 56 @abc.abstractmethod 57 def read(self, n): 58 """Read up to n bytes from the transport. 59 60 Parameters 61 ---------- 62 n : int 63 Maximum number of bytes to read from the transport. 64 65 Returns 66 ------- 67 bytes : 68 Data read from the channel. Less than `n` bytes may be returned, but 0 bytes should 69 never be returned except in error. Note that if a transport error occurs, an Exception 70 should be raised rather than simply returning empty bytes. 71 72 73 Raises 74 ------ 75 SessionTerminatedError : 76 When the transport layer determines that the active session was terminated by the 77 remote side. Typically this indicates that the remote device has reset. 78 """ 79 raise NotImplementedError() 80 81 @abc.abstractmethod 82 def write(self, data): 83 """Write data to the transport channel. 84 85 Parameters 86 ---------- 87 data : bytes 88 The data to write over the channel. 89 90 Returns 91 ------- 92 int : 93 The number of bytes written to the underlying channel. This can be less than the length 94 of `data`, but cannot be 0. 95 """ 96 raise NotImplementedError() 97 98 99class TransportLogger(Transport): 100 """Wraps a Transport implementation and logs traffic to the Python logging infrastructure.""" 101 102 def __init__(self, name, child, logger=None, level=logging.INFO): 103 self.name = name 104 self.child = child 105 self.logger = logger or _LOG 106 self.level = level 107 108 # Construct PRINTABLE to exclude whitespace from string.printable. 109 PRINTABLE = string.digits + string.ascii_letters + string.punctuation 110 111 @classmethod 112 def _to_hex(cls, data): 113 lines = [] 114 if not data: 115 lines.append("") 116 return lines 117 118 for i in range(0, (len(data) + 15) // 16): 119 chunk = data[i * 16 : (i + 1) * 16] 120 hex_chunk = " ".join(f"{c:02x}" for c in chunk) 121 ascii_chunk = "".join((chr(c) if chr(c) in cls.PRINTABLE else ".") for c in chunk) 122 lines.append(f"{i * 16:04x} {hex_chunk:47} {ascii_chunk}") 123 124 if len(lines) == 1: 125 lines[0] = lines[0][6:] 126 127 return lines 128 129 def open(self): 130 self.logger.log(self.level, "opening transport") 131 self.child.open() 132 133 def close(self): 134 self.logger.log(self.level, "closing transport") 135 return self.child.close() 136 137 def read(self, n): 138 data = self.child.read(n) 139 hex_lines = self._to_hex(data) 140 if len(hex_lines) > 1: 141 self.logger.log( 142 self.level, 143 "%s read %4d B -> [%d B]:\n%s", 144 self.name, 145 n, 146 len(data), 147 "\n".join(hex_lines), 148 ) 149 else: 150 self.logger.log( 151 self.level, "%s read %4d B -> [%d B]: %s", self.name, n, len(data), hex_lines[0] 152 ) 153 154 return data 155 156 def write(self, data): 157 bytes_written = self.child.write(data) 158 hex_lines = self._to_hex(data[:bytes_written]) 159 if len(hex_lines) > 1: 160 self.logger.log( 161 self.level, 162 "%s write <- [%d B]:\n%s", 163 self.name, 164 bytes_written, 165 "\n".join(hex_lines), 166 ) 167 else: 168 self.logger.log( 169 self.level, "%s write <- [%d B]: %s", self.name, bytes_written, hex_lines[0] 170 ) 171 172 return bytes_written 173 174 175class SubprocessTransport(Transport): 176 """A Transport implementation that uses a subprocess's stdin/stdout as the channel.""" 177 178 def __init__(self, args, **kwargs): 179 self.args = args 180 self.kwargs = kwargs 181 self.popen = None 182 183 def open(self): 184 self.kwargs["stdout"] = subprocess.PIPE 185 self.kwargs["stdin"] = subprocess.PIPE 186 self.kwargs["bufsize"] = 0 187 self.popen = subprocess.Popen(self.args, **self.kwargs) 188 self.stdin = self.popen.stdin 189 self.stdout = self.popen.stdout 190 191 def write(self, data): 192 to_return = self.stdin.write(data) 193 self.stdin.flush() 194 195 return to_return 196 197 def read(self, n): 198 return self.stdout.read(n) 199 200 def close(self): 201 self.stdin.close() 202 self.stdout.close() 203 self.popen.terminate() 204 205 206class DebugWrapperTransport(Transport): 207 """A Transport wrapper class that launches a debugger before opening the transport. 208 209 This is primiarly useful when debugging the other end of a SubprocessTransport. It allows you 210 to pipe data through the GDB process to drive the subprocess with a debugger attached. 211 """ 212 213 def __init__(self, debugger, transport): 214 self.debugger = debugger 215 self.transport = transport 216 self.debugger.on_terminate_callbacks.append(self.transport.close) 217 218 def open(self): 219 self.debugger.Start() 220 221 try: 222 self.transport.open() 223 except Exception: 224 self.debugger.Stop() 225 raise 226 227 def write(self, data): 228 return self.transport.write(data) 229 230 def read(self, n): 231 return self.transport.read(n) 232 233 def close(self): 234 self.transport.close() 235 self.debugger.Stop() 236 237 238TransportContextManager = typing.ContextManager[Transport] 239