1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""The build utils in python.
18
19This module provides the functions to transform schedule to
20LoweredFunc and compiled Module.
21"""
22from __future__ import absolute_import as _abs
23import warnings
24
25from ._ffi.function import Function
26from ._ffi.node import NodeBase, register_node
27from . import api
28from . import _api_internal
29from . import tensor
30from . import schedule
31from . import expr
32from . import ir_pass
33from . import stmt as _stmt
34from . import container
35from . import module
36from . import codegen
37from . import ndarray
38from . import target as _target
39from . import make
40
41class DumpIR(object):
42    """
43    Dump IR for each pass.
44    With it, you can dump ir just like gcc/llvm.
45
46    How to use:
47    -----------
48    .. code-block:: python
49
50        with tvm.build_config(dump_pass_ir=True)
51            run()
52    """
53    scope_level = 0
54    def __init__(self):
55        self._pass_id = 0
56        self._recover_list = []
57
58    def decorate(self, func):
59        """ decorate the pass function"""
60        def dump(*args, **kwargs):
61            """dump function"""
62            retv = func(*args, **kwargs)
63            if not isinstance(retv, (_stmt.Stmt, container.LoweredFunc, container.Array)):
64                return retv
65            fname = func.func_name if hasattr(func, 'func_name') else func.__name__
66            pname = str(self._pass_id) + "_" + fname + "_ir.cc"
67            with open(pname, "a") as f:
68                out = retv.body if isinstance(retv, container.LoweredFunc) else retv
69                f.write(str(out))
70                if isinstance(retv, container.Array):
71                    for x in retv:
72                        out = x.body if isinstance(x, container.LoweredFunc) else x
73                        f.write("---------%s\n%s\n-----------\n"%(x.name, str(out)))
74                self._pass_id += 1
75            return retv
76        return dump
77
78    def decorate_irpass(self):
79        """decorate ir_pass and ScheduleOps"""
80        self._old_sgpass = schedule.ScheduleOps
81        schedule.ScheduleOps = self.decorate(schedule.ScheduleOps)
82        vset = vars(ir_pass)
83        k = v = 0
84        def recover():
85            vset[k] = v
86        for k, v in vset.items():
87            self._recover_list.append(recover)
88            vset[k] = self.decorate(v) if isinstance(v, Function) else v
89
90    def decorate_custompass(self, custom_pass):
91        """decorate given list of custom passes, and return decorated passes"""
92        custom_pass = custom_pass if custom_pass else []
93        pass_list = []
94        for idx, x in enumerate(custom_pass):
95            x[1].__name__ = "custom{}_phase{}".format(idx, x[0])
96            pass_list += [(x[0], self.decorate(x[1]))]
97        return pass_list
98
99    def enter(self):
100        """only decorate outermost nest"""
101        if DumpIR.scope_level > 0:
102            return
103        self.decorate_irpass()
104        self._pass_id = 0
105        DumpIR.scope_level += 1
106
107    def exit(self):
108        """recover outermost nest"""
109        if DumpIR.scope_level > 1:
110            return
111        # recover decorated functions
112        for f in self._recover_list:
113            f()
114        schedule.ScheduleOps = self._old_sgpass
115        DumpIR.scope_level -= 1
116
117
118@register_node
119class BuildConfig(NodeBase):
120    """Configuration scope to set a build config option.
121
122    Note
123    ----
124    This object is backed by node system in C++, with arguments that can be
125    exchanged between python and C++.
126
127    Do not construct directly, use build_config instead.
128
129    The fields that are backed by the C++ node are immutable once an instance
130    is constructed. See _node_defaults for the fields.
131    """
132
133    _node_defaults = {
134        "auto_unroll_max_step": 0,
135        "auto_unroll_max_depth": 8,
136        "auto_unroll_max_extent": 0,
137        "unroll_explicit": True,
138        "detect_global_barrier": False,
139        "partition_const_loop": False,
140        "offset_factor": 0,
141        "data_alignment": -1,
142        "restricted_func": True,
143        "double_buffer_split_loop": 1,
144        "dump_pass_ir": False,
145        "instrument_bound_checkers": False,
146        "disable_select_rewriting": False,
147        "disable_vectorize": False,
148        "disable_assert": False
149    }
150    _dump_ir = DumpIR()
151
152    # pylint: disable=no-member
153    def __init__(self, handle):
154        """Initialize the function with handle
155
156        Parameters
157        ----------
158        handle : SymbolHandle
159            the handle to the underlying C++ Symbol
160        """
161        super(BuildConfig, self).__init__(handle)
162        self.handle = handle
163
164    @property
165    def add_lower_pass(self):
166        size = _api_internal._BuildConfigGetAddLowerPassInfo(self)
167        result = []
168        for i in range(size):
169            phase = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, True)
170            func = _api_internal._BuildConfigGetAddLowerPassInfo(self, i, False)
171            result += [(phase, func)]
172        return result
173
174    @add_lower_pass.setter
175    def add_lower_pass(self, value):
176        add_lower_pass_args = []
177        for x in value:
178            add_lower_pass_args += [x[0], x[1]]
179        _api_internal._BuildConfigSetAddLowerPass(self, *add_lower_pass_args)
180
181    def __enter__(self):
182        # pylint: disable=protected-access
183        _api_internal._EnterBuildConfigScope(self)
184        if self.dump_pass_ir:
185            BuildConfig._dump_ir.enter()
186        return self
187
188    def __exit__(self, ptype, value, trace):
189        if self.dump_pass_ir:
190            BuildConfig._dump_ir.exit()
191        _api_internal._ExitBuildConfigScope(self)
192
193    def __setattr__(self, name, value):
194        if name in BuildConfig._node_defaults:
195            raise AttributeError(
196                "'%s' object cannot set attribute '%s'" % (str(type(self)), name))
197        return super(BuildConfig, self).__setattr__(name, value)
198
199
200def current_build_config():
201    """Get the current build configuration."""
202    return _api_internal._GetCurrentBuildConfig()
203
204
205def build_config(**kwargs):
206    """Configure the build behavior by setting config variables.
207
208    Parameters
209    ----------
210    auto_unroll_max_step: int, default=0
211        Threshold of number of steps in the loop to be automatically unrolled.
212        This takes inner loop count into consideration.
213
214    auto_unroll_max_depth: int, default=8
215        The maximum nested level of loops that can be automatically unrolled.
216
217    unroll_explicit: bool, default=True
218        Whether explicitly unroll the loop, if set false, the unroll hint will
219        be passed to the CodeGen phase, which may generate pragma unroll hint.
220        Set this to be true if CodeGen support unroll pragma and
221        when we want to be more readable.
222
223    detect_global_barrier: bool, default=True
224        Whether detect global barrier.
225
226    partition_const_loop: bool, default=False
227        Whether partition const loop
228
229    data_alignment: int, optional
230        The alignment of data pointer in bytes.
231        If -1 is passed, the alignment will be set to TVM's internal default.
232
233    offset_factor: int, default=0
234        The factor used in default buffer declaration.
235        If specified as 0, offset field is not used.
236
237    restricted_func: bool, default=True
238        Whether build restricted function.
239        That is each buffer argument to the function are guaranteed
240        not to overlap. This enables more optimization.
241        Corresponds to restricted keyword in C99
242
243    double_buffer_split_loop: int, default=2
244        Whether split the loop with factor. If it is zero, no splitting will happen.
245        It it is bigger than one, the logic will do a split with factor equals the integer
246        and unroll the inner loop. This allows the buffer fetching won't contain condition.
247
248    add_lower_pass: list of tuple (phase, function(Stmt->Stmt)), default=None
249        phase contains an integer on which optimization pass we apply the pass.
250        Additional lowering passes to be applied before make_api.
251
252    dump_pass_ir: dump ir of each pass into file idx_passname_ir.cc, default=False
253
254    Returns
255    -------
256    config: BuildConfig
257        The build configuration
258    """
259    node_args = {k: v if k not in kwargs else kwargs[k]
260                 for k, v in BuildConfig._node_defaults.items()}
261    config = make.node("BuildConfig", **node_args)
262
263    if "add_lower_pass" in kwargs:
264        config.add_lower_pass = kwargs["add_lower_pass"]
265
266    return config
267
268def get_binds(args, compact=False, binds=None):
269    """Internal function to get binds and arg_list given arguments.
270
271    Parameters
272    ----------
273    args : list of Buffer or Tensor or Var
274        The argument lists to the function.
275
276    compact : bool
277        If the statement has already bound to a compact buffer.
278
279    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
280        Dictionary that maps the Tensor to Buffer which specified the data layout
281        requirement of the function. By default, a new compact buffer is created
282        for each tensor in the argument.
283
284    Returns
285    -------
286    binds: dict
287        The bind specification
288
289    arg_list: list
290        The list of symbolic buffers of arguments.
291    """
292    binds = {} if binds is None else binds.copy()
293    cfg = current_build_config()
294    arg_list = []
295    for x in args:
296        if isinstance(x, tensor.Tensor):
297            any_dim = any(isinstance(i, expr.Var) for i in x.shape)
298            buffer_type = "auto_broadcast" if any_dim and not compact else ""
299            if x not in binds:
300                buf = api.decl_buffer(x.shape,
301                                      dtype=x.dtype,
302                                      name=x.name,
303                                      data_alignment=cfg.data_alignment,
304                                      offset_factor=cfg.offset_factor,
305                                      buffer_type=buffer_type)
306                binds[x] = buf
307                arg_list.append(buf)
308            else:
309                arg_list.append(binds[x])
310        elif isinstance(x, schedule.Buffer):
311            arg_list.append(x)
312        elif isinstance(x, expr.Var):
313            arg_list.append(x)
314        else:
315            raise ValueError("args must be Tensor, Buffer or Var")
316    return binds, arg_list
317
318
319def form_body(sch):
320    """According to the given schedule, form the raw body
321    Parameters
322    ----------
323    sch : tvm.schedule.Schedule
324    The given scheduler to form the raw body
325
326    Returns
327    -------
328    The body formed according to the given schedule
329    """
330    # normalize schedule first
331    sch = sch.normalize()
332    bounds = schedule.InferBound(sch)
333    stmt = schedule.ScheduleOps(sch, bounds)
334    stmt = ir_pass.InjectPrefetch(stmt)
335    return stmt
336
337
338def lower(sch,
339          args,
340          name="default_function",
341          binds=None,
342          simple_mode=False):
343    """Lowering step before build into target.
344
345    Parameters
346    ----------
347    sch : tvm.schedule.Schedule
348        The schedule to be built
349
350    args : list of Buffer or Tensor or Var
351        The argument lists to the function.
352
353    name : str, optional
354        The name of result function.
355
356    binds : dict of :any:`Tensor` to :any:`Buffer`, optional
357        Dictionary that maps the Tensor to Buffer which specified the data layout
358        requirement of the function. By default, a new compact buffer is created
359        for each tensor in the argument.
360
361    simple_mode : bool, optional
362        Whether only output simple and compact statement, this will skip
363        LoopPartition, api wrapper generation and Unrolling.
364
365    Returns
366    -------
367    f : LoweredFunc or Stmt
368       The result function, if with_api_wrapper=False
369       Then the Stmt before make api is returned.
370    """
371    cfg = current_build_config()
372    add_lower_pass = cfg.add_lower_pass if cfg.add_lower_pass else []
373    if cfg.dump_pass_ir:
374        add_lower_pass = BuildConfig._dump_ir.decorate_custompass(add_lower_pass)
375    lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
376    lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
377    lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
378    lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
379
380    # Phase 0
381    if isinstance(sch, schedule.Schedule):
382        stmt = form_body(sch)
383
384    for f in lower_phase0:
385        stmt = f(stmt)
386
387    compact = ir_pass.VerifyCompactBuffer(stmt)
388    binds, arg_list = get_binds(args, compact, binds)
389
390    # Phase 1
391    stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
392    stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
393    stmt = ir_pass.CanonicalSimplify(stmt)
394    for f in lower_phase1:
395        stmt = f(stmt)
396
397    # Phase 2
398    if not simple_mode:
399        stmt = ir_pass.LoopPartition(stmt, cfg.partition_const_loop)
400    if cfg.disable_vectorize:
401        stmt = ir_pass.SkipVectorize(stmt)
402    else:
403        stmt = ir_pass.VectorizeLoop(stmt)
404    stmt = ir_pass.InjectVirtualThread(stmt)
405    stmt = ir_pass.InjectDoubleBuffer(stmt, cfg.double_buffer_split_loop)
406    stmt = ir_pass.StorageRewrite(stmt)
407    stmt = ir_pass.UnrollLoop(
408        stmt,
409        cfg.auto_unroll_max_step,
410        cfg.auto_unroll_max_depth,
411        cfg.auto_unroll_max_extent,
412        cfg.unroll_explicit)
413    for f in lower_phase2:
414        stmt = f(stmt)
415
416    # Phase 3
417    stmt = ir_pass.Simplify(stmt)
418    stmt = ir_pass.RemoveNoOp(stmt)
419    if not cfg.disable_select_rewriting:
420        stmt = ir_pass.RewriteUnsafeSelect(stmt)
421    for f in lower_phase3:
422        stmt = f(stmt)
423    # Instrument BoundCheckers
424    if cfg.instrument_bound_checkers:
425        stmt = ir_pass.InstrumentBoundCheckers(stmt)
426    if simple_mode:
427        return stmt
428
429    return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func)
430
431
432def _build_for_device(flist, target, target_host):
433    """Build the lowered functions for a device with the given compilation
434    target.
435
436    Parameters
437    ----------
438    flist : list of LoweredFunc
439        The schedule to be built.
440
441    target : str or :any:`tvm.target.Target`
442        The target and option of the compilation.
443
444    target_host : str or :any:`tvm.target.Target`
445        The host compilation target.
446
447    Returns
448    -------
449    fhost : list of LoweredFunc
450        A list of lowered functions for the host.
451
452    mdev : tvm.module
453        A module that contains device code.
454    """
455    target = _target.create(target)
456    device_type = ndarray.context(target.target_name, 0).device_type
457    fhost = []
458    fdevice = []
459    for func in flist:
460        if not ir_pass.VerifyMemory(func, device_type):
461            raise ValueError(
462                "Direct host side access to device memory is detected in %s. "
463                "Did you forget to bind?" % func.name)
464        if func.func_type == container.LoweredFunc.MixedFunc:
465            if current_build_config().detect_global_barrier:
466                func = ir_pass.ThreadSync(func, "global")
467            func = ir_pass.ThreadSync(func, "shared")
468            func = ir_pass.ThreadSync(func, "warp")
469            func = ir_pass.InferFragment(func)
470            warp_size = target.thread_warp_size
471            func = ir_pass.LowerThreadAllreduce(func, warp_size)
472            fsplits = [s for s in ir_pass.SplitHostDevice(func)]
473            fhost.append(fsplits[0])
474            for x in fsplits[1:]:
475                fdevice.append(x)
476        elif func.func_type == container.LoweredFunc.HostFunc:
477            fhost.append(func)
478        elif func.func_type == container.LoweredFunc.DeviceFunc:
479            fdevice.append(func)
480        else:
481            raise ValueError("unknown function type %d" % func.func_type)
482
483    for i, func in enumerate(fdevice):
484        warp_size = target.thread_warp_size
485        fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)
486
487    if "gpu" in target.keys and not fdevice:
488        warnings.warn(
489            "Specified target %s, but cannot find device code, did you do "
490            "bind?" % target)
491
492    fhost = [ir_pass.BindDeviceType(x, device_type) for x in fhost]
493    fhost = [ir_pass.LowerTVMBuiltin(x) for x in fhost]
494
495    if device_type == ndarray.cpu(0).device_type and target_host == target:
496        assert not fdevice
497
498    target_host = _target.create(target_host)
499    fdevice = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fdevice]
500    fhost = [ir_pass.LowerDeviceStorageAccessInfo(x) for x in fhost]
501    fdevice = [ir_pass.LowerIntrin(x, target.target_name) for x in fdevice]
502    fhost = [ir_pass.LowerIntrin(x, target_host.target_name) for x in fhost]
503    fhost = [ir_pass.CombineContextCall(x) for x in fhost]
504    mdev = codegen.build_module(fdevice, str(target)) if fdevice else None
505
506    return fhost, mdev
507
508
509def build(inputs,
510          args=None,
511          target=None,
512          target_host=None,
513          name="default_function",
514          binds=None):
515    """Build a function with arguments as signature. Code will be generated
516    for devices coupled with target information.
517
518    Parameters
519    ----------
520    inputs : tvm.Schedule, LoweredFunc, or dict of target to LoweredFunc list
521        The schedule to be built
522
523    args : list of Buffer or Tensor or Var, optional
524        The argument lists to the function.
525
526    target : str or :any:`tvm.target.Target`, optional
527        The target and option of the compilation.
528
529    target_host : str or :any:`tvm.target.Target` optional
530        Host compilation target, if target is device.
531        When TVM compiles device specific program such as CUDA,
532        we also need host(CPU) side code to interact with the driver
533        setup the dimensions and parameters correctly.
534        target_host is used to specify the host side codegen target.
535        By default, llvm is used if it is enabled,
536        otherwise a stackvm intepreter is used.
537
538    name : str, optional
539        The name of result function.
540
541    binds : dict, optional
542        Dictionary that maps the binding of symbolic buffer to Tensor.
543        By default, a new buffer is created for each tensor in the argument.
544
545    Returns
546    -------
547    ret : tvm.module
548        A module that combines both host and device code.
549
550    Examples
551    ________
552    There are two typical example uses of this function depending on the type
553    of the argument `inputs`:
554    1. it is a list of lowered functions:
555
556    .. code-block:: python
557
558        n = 2
559        A = tvm.placeholder((n,), name='A')
560        B = tvm.placeholder((n,), name='B')
561        C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
562        s = tvm.create_schedule(C.op)
563        f = tvm.lower(s, [A, B, C], name="test_add")
564        m = tvm.build(f, target="llvm")
565
566    2. it is a dict of compilation target to list of lowered functions:
567
568    .. code-block:: python
569
570        n = 2
571        A = tvm.placeholder((n,), name='A')
572        B = tvm.placeholder((n,), name='B')
573        C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
574        s1 = tvm.create_schedule(C.op)
575        with tvm.target.cuda() as cuda_tgt:
576          s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
577          f1 = tvm.lower(s1, [A, B, C], name="test_add1")
578          f2 = tvm.lower(s2, [A, B, C], name="test_add2")
579          m = tvm.build({"llvm": [f1], "cuda": [f2]}, target_host="llvm")
580
581    Note
582    ----
583    See the note on :any:`tvm.target` on target string format.
584    """
585    if isinstance(inputs, schedule.Schedule):
586        if args is None:
587            raise ValueError("args must be given for build from schedule")
588        flist = lower(inputs, args,
589                      name=name,
590                      binds=binds)
591        if isinstance(flist, container.LoweredFunc):
592            flist = [flist]
593    elif isinstance(inputs, container.LoweredFunc):
594        if args:
595            raise ValueError("args must be done when build from LoweredFunc.")
596        flist = [inputs]
597    elif isinstance(inputs, (list, tuple, container.Array)):
598        flist = inputs
599    elif not isinstance(inputs, (dict, container.Map)):
600        raise ValueError("inputs must be Schedule, LoweredFunc, list of "
601                         "LoweredFunc, or dict of target to list of "
602                         "LoweredFunc.")
603
604    if not isinstance(inputs, (dict, container.Map)):
605        target = _target.current_target() if target is None else target
606        target = target if target else "llvm"
607        target_flist = {target: flist}
608    else:
609        target_flist = inputs
610
611    for tar, flist in target_flist.items():
612        if not isinstance(tar, (str, _target.Target)):
613            raise ValueError("The key of inputs must be str or "
614                             "_target.Target when inputs is dict.")
615        fname_set = set()
616        for x in flist:
617            if not isinstance(x, container.LoweredFunc):
618                raise ValueError("inputs must be Schedule, LoweredFunc, list "
619                                 "of LoweredFunc, or dict of str to list of "
620                                 "LoweredFunc.")
621            if x.name in fname_set:
622                raise ValueError("Duplicate function name %s" % x.name)
623            fname_set.add(x.name)
624
625    if not target_host:
626        for tar, _ in target_flist.items():
627            tar = _target.create(tar)
628            device_type = ndarray.context(tar.target_name, 0).device_type
629            if device_type == ndarray.cpu(0).device_type:
630                target_host = tar
631                break
632    if not target_host:
633        target_host = "llvm" if module.enabled("llvm") else "stackvm"
634
635    fhost_all = []
636    device_modules = []
637    for tar, flist in target_flist.items():
638        fhost, mdev = _build_for_device(flist, tar, target_host)
639        # Save the current lowered functions of the host and the device module.
640        fhost_all += fhost
641        device_modules.append(mdev)
642
643    # Generate a unified host module.
644    mhost = codegen.build_module(fhost_all, str(target_host))
645
646    # Import all modules.
647    for mdev in device_modules:
648        if mdev:
649            mhost.import_module(mdev)
650    return mhost
651