1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, unused-variable,
18"""Schedule for composition of injective operator"""
19import tvm
20from .. import generic, util
21
22@generic.schedule_injective_from_existing.register(["cuda", "gpu"])
23def schedule_injective_from_existing(sch, out):
24    """Schedule for injective op from existing schedule.
25
26    Parameters
27    ----------
28    sch: Schedule
29         The schedule to update.
30    out: Tensor
31         The tensor representing the injective op.
32
33    Returns
34    -------
35    sch: Schedule
36         The updated schedule.
37    """
38    fused = sch[out].fuse(*sch[out].op.axis)
39    num_thread = tvm.target.current_target(allow_none=False).max_num_threads
40    max_block = 256
41
42    try:
43        const_size = util.get_const_int(util.prod(out.shape))
44        max_block = 256
45        need_block_split = const_size > max_block * num_thread
46    except ValueError:
47        need_block_split = False
48
49    if need_block_split:
50        xo, xi = sch[out].split(fused, factor=num_thread * max_block)
51        bx, tx = sch[out].split(xi, factor=num_thread)
52        sch[out].reorder(bx, tx, xo)
53        sch[out].bind(bx, tvm.thread_axis("blockIdx.x"))
54        sch[out].bind(tx, tvm.thread_axis("threadIdx.x"))
55    else:
56        bx, tx = sch[out].split(fused, factor=num_thread)
57        sch[out].bind(tx, tvm.thread_axis("threadIdx.x"))
58        sch[out].bind(bx, tvm.thread_axis("blockIdx.x"))
59
60    return sch
61
62@generic.schedule_injective.register(["cuda", "gpu"])
63def schedule_injective(outs):
64    """Schedule for injective op.
65
66    Parameters
67    ----------
68    outs: Array of Tensor
69          The computation graph description of reduce in the format
70          of an array of tensors.
71
72    Returns
73    -------
74    sch: Schedule
75        The computation schedule for the op.
76    """
77    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
78    s = tvm.create_schedule([x.op for x in outs])
79
80    tvm.schedule.AutoInlineInjective(s)
81    for out in outs:
82        schedule_injective_from_existing(s, out)
83    return s
84
85schedule_elemwise = schedule_injective
86schedule_broadcast = schedule_injective
87