1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name,unused-variable,too-many-locals,len-as-condition 18"""Schedule for reduce operators""" 19from __future__ import absolute_import as _abs 20import tvm 21from .. import tag 22from .. import generic 23from .injective import schedule_injective_from_existing 24 25def _schedule_reduce(op, sch, is_idx_reduce=False): 26 if is_idx_reduce: 27 data_out = op.input_tensors[0] 28 else: 29 data_in = op.input_tensors[0] 30 data_out = op.output(0) 31 32 if not sch[data_out].op.reduce_axis: 33 return schedule_injective_from_existing(sch, op.output(0)) 34 35 if len(sch[data_out].op.axis) > 0: 36 all_reduce = False 37 num_thread = 32 38 target = tvm.target.current_target() 39 if target and target.target_name == "opencl": 40 # without it, CL_INVALID_WORK_GROUP_SIZE occurred when running test_topi_reduce.py 41 # don't know why 42 num_thread = 16 43 block_x = tvm.thread_axis("blockIdx.x") 44 thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") 45 thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") 46 else: 47 all_reduce = True 48 num_thread = tvm.target.current_target(allow_none=False).max_num_threads 49 thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") 50 51 # Fuse and refactor the reduce axis 52 fused_reduce = sch[data_out].fuse(*[sch[data_out].op.reduce_axis[i] 53 for i in range(len(sch[data_out].op.reduce_axis))]) 54 ko, ki = sch[data_out].split(fused_reduce, factor=num_thread) 55 if is_idx_reduce: 56 data_out_rf, _ = sch.rfactor(data_out, ki) 57 else: 58 data_out_rf = sch.rfactor(data_out, ki) 59 tx = sch[data_out].op.reduce_axis[0] 60 sch[data_out].bind(tx, thread_x) 61 sch[data_out_rf].compute_at(sch[data_out], tx) 62 if is_idx_reduce: 63 real_output = op.output(0) 64 temp_idx_input = data_out.op.output(0) 65 temp_val_input = data_out.op.output(1) 66 else: 67 real_output = data_out 68 if not all_reduce: 69 # Fuse and split the axis 70 fused_outer = sch[real_output].fuse(*[sch[real_output].op.axis[i] 71 for i in range(len(sch[real_output].op.axis))]) 72 bx, outer_in = sch[real_output].split(fused_outer, factor=num_thread) 73 74 # Bind the axes to threads and blocks 75 sch[real_output].bind(outer_in, thread_y) 76 sch[real_output].bind(bx, block_x) 77 if is_idx_reduce: 78 sch[temp_idx_input].compute_at(sch[real_output], outer_in) 79 sch[temp_val_input].compute_at(sch[real_output], outer_in) 80 else: 81 if is_idx_reduce: 82 spatial_axis = sch[real_output].fuse(*(sch[real_output].op.axis)) 83 sch[real_output].bind(spatial_axis, tvm.thread_axis("blockIdx.x")) 84 sch[temp_idx_input].compute_at(sch[real_output], 85 spatial_axis) 86 sch[temp_val_input].compute_at(sch[real_output], 87 spatial_axis) 88 sch[real_output].set_store_predicate(thread_x.equal(0)) 89 return sch 90 91 92@generic.schedule_reduce.register(["cuda", "gpu"]) 93def schedule_reduce(outs): 94 """Schedule for inject->reduce->bcast ops. 95 96 Parameters 97 ---------- 98 outs: Array of Tensor 99 The computation graph description of reduce in the format 100 of an array of tensors. 101 102 Returns 103 ------- 104 sch: Schedule 105 The computation schedule for the op. 106 """ 107 outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 108 sch = tvm.create_schedule([x.op for x in outs]) 109 scheduled_ops = [] 110 111 def traverse_before_reduce(operator): 112 """Internal travserse function""" 113 if isinstance(operator, tvm.tensor.PlaceholderOp): 114 return 115 if tag.is_injective(operator.tag): 116 sch[operator].compute_inline() 117 for tensor in operator.input_tensors: 118 if tensor.op not in scheduled_ops: 119 traverse_before_reduce(tensor.op) 120 else: 121 raise RuntimeError("Unsupported operator: %s" % operator.tag) 122 123 scheduled_ops.append(operator) 124 125 def traverse_after_reduce(operator): 126 """Internal travserse function""" 127 if tag.is_broadcast(operator.tag): 128 if operator not in scheduled_ops: 129 schedule_injective_from_existing(sch, operator.output(0)) 130 for tensor in operator.input_tensors: 131 traverse_after_reduce(tensor.op) 132 elif operator.tag == 'comm_reduce': 133 _schedule_reduce(operator, sch, is_idx_reduce=False) 134 for tensor in operator.input_tensors: 135 if tensor.op not in scheduled_ops: 136 traverse_before_reduce(tensor.op) 137 elif operator.tag == 'comm_reduce_idx': 138 _schedule_reduce(operator, sch, is_idx_reduce=True) 139 input_tensors = operator.input_tensors[0].op.input_tensors 140 for tensor in input_tensors: 141 if tensor.op not in scheduled_ops: 142 traverse_before_reduce(tensor.op) 143 else: 144 raise RuntimeError("Unsupported operator: %s" % operator.tag) 145 146 scheduled_ops.append(operator) 147 148 traverse_after_reduce(outs[0].op) 149 return sch 150