1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""Defines interfaces and default implementations for compiling and flashing code.""" 19 20import abc 21import glob 22import os 23import re 24 25from tvm.contrib import binutil 26import tvm.target 27from . import build 28from . import class_factory 29from . import debugger 30from . import transport 31 32 33class DetectTargetError(Exception): 34 """Raised when no target comment was detected in the sources given.""" 35 36 37class NoDefaultToolchainMatchedError(Exception): 38 """Raised when no default toolchain matches the target string.""" 39 40 41class Compiler(metaclass=abc.ABCMeta): 42 """The compiler abstraction used with micro TVM.""" 43 44 TVM_TARGET_RE = re.compile(r"^// tvm target: (.*)$") 45 46 @classmethod 47 def _target_from_sources(cls, sources): 48 """Determine the target used to generate the given source files. 49 50 Parameters 51 ---------- 52 sources : List[str] 53 The paths to source files to analyze. 54 55 Returns 56 ------- 57 tvm.target.Target : 58 A Target instance reconstructed from the target string listed in the source files. 59 """ 60 target_strs = set() 61 62 for obj in sources: 63 with open(obj) as obj_f: 64 for line in obj_f: 65 m = cls.TVM_TARGET_RE.match(line) 66 if m: 67 target_strs.add(m.group(1)) 68 69 if len(target_strs) != 1: 70 raise DetectTargetError( 71 "autodetecting cross-compiler: could not extract TVM target from C source; regex " 72 f"{cls.TVM_TARGET_RE.pattern} does not match any line in sources: " 73 f'{", ".join(sources)}' 74 ) 75 76 target_str = next(iter(target_strs)) 77 return tvm.target.create(target_str) 78 79 # Maps regexes identifying CPUs to the default toolchain prefix for that CPU. 80 TOOLCHAIN_PREFIX_BY_CPU_REGEX = { 81 r"cortex-[am].*": "arm-none-eabi-", 82 "x86[_-]64": "", 83 "native": "", 84 } 85 86 def _autodetect_toolchain_prefix(self, target): 87 matches = [] 88 for regex, prefix in self.TOOLCHAIN_PREFIX_BY_CPU_REGEX.items(): 89 if re.match(regex, target.attrs["mcpu"]): 90 matches.append(prefix) 91 92 if matches: 93 if len(matches) != 1: 94 raise NoDefaultToolchainMatchedError( 95 f'{opt} matched more than 1 default toolchain prefix: {", ".join(matches)}. ' 96 "Specify cc.cross_compiler to create_micro_library()" 97 ) 98 99 return matches[0] 100 101 raise NoDefaultToolchainMatchedError( 102 f"target {str(target)} did not match any default toolchains" 103 ) 104 105 def _defaults_from_target(self, target): 106 """Determine the default compiler options from the target specified. 107 108 Parameters 109 ---------- 110 target : tvm.target.Target 111 112 Returns 113 ------- 114 List[str] : 115 Default options used the configure the compiler for that target. 116 """ 117 opts = [] 118 # TODO use march for arm(https://gcc.gnu.org/onlinedocs/gcc/ARM-Options.html)? 119 if target.attrs.get("mcpu"): 120 opts.append(f'-march={target.attrs["mcpu"]}') 121 if target.attrs.get("mfpu"): 122 opts.append(f'-mfpu={target.attrs["mfpu"]}') 123 124 return opts 125 126 @abc.abstractmethod 127 def library(self, output, sources, options=None): 128 """Build a library from the given source files. 129 130 Parameters 131 ---------- 132 output : str 133 The path to the library that should be created. The containing directory 134 is guaranteed to be empty and should be the base_dir for the returned 135 Artifact. 136 sources : List[str] 137 A list of paths to source files that should be compiled. 138 options : Optional[List[str]] 139 If given, additional command-line flags to pass to the compiler. 140 141 Returns 142 ------- 143 MicroLibrary : 144 The compiled library, as a MicroLibrary instance. 145 """ 146 raise NotImplementedError() 147 148 @abc.abstractmethod 149 def binary(self, output, objects, options=None, link_main=True, main_options=None): 150 """Link a binary from the given object and/or source files. 151 152 Parameters 153 ---------- 154 output : str 155 The path to the binary that should be created. The containing directory 156 is guaranteed to be empty and should be the base_dir for the returned 157 Artifact. 158 objects : List[MicroLibrary] 159 A list of paths to source files or libraries that should be compiled. The final binary 160 should be statically-linked. 161 options: Optional[List[str]] 162 If given, additional command-line flags to pass to the compiler. 163 link_main: Optional[bool] 164 True if the standard main entry point for this Compiler should be included in the 165 binary. False if a main entry point is provided in one of `objects`. 166 main_options: Optional[List[str]] 167 If given, additional command-line flags to pass to the compiler when compiling the 168 main() library. In some cases, the main() may be compiled directly into the final binary 169 along with `objects` for logistical reasons. In those cases, specifying main_options is 170 an error and ValueError will be raised. 171 172 Returns 173 ------- 174 MicroBinary : 175 The compiled binary, as a MicroBinary instance. 176 """ 177 raise NotImplementedError() 178 179 @property 180 def flasher_factory(self): 181 """Produce a FlasherFactory for a Flasher instance suitable for this Compiler.""" 182 raise NotImplementedError("The Compiler base class doesn't define a flasher.") 183 184 def flasher(self, **kw): 185 """Return a Flasher that can be used to program a produced MicroBinary onto the target.""" 186 return self.flasher_factory.override_kw(**kw).instantiate() 187 188 189class IncompatibleTargetError(Exception): 190 """Raised when source files specify a target that differs from the compiler target.""" 191 192 193class DefaultCompiler(Compiler): 194 """A Compiler implementation that attempts to use the system-installed GCC.""" 195 196 def __init__(self, target=None): 197 super(DefaultCompiler, self).__init__() 198 self.target = target 199 if isinstance(target, str): 200 self.target = tvm.target.create(target) 201 202 def library(self, output, sources, options=None): 203 options = options if options is not None else {} 204 try: 205 target = self._target_from_sources(sources) 206 except DetectTargetError: 207 assert self.target is not None, ( 208 "Must specify target= to constructor when compiling sources which don't specify a " 209 "target" 210 ) 211 212 target = self.target 213 214 if self.target is not None and str(self.target) != str(target): 215 raise IncompatibleTargetError( 216 f"auto-detected target {target} differs from configured {self.target}" 217 ) 218 219 prefix = self._autodetect_toolchain_prefix(target) 220 outputs = [] 221 for src in sources: 222 src_base, src_ext = os.path.splitext(os.path.basename(src)) 223 224 compiler_name = {".c": "gcc", ".cc": "g++", ".cpp": "g++"}[src_ext] 225 args = [prefix + compiler_name, "-g"] 226 args.extend(self._defaults_from_target(target)) 227 228 args.extend(options.get(f"{src_ext[1:]}flags", [])) 229 230 for include_dir in options.get("include_dirs", []): 231 args.extend(["-I", include_dir]) 232 233 output_filename = f"{src_base}.o" 234 output_abspath = os.path.join(output, output_filename) 235 binutil.run_cmd(args + ["-c", "-o", output_abspath, src]) 236 outputs.append(output_abspath) 237 238 output_filename = f"{os.path.basename(output)}.a" 239 output_abspath = os.path.join(output, output_filename) 240 binutil.run_cmd([prefix + "ar", "-r", output_abspath] + outputs) 241 binutil.run_cmd([prefix + "ranlib", output_abspath]) 242 243 return tvm.micro.MicroLibrary(output, [output_filename]) 244 245 def binary(self, output, objects, options=None, link_main=True, main_options=None): 246 assert self.target is not None, ( 247 "must specify target= to constructor, or compile sources which specify the target " 248 "first" 249 ) 250 251 args = [self._autodetect_toolchain_prefix(self.target) + "g++"] 252 args.extend(self._defaults_from_target(self.target)) 253 if options is not None: 254 args.extend(options.get("ldflags", [])) 255 256 for include_dir in options.get("include_dirs", []): 257 args.extend(["-I", include_dir]) 258 259 output_filename = os.path.basename(output) 260 output_abspath = os.path.join(output, output_filename) 261 args.extend(["-g", "-o", output_abspath]) 262 263 if link_main: 264 host_main_srcs = glob.glob(os.path.join(build.CRT_ROOT_DIR, "host", "*.cc")) 265 if main_options: 266 main_lib = self.library(os.path.join(output, "host"), host_main_srcs, main_options) 267 for lib_name in main_lib.library_files: 268 args.append(main_lib.abspath(lib_name)) 269 else: 270 args.extend(host_main_srcs) 271 272 for obj in objects: 273 for lib_name in obj.library_files: 274 args.append(obj.abspath(lib_name)) 275 276 binutil.run_cmd(args) 277 return tvm.micro.MicroBinary(output, output_filename, []) 278 279 @property 280 def flasher_factory(self): 281 return FlasherFactory(HostFlasher, [], {}) 282 283 284class Flasher(metaclass=abc.ABCMeta): 285 """An interface for flashing binaries and returning a transport factory.""" 286 287 @abc.abstractmethod 288 def flash(self, micro_binary): 289 """Flash a binary onto the device. 290 291 Parameters 292 ---------- 293 micro_binary : MicroBinary 294 A MicroBinary instance. 295 296 Returns 297 ------- 298 transport.TransportContextManager : 299 A ContextManager that can be used to create and tear down an RPC transport layer between 300 this TVM instance and the newly-flashed binary. 301 """ 302 raise NotImplementedError() 303 304 305class FlasherFactory(class_factory.ClassFactory): 306 """A ClassFactory for Flasher instances.""" 307 308 SUPERCLASS = Flasher 309 310 311class HostFlasher(Flasher): 312 """A Flasher implementation that spawns a subprocess on the host.""" 313 314 def __init__(self, debug=False): 315 self.debug = debug 316 317 def flash(self, micro_binary): 318 if self.debug: 319 gdb_wrapper = debugger.GdbTransportDebugger( 320 [micro_binary.abspath(micro_binary.binary_file)] 321 ) 322 return transport.DebugWrapperTransport( 323 debugger=gdb_wrapper, transport=gdb_wrapper.Transport() 324 ) 325 326 return transport.SubprocessTransport([micro_binary.abspath(micro_binary.binary_file)]) 327