1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Container of compiled functions of TVM."""
18from __future__ import absolute_import as _abs
19
20import struct
21from collections import namedtuple
22
23from ._ffi.function import ModuleBase, _set_class_module
24from ._ffi.function import _init_api
25from ._ffi.libinfo import find_include_path
26from .contrib import cc as _cc, tar as _tar, util as _util
27
28ProfileResult = namedtuple("ProfileResult", ["mean", "results"])
29
30
31class Module(ModuleBase):
32    """Module container of all TVM generated functions"""
33
34    def __repr__(self):
35        return "Module(%s, %x)" % (self.type_key, self.handle.value)
36
37    @property
38    def type_key(self):
39        """Get type key of the module."""
40        return _GetTypeKey(self)
41
42    def get_source(self, fmt=""):
43        """Get source code from module, if available.
44
45        Parameters
46        ----------
47        fmt : str, optional
48            The specified format.
49
50        Returns
51        -------
52        source : str
53            The result source code.
54        """
55        return _GetSource(self, fmt)
56
57    @property
58    def imported_modules(self):
59        """Get imported modules
60
61        Returns
62        ----------
63        modules : list of Module
64            The module
65        """
66        nmod = _ImportsSize(self)
67        return [_GetImport(self, i) for i in range(nmod)]
68
69    def save(self, file_name, fmt=""):
70        """Save the module to file.
71
72        This do not save the dependent device modules.
73        See also export_shared
74
75        Parameters
76        ----------
77        file_name : str
78            The name of the file.
79        fmt : str
80            The format of the file.
81
82        See Also
83        --------
84        Module.export_library : export the module to shared library.
85        """
86        _SaveToFile(self, file_name, fmt)
87
88    def export_library(self,
89                       file_name,
90                       fcompile=None,
91                       **kwargs):
92        """Export the module and its imported device code one library.
93
94        This function only works on host llvm modules.
95        It will pack all the imported modules
96
97        Parameters
98        ----------
99        file_name : str
100            The name of the shared library.
101
102        fcompile : function(target, file_list, kwargs), optional
103            Compilation function to use create dynamic library.
104            If fcompile has attribute object_format, will compile host library
105            to that format. Otherwise, will use default format "o".
106
107        kwargs : dict, optional
108            Additional arguments passed to fcompile
109        """
110        from pathlib import Path
111        if isinstance(file_name, Path):
112            file_name = str(file_name)
113
114        if self.type_key == "stackvm":
115            if not file_name.endswith(".stackvm"):
116                raise ValueError("Module[%s]: can only be saved as stackvm format."
117                                 "did you build with LLVM enabled?" % self.type_key)
118            self.save(file_name)
119            return
120
121        if not (self.type_key == "llvm" or self.type_key == "c"):
122            raise ValueError("Module[%s]: Only llvm and c support export shared" % self.type_key)
123        temp = _util.tempdir()
124        if fcompile is not None and hasattr(fcompile, "object_format"):
125            object_format = fcompile.object_format
126        else:
127            if self.type_key == "llvm":
128                object_format = "o"
129            else:
130                assert self.type_key == "c"
131                object_format = "cc"
132        path_obj = temp.relpath("lib." + object_format)
133        self.save(path_obj)
134        files = [path_obj]
135        is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")()
136        if self.imported_modules:
137            path_cc = temp.relpath("devc.cc")
138            with open(path_cc, "w") as f:
139                f.write(_PackImportsToC(self, is_system_lib))
140            files.append(path_cc)
141        if not fcompile:
142            if file_name.endswith(".tar"):
143                fcompile = _tar.tar
144            else:
145                fcompile = _cc.create_shared
146        if self.type_key == "c":
147            options = []
148            if "options" in kwargs:
149                opts = kwargs["options"]
150                options = opts if isinstance(opts, (list, tuple)) else [opts]
151            opts = options + ["-I" + path for path in find_include_path()]
152            kwargs.update({'options': opts})
153        fcompile(file_name, files, **kwargs)
154
155    def time_evaluator(self, func_name, ctx, number=10, repeat=1, min_repeat_ms=0):
156        """Get an evaluator that measures time cost of running function.
157
158        Parameters
159        ----------
160        func_name: str
161            The name of the function in the module.
162
163        ctx: TVMContext
164            The context we should run this function on.
165
166        number: int
167            The number of times to run this function for taking average.
168            We call these runs as one `repeat` of measurement.
169
170        repeat: int, optional
171            The number of times to repeat the measurement.
172            In total, the function will be invoked (1 + number x repeat) times,
173            where the first one is warm up and will be discarded.
174            The returned result contains `repeat` costs,
175            each of which is an average of `number` costs.
176
177        min_repeat_ms: int, optional
178            The minimum duration of one `repeat` in milliseconds.
179            By default, one `repeat` contains `number` runs. If this parameter is set,
180            the parameters `number` will be dynamically adjusted to meet the
181            minimum duration requirement of one `repeat`.
182            i.e., When the run time of one `repeat` falls below this time, the `number` parameter
183            will be automatically increased.
184
185        Note
186        ----
187        The function will be invoked  (1 + number x repeat) times,
188        with the first call discarded in case there is lazy initialization.
189
190        Returns
191        -------
192        ftimer : Function
193            The function that takes same argument as func and returns a ProfileResult.
194            The ProfileResult reports `repeat` time costs in seconds.
195        """
196        try:
197            feval = _RPCTimeEvaluator(
198                self, func_name, ctx.device_type, ctx.device_id, number, repeat, min_repeat_ms)
199
200            def evaluator(*args):
201                """Internal wrapped evaluator."""
202                # Wrap feval so we can add more stats in future.
203                blob = feval(*args)
204                fmt = "@" + ("d" * repeat)
205                results = struct.unpack(fmt, blob)
206                mean = sum(results) / float(repeat)
207                return ProfileResult(mean=mean, results=results)
208
209            return evaluator
210        except NameError:
211            raise NameError("time_evaluate is only supported when RPC is enabled")
212
213
214def system_lib():
215    """Get system-wide library module singleton.
216
217    System lib is a global module that contains self register functions in startup.
218    Unlike normal dso modules which need to be loaded explicitly.
219    It is useful in environments where dynamic loading api like dlopen is banned.
220
221    To build system lib function, simply specify target option ```llvm --system-lib```
222    The system lib will be available as long as the result code is linked by the program.
223
224    The system lib is intended to be linked and loaded during the entire life-cyle of the program.
225    If you want dynamic loading features, use dso modules instead.
226
227    Returns
228    -------
229    module : Module
230        The system-wide library module.
231    """
232    return _GetSystemLib()
233
234
235def load(path, fmt=""):
236    """Load module from file.
237
238    Parameters
239    ----------
240    path : str
241        The path to the module file.
242
243    fmt : str, optional
244        The format of the file, if not specified
245        it will be inferred from suffix of the file.
246
247    Returns
248    -------
249    module : Module
250        The loaded module
251
252    Note
253    ----
254    This function will automatically call
255    cc.create_shared if the path is in format .o or .tar
256    """
257    # High level handling for .o and .tar file.
258    # We support this to be consistent with RPC module load.
259    if path.endswith(".o"):
260        _cc.create_shared(path + ".so", path)
261        path += ".so"
262    elif path.endswith(".tar"):
263        tar_temp = _util.tempdir(custom_path=path.replace('.tar', ''))
264        _tar.untar(path, tar_temp.temp_dir)
265        files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
266        _cc.create_shared(path + ".so", files)
267        path += ".so"
268    # Redirect to the load API
269    return _LoadFromFile(path, fmt)
270
271
272def enabled(target):
273    """Whether module runtime is enabled for target
274
275    Parameters
276    ----------
277    target : str
278        The target device type.
279
280    Returns
281    -------
282    enabled : bool
283        Whether runtime is enabled.
284
285    Examples
286    --------
287    The following code checks if gpu is enabled.
288
289    >>> tvm.module.enabled("gpu")
290    """
291    return _Enabled(target)
292
293
294_init_api("tvm.module")
295_set_class_module(Module)
296