1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Defines interfaces and default implementations for compiling and flashing code."""
19
20import abc
21import glob
22import os
23import re
24
25from tvm.contrib import binutil
26import tvm.target
27from . import build
28from . import class_factory
29from . import debugger
30from . import transport
31
32
33class DetectTargetError(Exception):
34    """Raised when no target comment was detected in the sources given."""
35
36
37class NoDefaultToolchainMatchedError(Exception):
38    """Raised when no default toolchain matches the target string."""
39
40
41class Compiler(metaclass=abc.ABCMeta):
42    """The compiler abstraction used with micro TVM."""
43
44    TVM_TARGET_RE = re.compile(r"^// tvm target: (.*)$")
45
46    @classmethod
47    def _target_from_sources(cls, sources):
48        """Determine the target used to generate the given source files.
49
50        Parameters
51        ----------
52        sources : List[str]
53            The paths to source files to analyze.
54
55        Returns
56        -------
57        tvm.target.Target :
58            A Target instance reconstructed from the target string listed in the source files.
59        """
60        target_strs = set()
61
62        for obj in sources:
63            with open(obj) as obj_f:
64                for line in obj_f:
65                    m = cls.TVM_TARGET_RE.match(line)
66                    if m:
67                        target_strs.add(m.group(1))
68
69        if len(target_strs) != 1:
70            raise DetectTargetError(
71                "autodetecting cross-compiler: could not extract TVM target from C source; regex "
72                f"{cls.TVM_TARGET_RE.pattern} does not match any line in sources: "
73                f'{", ".join(sources)}'
74            )
75
76        target_str = next(iter(target_strs))
77        return tvm.target.create(target_str)
78
79    # Maps regexes identifying CPUs to the default toolchain prefix for that CPU.
80    TOOLCHAIN_PREFIX_BY_CPU_REGEX = {
81        r"cortex-[am].*": "arm-none-eabi-",
82        "x86[_-]64": "",
83        "native": "",
84    }
85
86    def _autodetect_toolchain_prefix(self, target):
87        matches = []
88        for regex, prefix in self.TOOLCHAIN_PREFIX_BY_CPU_REGEX.items():
89            if re.match(regex, target.attrs["mcpu"]):
90                matches.append(prefix)
91
92        if matches:
93            if len(matches) != 1:
94                raise NoDefaultToolchainMatchedError(
95                    f'{opt} matched more than 1 default toolchain prefix: {", ".join(matches)}. '
96                    "Specify cc.cross_compiler to create_micro_library()"
97                )
98
99            return matches[0]
100
101        raise NoDefaultToolchainMatchedError(
102            f"target {str(target)} did not match any default toolchains"
103        )
104
105    def _defaults_from_target(self, target):
106        """Determine the default compiler options from the target specified.
107
108        Parameters
109        ----------
110        target : tvm.target.Target
111
112        Returns
113        -------
114        List[str] :
115            Default options used the configure the compiler for that target.
116        """
117        opts = []
118        # TODO use march for arm(https://gcc.gnu.org/onlinedocs/gcc/ARM-Options.html)?
119        if target.attrs.get("mcpu"):
120            opts.append(f'-march={target.attrs["mcpu"]}')
121        if target.attrs.get("mfpu"):
122            opts.append(f'-mfpu={target.attrs["mfpu"]}')
123
124        return opts
125
126    @abc.abstractmethod
127    def library(self, output, sources, options=None):
128        """Build a library from the given source files.
129
130        Parameters
131        ----------
132        output : str
133            The path to the library that should be created. The containing directory
134            is guaranteed to be empty and should be the base_dir for the returned
135            Artifact.
136        sources : List[str]
137            A list of paths to source files that should be compiled.
138        options : Optional[List[str]]
139            If given, additional command-line flags to pass to the compiler.
140
141        Returns
142        -------
143        MicroLibrary :
144            The compiled library, as a MicroLibrary instance.
145        """
146        raise NotImplementedError()
147
148    @abc.abstractmethod
149    def binary(self, output, objects, options=None, link_main=True, main_options=None):
150        """Link a binary from the given object and/or source files.
151
152        Parameters
153        ----------
154        output : str
155            The path to the binary that should be created. The containing directory
156            is guaranteed to be empty and should be the base_dir for the returned
157            Artifact.
158        objects : List[MicroLibrary]
159            A list of paths to source files or libraries that should be compiled. The final binary
160            should be statically-linked.
161        options: Optional[List[str]]
162            If given, additional command-line flags to pass to the compiler.
163        link_main: Optional[bool]
164            True if the standard main entry point for this Compiler should be included in the
165            binary. False if a main entry point is provided in one of `objects`.
166        main_options: Optional[List[str]]
167            If given, additional command-line flags to pass to the compiler when compiling the
168            main() library. In some cases, the main() may be compiled directly into the final binary
169            along with `objects` for logistical reasons. In those cases, specifying main_options is
170            an error and ValueError will be raised.
171
172        Returns
173        -------
174        MicroBinary :
175            The compiled binary, as a MicroBinary instance.
176        """
177        raise NotImplementedError()
178
179    @property
180    def flasher_factory(self):
181        """Produce a FlasherFactory for a Flasher instance suitable for this Compiler."""
182        raise NotImplementedError("The Compiler base class doesn't define a flasher.")
183
184    def flasher(self, **kw):
185        """Return a Flasher that can be used to program a produced MicroBinary onto the target."""
186        return self.flasher_factory.override_kw(**kw).instantiate()
187
188
189class IncompatibleTargetError(Exception):
190    """Raised when source files specify a target that differs from the compiler target."""
191
192
193class DefaultCompiler(Compiler):
194    """A Compiler implementation that attempts to use the system-installed GCC."""
195
196    def __init__(self, target=None):
197        super(DefaultCompiler, self).__init__()
198        self.target = target
199        if isinstance(target, str):
200            self.target = tvm.target.create(target)
201
202    def library(self, output, sources, options=None):
203        options = options if options is not None else {}
204        try:
205            target = self._target_from_sources(sources)
206        except DetectTargetError:
207            assert self.target is not None, (
208                "Must specify target= to constructor when compiling sources which don't specify a "
209                "target"
210            )
211
212            target = self.target
213
214        if self.target is not None and str(self.target) != str(target):
215            raise IncompatibleTargetError(
216                f"auto-detected target {target} differs from configured {self.target}"
217            )
218
219        prefix = self._autodetect_toolchain_prefix(target)
220        outputs = []
221        for src in sources:
222            src_base, src_ext = os.path.splitext(os.path.basename(src))
223
224            compiler_name = {".c": "gcc", ".cc": "g++", ".cpp": "g++"}[src_ext]
225            args = [prefix + compiler_name, "-g"]
226            args.extend(self._defaults_from_target(target))
227
228            args.extend(options.get(f"{src_ext[1:]}flags", []))
229
230            for include_dir in options.get("include_dirs", []):
231                args.extend(["-I", include_dir])
232
233            output_filename = f"{src_base}.o"
234            output_abspath = os.path.join(output, output_filename)
235            binutil.run_cmd(args + ["-c", "-o", output_abspath, src])
236            outputs.append(output_abspath)
237
238        output_filename = f"{os.path.basename(output)}.a"
239        output_abspath = os.path.join(output, output_filename)
240        binutil.run_cmd([prefix + "ar", "-r", output_abspath] + outputs)
241        binutil.run_cmd([prefix + "ranlib", output_abspath])
242
243        return tvm.micro.MicroLibrary(output, [output_filename])
244
245    def binary(self, output, objects, options=None, link_main=True, main_options=None):
246        assert self.target is not None, (
247            "must specify target= to constructor, or compile sources which specify the target "
248            "first"
249        )
250
251        args = [self._autodetect_toolchain_prefix(self.target) + "g++"]
252        args.extend(self._defaults_from_target(self.target))
253        if options is not None:
254            args.extend(options.get("ldflags", []))
255
256            for include_dir in options.get("include_dirs", []):
257                args.extend(["-I", include_dir])
258
259        output_filename = os.path.basename(output)
260        output_abspath = os.path.join(output, output_filename)
261        args.extend(["-g", "-o", output_abspath])
262
263        if link_main:
264            host_main_srcs = glob.glob(os.path.join(build.CRT_ROOT_DIR, "host", "*.cc"))
265            if main_options:
266                main_lib = self.library(os.path.join(output, "host"), host_main_srcs, main_options)
267                for lib_name in main_lib.library_files:
268                    args.append(main_lib.abspath(lib_name))
269            else:
270                args.extend(host_main_srcs)
271
272        for obj in objects:
273            for lib_name in obj.library_files:
274                args.append(obj.abspath(lib_name))
275
276        binutil.run_cmd(args)
277        return tvm.micro.MicroBinary(output, output_filename, [])
278
279    @property
280    def flasher_factory(self):
281        return FlasherFactory(HostFlasher, [], {})
282
283
284class Flasher(metaclass=abc.ABCMeta):
285    """An interface for flashing binaries and returning a transport factory."""
286
287    @abc.abstractmethod
288    def flash(self, micro_binary):
289        """Flash a binary onto the device.
290
291        Parameters
292        ----------
293        micro_binary : MicroBinary
294            A MicroBinary instance.
295
296        Returns
297        -------
298        transport.TransportContextManager :
299            A ContextManager that can be used to create and tear down an RPC transport layer between
300            this TVM instance and the newly-flashed binary.
301        """
302        raise NotImplementedError()
303
304
305class FlasherFactory(class_factory.ClassFactory):
306    """A ClassFactory for Flasher instances."""
307
308    SUPERCLASS = Flasher
309
310
311class HostFlasher(Flasher):
312    """A Flasher implementation that spawns a subprocess on the host."""
313
314    def __init__(self, debug=False):
315        self.debug = debug
316
317    def flash(self, micro_binary):
318        if self.debug:
319            gdb_wrapper = debugger.GdbTransportDebugger(
320                [micro_binary.abspath(micro_binary.binary_file)]
321            )
322            return transport.DebugWrapperTransport(
323                debugger=gdb_wrapper, transport=gdb_wrapper.Transport()
324            )
325
326        return transport.SubprocessTransport([micro_binary.abspath(micro_binary.binary_file)])
327