1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name 18"""x86 declaration and schedules.""" 19import tvm 20from tvm import te 21from .injective import schedule_injective_from_existing 22from .. import tag 23from ..util import get_const_tuple 24 25 26def _schedule_reduce(sch, op, is_idx_reduce=False): 27 if is_idx_reduce: 28 real_out = op.output(0) 29 fused = sch[real_out].fuse(*sch[real_out].op.axis) 30 out = op.input_tensors[0] 31 else: 32 out = op.output(0) 33 34 const_shape = True 35 out_shape = get_const_tuple(out.shape) 36 for d in out_shape: 37 if not isinstance(d, int): 38 const_shape = False 39 break 40 41 if const_shape: 42 naxes = len(sch[out].op.axis) 43 parallelism = 1 44 fuse_axes = [] 45 # We choose a heuristic number 128 to limit the maximum parallelism 46 while len(fuse_axes) < naxes and parallelism < 128: 47 ivar = sch[out].op.axis[len(fuse_axes)] 48 parallelism *= int(ivar.dom.extent) 49 fuse_axes.append(ivar) 50 fused = sch[out].fuse(*fuse_axes) 51 sch[out].parallel(fused) 52 else: 53 if len(sch[out].op.axis) >= 5: 54 # avoid too many parallelism 55 fused = sch[out].fuse(sch[out].op.axis[0], sch[out].op.axis[1], sch[out].op.axis[2]) 56 sch[out].parallel(fused) 57 else: 58 fused = sch[out].fuse(*sch[out].op.axis) 59 sch[out].parallel(fused) 60 61 62def schedule_reduce(outs): 63 """X86 schedule for reduction op. 64 65 Parameters 66 ---------- 67 outs: Array of Tensor 68 The computation graph description of injective in the format 69 of an array of tensors. 70 71 Returns 72 ------- 73 sch: Schedule 74 The computation schedule for the op. 75 """ 76 outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs 77 sch = te.create_schedule([x.op for x in outs]) 78 scheduled_ops = [] 79 80 def traverse_before_reduce(operator): 81 """Internal traverse function""" 82 if isinstance(operator, tvm.te.PlaceholderOp): 83 return 84 if tag.is_injective(operator.tag): 85 sch[operator].compute_inline() 86 for tensor in operator.input_tensors: 87 if tensor.op not in scheduled_ops: 88 traverse_before_reduce(tensor.op) 89 else: 90 raise RuntimeError("Unsupported operator: %s" % operator.tag) 91 92 scheduled_ops.append(operator) 93 94 def traverse_after_reduce(operator): 95 """Internal traverse function""" 96 if tag.is_broadcast(operator.tag): 97 if operator not in scheduled_ops: 98 schedule_injective_from_existing(sch, operator) 99 for tensor in operator.input_tensors: 100 traverse_after_reduce(tensor.op) 101 elif operator.tag == "comm_reduce": 102 _schedule_reduce(sch, operator, is_idx_reduce=False) 103 for tensor in operator.input_tensors: 104 if tensor.op not in scheduled_ops: 105 traverse_before_reduce(tensor.op) 106 elif operator.tag == "comm_reduce_idx": 107 _schedule_reduce(sch, operator, is_idx_reduce=True) 108 input_tensors = operator.input_tensors[0].op.input_tensors 109 for tensor in input_tensors: 110 if tensor.op not in scheduled_ops: 111 traverse_before_reduce(tensor.op) 112 elif isinstance(operator, tvm.te.PlaceholderOp): 113 pass 114 else: 115 raise RuntimeError("Unsupported operator: %s (tag: %s)" % (operator, operator.tag)) 116 117 scheduled_ops.append(operator) 118 119 traverse_after_reduce(outs[0].op) 120 return sch 121