1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name,unused-variable,too-many-locals,len-as-condition
18"""Schedule for reduce operators"""
19from __future__ import absolute_import as _abs
20import tvm
21from .. import tag
22from .. import generic
23from .injective import schedule_injective_from_existing
24
25def _schedule_reduce(op, sch, is_idx_reduce=False):
26    if is_idx_reduce:
27        data_out = op.input_tensors[0]
28    else:
29        data_in = op.input_tensors[0]
30        data_out = op.output(0)
31
32    if not sch[data_out].op.reduce_axis:
33        return schedule_injective_from_existing(sch, op.output(0))
34
35    if len(sch[data_out].op.axis) > 0:
36        all_reduce = False
37        num_thread = 32
38        target = tvm.target.current_target()
39        if target and target.target_name == "opencl":
40            # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py
41            # don't know why
42            num_thread = 16
43        block_x = tvm.thread_axis("blockIdx.x")
44        thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
45        thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
46    else:
47        all_reduce = True
48        num_thread = tvm.target.current_target(allow_none=False).max_num_threads
49        thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
50
51    # Fuse and refactor the reduce axis
52    fused_reduce = sch[data_out].fuse(*[sch[data_out].op.reduce_axis[i]
53                                        for i in range(len(sch[data_out].op.reduce_axis))])
54    ko, ki = sch[data_out].split(fused_reduce, factor=num_thread)
55    if is_idx_reduce:
56        data_out_rf, _ = sch.rfactor(data_out, ki)
57    else:
58        data_out_rf = sch.rfactor(data_out, ki)
59    tx = sch[data_out].op.reduce_axis[0]
60    sch[data_out].bind(tx, thread_x)
61    sch[data_out_rf].compute_at(sch[data_out], tx)
62    if is_idx_reduce:
63        real_output = op.output(0)
64        temp_idx_input = data_out.op.output(0)
65        temp_val_input = data_out.op.output(1)
66    else:
67        real_output = data_out
68    if not all_reduce:
69        # Fuse and split the axis
70        fused_outer = sch[real_output].fuse(*[sch[real_output].op.axis[i]
71                                              for i in range(len(sch[real_output].op.axis))])
72        bx, outer_in = sch[real_output].split(fused_outer, factor=num_thread)
73
74        # Bind the axes to threads and blocks
75        sch[real_output].bind(outer_in, thread_y)
76        sch[real_output].bind(bx, block_x)
77        if is_idx_reduce:
78            sch[temp_idx_input].compute_at(sch[real_output], outer_in)
79            sch[temp_val_input].compute_at(sch[real_output], outer_in)
80    else:
81        if is_idx_reduce:
82            spatial_axis = sch[real_output].fuse(*(sch[real_output].op.axis))
83            sch[real_output].bind(spatial_axis, tvm.thread_axis("blockIdx.x"))
84            sch[temp_idx_input].compute_at(sch[real_output],
85                                           spatial_axis)
86            sch[temp_val_input].compute_at(sch[real_output],
87                                           spatial_axis)
88    sch[real_output].set_store_predicate(thread_x.equal(0))
89    return sch
90
91
92@generic.schedule_reduce.register(["cuda", "gpu"])
93def schedule_reduce(outs):
94    """Schedule for inject->reduce->bcast ops.
95
96    Parameters
97    ----------
98    outs: Array of Tensor
99          The computation graph description of reduce in the format
100          of an array of tensors.
101
102    Returns
103    -------
104    sch: Schedule
105        The computation schedule for the op.
106    """
107    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
108    sch = tvm.create_schedule([x.op for x in outs])
109    scheduled_ops = []
110
111    def traverse_before_reduce(operator):
112        """Internal travserse function"""
113        if isinstance(operator, tvm.tensor.PlaceholderOp):
114            return
115        if tag.is_injective(operator.tag):
116            sch[operator].compute_inline()
117            for tensor in operator.input_tensors:
118                if tensor.op not in scheduled_ops:
119                    traverse_before_reduce(tensor.op)
120        else:
121            raise RuntimeError("Unsupported operator: %s" % operator.tag)
122
123        scheduled_ops.append(operator)
124
125    def traverse_after_reduce(operator):
126        """Internal travserse function"""
127        if tag.is_broadcast(operator.tag):
128            if operator not in scheduled_ops:
129                schedule_injective_from_existing(sch, operator.output(0))
130            for tensor in operator.input_tensors:
131                traverse_after_reduce(tensor.op)
132        elif operator.tag == 'comm_reduce':
133            _schedule_reduce(operator, sch, is_idx_reduce=False)
134            for tensor in operator.input_tensors:
135                if tensor.op not in scheduled_ops:
136                    traverse_before_reduce(tensor.op)
137        elif operator.tag == 'comm_reduce_idx':
138            _schedule_reduce(operator, sch, is_idx_reduce=True)
139            input_tensors = operator.input_tensors[0].op.input_tensors
140            for tensor in input_tensors:
141                if tensor.op not in scheduled_ops:
142                    traverse_before_reduce(tensor.op)
143        else:
144            raise RuntimeError("Unsupported operator: %s" % operator.tag)
145
146        scheduled_ops.append(operator)
147
148    traverse_after_reduce(outs[0].op)
149    return sch
150