1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, unused-variable, 18"""Schedule for composition of injective operator""" 19import tvm 20from .. import generic, util 21 22@generic.schedule_injective_from_existing.register(["cuda", "gpu"]) 23def schedule_injective_from_existing(sch, out): 24 """Schedule for injective op from existing schedule. 25 26 Parameters 27 ---------- 28 sch: Schedule 29 The schedule to update. 30 out: Tensor 31 The tensor representing the injective op. 32 33 Returns 34 ------- 35 sch: Schedule 36 The updated schedule. 37 """ 38 fused = sch[out].fuse(*sch[out].op.axis) 39 num_thread = tvm.target.current_target(allow_none=False).max_num_threads 40 max_block = 256 41 42 try: 43 const_size = util.get_const_int(util.prod(out.shape)) 44 max_block = 256 45 need_block_split = const_size > max_block * num_thread 46 except ValueError: 47 need_block_split = False 48 49 if need_block_split: 50 xo, xi = sch[out].split(fused, factor=num_thread * max_block) 51 bx, tx = sch[out].split(xi, factor=num_thread) 52 sch[out].reorder(bx, tx, xo) 53 sch[out].bind(bx, tvm.thread_axis("blockIdx.x")) 54 sch[out].bind(tx, tvm.thread_axis("threadIdx.x")) 55 else: 56 bx, tx = sch[out].split(fused, factor=num_thread) 57 sch[out].bind(tx, tvm.thread_axis("threadIdx.x")) 58 sch[out].bind(bx, tvm.thread_axis("blockIdx.x")) 59 60 return sch 61 62@generic.schedule_injective.register(["cuda", "gpu"]) 63def schedule_injective(outs): 64 """Schedule for injective op. 65 66 Parameters 67 ---------- 68 outs: Array of Tensor 69 The computation graph description of reduce in the format 70 of an array of tensors. 71 72 Returns 73 ------- 74 sch: Schedule 75 The computation schedule for the op. 76 """ 77 outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 78 s = tvm.create_schedule([x.op for x in outs]) 79 80 tvm.schedule.AutoInlineInjective(s) 81 for out in outs: 82 schedule_injective_from_existing(s, out) 83 return s 84 85schedule_elemwise = schedule_injective 86schedule_broadcast = schedule_injective 87