1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name
18"""x86 declaration and schedules."""
19from __future__ import absolute_import as _abs
20import tvm
21from .. import tag
22from .. import generic
23from ..util import get_const_tuple
24
25def _schedule_reduce(sch, op, is_idx_reduce=False):
26    if is_idx_reduce:
27        real_out = op.output(0)
28        fused = sch[real_out].fuse(*sch[real_out].op.axis)
29        out = op.input_tensors[0]
30    else:
31        out = op.output(0)
32
33    const_shape = True
34    out_shape = get_const_tuple(out.shape)
35    for d in out_shape:
36        if not isinstance(d, int):
37            const_shape = False
38            break
39
40    if const_shape:
41        naxes = len(sch[out].op.axis)
42        parallelism = 1
43        fuse_axes = []
44        # We choose a heuristic number 128 to limit the maximum parallelism
45        while len(fuse_axes) < naxes and parallelism < 128:
46            ivar = sch[out].op.axis[len(fuse_axes)]
47            parallelism *= int(ivar.dom.extent)
48            fuse_axes.append(ivar)
49        fused = sch[out].fuse(*fuse_axes)
50        sch[out].parallel(fused)
51    else:
52        if len(sch[out].op.axis) >= 5:
53            # avoid too many parallelism
54            fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2])
55            sch[out].parallel(fused)
56        else:
57            fused = sch[out].fuse(*sch[out].op.axis)
58            sch[out].parallel(fused)
59
60
61@generic.schedule_reduce.register(["cpu"])
62def schedule_reduce(outs):
63    """X86 schedule for reduction op.
64
65    Parameters
66    ----------
67    outs: Array of Tensor
68          The computation graph description of injective in the format
69          of an array of tensors.
70
71    Returns
72    -------
73    sch: Schedule
74        The computation schedule for the op.
75    """
76    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
77    sch = tvm.create_schedule([x.op for x in outs])
78    scheduled_ops = []
79
80    def traverse_before_reduce(operator):
81        """Internal travserse function"""
82        if isinstance(operator, tvm.tensor.PlaceholderOp):
83            return
84        if tag.is_injective(operator.tag):
85            sch[operator].compute_inline()
86            for tensor in operator.input_tensors:
87                if tensor.op not in scheduled_ops:
88                    traverse_before_reduce(tensor.op)
89        else:
90            raise RuntimeError("Unsupported operator: %s" % operator.tag)
91
92        scheduled_ops.append(operator)
93
94    def traverse_after_reduce(operator):
95        """Internal travserse function"""
96        if tag.is_broadcast(operator.tag):
97            if operator not in scheduled_ops:
98                generic.schedule_injective_from_existing(sch, operator)
99            for tensor in operator.input_tensors:
100                traverse_after_reduce(tensor.op)
101        elif operator.tag == 'comm_reduce':
102            _schedule_reduce(sch, operator, is_idx_reduce=False)
103            for tensor in operator.input_tensors:
104                if tensor.op not in scheduled_ops:
105                    traverse_before_reduce(tensor.op)
106        elif operator.tag == 'comm_reduce_idx':
107            _schedule_reduce(sch, operator, is_idx_reduce=True)
108            input_tensors = operator.input_tensors[0].op.input_tensors
109            for tensor in input_tensors:
110                if tensor.op not in scheduled_ops:
111                    traverse_before_reduce(tensor.op)
112        elif isinstance(operator, tvm.tensor.PlaceholderOp):
113            pass
114        else:
115            raise RuntimeError("Unsupported operator: %s (tag: %s)" % (operator, operator.tag))
116
117        scheduled_ops.append(operator)
118
119    traverse_after_reduce(outs[0].op)
120    return sch
121