1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, unused-variable, unused-argument 18"""Schedule for pooling operators""" 19import tvm 20from .. import tag 21from .. import generic 22from ..util import traverse_inline 23 24 25 26@generic.schedule_adaptive_pool.register(["cuda", "gpu"]) 27def schedule_adaptive_pool(outs): 28 """Schedule for adaptive_pool. 29 30 Parameters 31 ---------- 32 outs: Array of Tensor 33 The computation graph description of adaptive_pool 34 in the format of an array of tensors. 35 36 Returns 37 ------- 38 s: Schedule 39 The computation schedule for adaptive_pool. 40 """ 41 outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 42 s = tvm.create_schedule([x.op for x in outs]) 43 44 def _schedule(Pool): 45 num_thread = 8 46 block_x = tvm.thread_axis("blockIdx.x") 47 block_y = tvm.thread_axis("blockIdx.y") 48 thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") 49 thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") 50 if Pool.op in s.outputs: 51 Out = Pool 52 OL = s.cache_write(Pool, "local") 53 else: 54 Out = outs[0].op.output(0) 55 s[Pool].set_scope("local") 56 by, ty = s[Out].split(s[Out].op.axis[0], factor=num_thread) 57 bx, tx = s[Out].split(s[Out].op.axis[1], factor=num_thread) 58 s[Out].reorder(by, bx, ty, tx) 59 s[Out].bind(ty, thread_y) 60 s[Out].bind(tx, thread_x) 61 s[Out].bind(by, block_y) 62 s[Out].bind(bx, block_x) 63 if Pool.op in s.outputs: 64 s[OL].compute_at(s[Out], tx) 65 else: 66 s[Pool].compute_at(s[Out], tx) 67 68 scheduled_ops = [] 69 70 def traverse(OP): 71 """Internal travserse function""" 72 # inline all one-to-one-mapping operators except the last stage (output) 73 if tag.is_broadcast(OP.tag): 74 if OP not in s.outputs: 75 s[OP].compute_inline() 76 for tensor in OP.input_tensors: 77 if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: 78 traverse(tensor.op) 79 # schedule global_pool 80 elif OP.tag.startswith('adaptive_pool'): 81 Pool = OP.output(0) 82 _schedule(Pool) 83 else: 84 raise RuntimeError("Unsupported operator: %s" % OP.tag) 85 86 scheduled_ops.append(OP) 87 88 traverse(outs[0].op) 89 return s 90 91 92@generic.schedule_pool.register(["cuda", "gpu"]) 93def schedule_pool(outs, layout): 94 """Schedule for pool. 95 96 Parameters 97 ---------- 98 outs: Array of Tensor 99 The computation graph description of pool 100 in the format of an array of tensors. 101 102 layout: str 103 Data layout. 104 105 Returns 106 ------- 107 s: Schedule 108 The computation schedule for pool. 109 """ 110 outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 111 s = tvm.create_schedule([x.op for x in outs]) 112 def _schedule(PaddedInput, Pool): 113 if isinstance(PaddedInput.op, tvm.tensor.ComputeOp): 114 s[PaddedInput].compute_inline() 115 num_thread = tvm.target.current_target(allow_none=False).max_num_threads 116 if Pool.op in s.outputs: 117 Out = Pool 118 OL = s.cache_write(Pool, "local") 119 else: 120 Out = outs[0].op.output(0) 121 s[Pool].set_scope("local") 122 fused = s[Out].fuse(*s[Out].op.axis) 123 bx, tx = s[Out].split(fused, factor=num_thread) 124 s[Out].bind(bx, tvm.thread_axis("blockIdx.x")) 125 s[Out].bind(tx, tvm.thread_axis("threadIdx.x")) 126 if Pool.op in s.outputs: 127 s[OL].compute_at(s[Out], tx) 128 else: 129 s[Pool].compute_at(s[Out], tx) 130 131 scheduled_ops = [] 132 133 def traverse(OP): 134 """Internal travserse function""" 135 # inline all one-to-one-mapping operators except the last stage (output) 136 if tag.is_broadcast(OP.tag): 137 if OP not in s.outputs: 138 s[OP].compute_inline() 139 for tensor in OP.input_tensors: 140 if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: 141 traverse(tensor.op) 142 # schedule pool 143 elif OP.tag.startswith('pool'): 144 PaddedInput = OP.input_tensors[0] 145 Pool = OP.output(0) 146 _schedule(PaddedInput, Pool) 147 else: 148 raise RuntimeError("Unsupported operator: %s" % OP.tag) 149 150 scheduled_ops.append(OP) 151 152 traverse(outs[0].op) 153 return s 154 155 156@generic.schedule_pool_grad.register(['cuda', 'gpu']) 157def schedule_pool_grad_cuda(outs): 158 """Schedule for pool_grad on CUDA 159 160 Parameters 161 ---------- 162 outs: Array of Tensor 163 The computation graph description of pool_grad 164 in the format of an array of tensors. 165 166 Returns 167 ------- 168 s: Schedule 169 The computation schedule for pool_grad. 170 """ 171 outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 172 s = tvm.create_schedule([x.op for x in outs]) 173 174 def _schedule_pool_grad(op): 175 if op in s.outputs: 176 out = op 177 else: 178 out = outs[0].op.output(0) 179 fused = s[out].fuse(*s[out].op.axis) 180 num_thread = tvm.target.current_target(allow_none=False).max_num_threads 181 bx, tx = s[out].split(fused, factor=num_thread) 182 s[out].bind(bx, tvm.thread_axis("blockIdx.x")) 183 s[out].bind(tx, tvm.thread_axis("threadIdx.x")) 184 185 if tag.COMM_REDUCE_IDX in op.input_tensors[0].op.tag: 186 max_pool_index = op.input_tensors[0] 187 s[max_pool_index].compute_at(s[out], tx) 188 189 pool_input = max_pool_index.op.input_tensors[0] 190 if isinstance(pool_input.op, tvm.tensor.ComputeOp): 191 # handle padding 192 s[pool_input].compute_inline() 193 if op not in s.outputs: 194 s[op].compute_at(s[out], tx) 195 196 def _callback(op): 197 if op.tag.startswith('pool_grad'): 198 _schedule_pool_grad(op) 199 200 traverse_inline(s, outs[0].op, _callback) 201 202 return s 203