1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name
18"""x86 declaration and schedules."""
19import tvm
20from tvm import te
21from .injective import schedule_injective_from_existing
22from .. import tag
23from ..util import get_const_tuple
24
25
26def _schedule_reduce(sch, op, is_idx_reduce=False):
27    if is_idx_reduce:
28        real_out = op.output(0)
29        fused = sch[real_out].fuse(*sch[real_out].op.axis)
30        out = op.input_tensors[0]
31    else:
32        out = op.output(0)
33
34    const_shape = True
35    out_shape = get_const_tuple(out.shape)
36    for d in out_shape:
37        if not isinstance(d, int):
38            const_shape = False
39            break
40
41    if const_shape:
42        naxes = len(sch[out].op.axis)
43        parallelism = 1
44        fuse_axes = []
45        # We choose a heuristic number 128 to limit the maximum parallelism
46        while len(fuse_axes) < naxes and parallelism < 128:
47            ivar = sch[out].op.axis[len(fuse_axes)]
48            parallelism *= int(ivar.dom.extent)
49            fuse_axes.append(ivar)
50        fused = sch[out].fuse(*fuse_axes)
51        sch[out].parallel(fused)
52    else:
53        if len(sch[out].op.axis) >= 5:
54            # avoid too many parallelism
55            fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2])
56            sch[out].parallel(fused)
57        else:
58            fused = sch[out].fuse(*sch[out].op.axis)
59            sch[out].parallel(fused)
60
61
62def schedule_reduce(outs):
63    """X86 schedule for reduction op.
64
65    Parameters
66    ----------
67    outs: Array of Tensor
68          The computation graph description of injective in the format
69          of an array of tensors.
70
71    Returns
72    -------
73    sch: Schedule
74        The computation schedule for the op.
75    """
76    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
77    sch = te.create_schedule([x.op for x in outs])
78    scheduled_ops = []
79
80    def traverse_before_reduce(operator):
81        """Internal traverse function"""
82        if isinstance(operator, tvm.te.PlaceholderOp):
83            return
84        if tag.is_injective(operator.tag):
85            sch[operator].compute_inline()
86            for tensor in operator.input_tensors:
87                if tensor.op not in scheduled_ops:
88                    traverse_before_reduce(tensor.op)
89        else:
90            raise RuntimeError("Unsupported operator: %s" % operator.tag)
91
92        scheduled_ops.append(operator)
93
94    def traverse_after_reduce(operator):
95        """Internal traverse function"""
96        if tag.is_broadcast(operator.tag):
97            if operator not in scheduled_ops:
98                schedule_injective_from_existing(sch, operator)
99            for tensor in operator.input_tensors:
100                traverse_after_reduce(tensor.op)
101        elif operator.tag == "comm_reduce":
102            _schedule_reduce(sch, operator, is_idx_reduce=False)
103            for tensor in operator.input_tensors:
104                if tensor.op not in scheduled_ops:
105                    traverse_before_reduce(tensor.op)
106        elif operator.tag == "comm_reduce_idx":
107            _schedule_reduce(sch, operator, is_idx_reduce=True)
108            input_tensors = operator.input_tensors[0].op.input_tensors
109            for tensor in input_tensors:
110                if tensor.op not in scheduled_ops:
111                    traverse_before_reduce(tensor.op)
112        elif isinstance(operator, tvm.te.PlaceholderOp):
113            pass
114        else:
115            raise RuntimeError("Unsupported operator: %s (tag: %s)" % (operator, operator.tag))
116
117        scheduled_ops.append(operator)
118
119    traverse_after_reduce(outs[0].op)
120    return sch
121