1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, unused-variable, unused-argument
18"""Schedule for pooling operators"""
19import tvm
20from .. import tag
21from .. import generic
22from ..util import traverse_inline
23
24
25
26@generic.schedule_adaptive_pool.register(["cuda", "gpu"])
27def schedule_adaptive_pool(outs):
28    """Schedule for adaptive_pool.
29
30    Parameters
31    ----------
32    outs: Array of Tensor
33        The computation graph description of adaptive_pool
34        in the format of an array of tensors.
35
36    Returns
37    -------
38    s: Schedule
39        The computation schedule for adaptive_pool.
40    """
41    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
42    s = tvm.create_schedule([x.op for x in outs])
43
44    def _schedule(Pool):
45        num_thread = 8
46        block_x = tvm.thread_axis("blockIdx.x")
47        block_y = tvm.thread_axis("blockIdx.y")
48        thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
49        thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
50        if Pool.op in s.outputs:
51            Out = Pool
52            OL = s.cache_write(Pool, "local")
53        else:
54            Out = outs[0].op.output(0)
55            s[Pool].set_scope("local")
56        by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread)
57        bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread)
58        s[Out].reorder(by, bx, ty, tx)
59        s[Out].bind(ty, thread_y)
60        s[Out].bind(tx, thread_x)
61        s[Out].bind(by, block_y)
62        s[Out].bind(bx, block_x)
63        if Pool.op in s.outputs:
64            s[OL].compute_at(s[Out], tx)
65        else:
66            s[Pool].compute_at(s[Out], tx)
67
68    scheduled_ops = []
69
70    def traverse(OP):
71        """Internal travserse function"""
72        # inline all one-to-one-mapping operators except the last stage (output)
73        if tag.is_broadcast(OP.tag):
74            if OP not in s.outputs:
75                s[OP].compute_inline()
76            for tensor in OP.input_tensors:
77                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
78                    traverse(tensor.op)
79        # schedule global_pool
80        elif OP.tag.startswith('adaptive_pool'):
81            Pool = OP.output(0)
82            _schedule(Pool)
83        else:
84            raise RuntimeError("Unsupported operator: %s" % OP.tag)
85
86        scheduled_ops.append(OP)
87
88    traverse(outs[0].op)
89    return s
90
91
92@generic.schedule_pool.register(["cuda", "gpu"])
93def schedule_pool(outs, layout):
94    """Schedule for pool.
95
96    Parameters
97    ----------
98    outs: Array of Tensor
99        The computation graph description of pool
100        in the format of an array of tensors.
101
102    layout: str
103        Data layout.
104
105    Returns
106    -------
107    s: Schedule
108        The computation schedule for pool.
109    """
110    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
111    s = tvm.create_schedule([x.op for x in outs])
112    def _schedule(PaddedInput, Pool):
113        if isinstance(PaddedInput.op, tvm.tensor.ComputeOp):
114            s[PaddedInput].compute_inline()
115        num_thread = tvm.target.current_target(allow_none=False).max_num_threads
116        if Pool.op in s.outputs:
117            Out = Pool
118            OL = s.cache_write(Pool, "local")
119        else:
120            Out = outs[0].op.output(0)
121            s[Pool].set_scope("local")
122        fused = s[Out].fuse(*s[Out].op.axis)
123        bx, tx = s[Out].split(fused, factor=num_thread)
124        s[Out].bind(bx, tvm.thread_axis("blockIdx.x"))
125        s[Out].bind(tx, tvm.thread_axis("threadIdx.x"))
126        if Pool.op in s.outputs:
127            s[OL].compute_at(s[Out], tx)
128        else:
129            s[Pool].compute_at(s[Out], tx)
130
131    scheduled_ops = []
132
133    def traverse(OP):
134        """Internal travserse function"""
135        # inline all one-to-one-mapping operators except the last stage (output)
136        if tag.is_broadcast(OP.tag):
137            if OP not in s.outputs:
138                s[OP].compute_inline()
139            for tensor in OP.input_tensors:
140                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
141                    traverse(tensor.op)
142        # schedule pool
143        elif OP.tag.startswith('pool'):
144            PaddedInput = OP.input_tensors[0]
145            Pool = OP.output(0)
146            _schedule(PaddedInput, Pool)
147        else:
148            raise RuntimeError("Unsupported operator: %s" % OP.tag)
149
150        scheduled_ops.append(OP)
151
152    traverse(outs[0].op)
153    return s
154
155
156@generic.schedule_pool_grad.register(['cuda', 'gpu'])
157def schedule_pool_grad_cuda(outs):
158    """Schedule for pool_grad on CUDA
159
160    Parameters
161    ----------
162    outs: Array of Tensor
163        The computation graph description of pool_grad
164        in the format of an array of tensors.
165
166    Returns
167    -------
168    s: Schedule
169        The computation schedule for pool_grad.
170    """
171    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
172    s = tvm.create_schedule([x.op for x in outs])
173
174    def _schedule_pool_grad(op):
175        if op in s.outputs:
176            out = op
177        else:
178            out = outs[0].op.output(0)
179        fused = s[out].fuse(*s[out].op.axis)
180        num_thread = tvm.target.current_target(allow_none=False).max_num_threads
181        bx, tx = s[out].split(fused, factor=num_thread)
182        s[out].bind(bx, tvm.thread_axis("blockIdx.x"))
183        s[out].bind(tx, tvm.thread_axis("threadIdx.x"))
184
185        if tag.COMM_REDUCE_IDX in op.input_tensors[0].op.tag:
186            max_pool_index = op.input_tensors[0]
187            s[max_pool_index].compute_at(s[out], tx)
188
189            pool_input = max_pool_index.op.input_tensors[0]
190            if isinstance(pool_input.op, tvm.tensor.ComputeOp):
191                # handle padding
192                s[pool_input].compute_inline()
193        if op not in s.outputs:
194            s[op].compute_at(s[out], tx)
195
196    def _callback(op):
197        if op.tag.startswith('pool_grad'):
198            _schedule_pool_grad(op)
199
200    traverse_inline(s, outs[0].op, _callback)
201
202    return s
203