1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Tensor intrinsics""" 18from __future__ import absolute_import as _abs 19from . import _api_internal 20from . import api as _api 21from . import expr as _expr 22from . import stmt as _stmt 23from . import make as _make 24from . import tensor as _tensor 25from . import schedule as _schedule 26from .build_module import current_build_config 27from ._ffi.node import NodeBase, register_node 28 29 30def _get_region(tslice): 31 region = [] 32 for idx in tslice.indices: 33 if isinstance(idx, slice): 34 assert idx.step is None 35 region.append(_api.Range(idx.start, idx.stop)) 36 else: 37 if isinstance(idx, _schedule.IterVar): 38 begin = idx.var 39 else: 40 begin = idx 41 region.append(_make.range_by_min_extent(begin, 1)) 42 return region 43 44@register_node 45class TensorIntrin(NodeBase): 46 """Tensor intrinsic functions for certain computation. 47 48 See Also 49 -------- 50 decl_tensor_intrin: Construct a TensorIntrin 51 """ 52 def __call__(self, *args, **kwargs): 53 tensors = [x.tensor for x in args if isinstance(x, _tensor.TensorSlice)] 54 scalar_inputs = [x for x in args if not isinstance(x, _tensor.TensorSlice)] 55 regions = [_get_region(x) for x in args if isinstance(x, _tensor.TensorSlice)] 56 reduce_axis = [] 57 if "reduce_axis" in kwargs: 58 reduce_axis = kwargs["reduce_axis"] 59 if not isinstance(reduce_axis, (list, tuple)): 60 reduce_axis = [reduce_axis] 61 reduce_axis = _api.convert(reduce_axis) 62 if scalar_inputs: 63 scalar_inputs = _api.convert(scalar_inputs) 64 return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs) 65 66def decl_tensor_intrin(op, 67 fcompute, 68 name="tensor_intrin", 69 binds=None, scalar_params=None): 70 """Declare a tensor intrinsic function. 71 72 Parameters 73 ---------- 74 op: Operation 75 The symbolic description of the intrinsic operation 76 77 fcompute: lambda function of inputs, outputs-> stmt 78 Specifies the IR statement to do the computation. 79 See the following note for function signature of fcompute 80 81 .. note:: 82 **Parameters** 83 84 - **ins** (list of :any:`Buffer`) - Placeholder for each inputs 85 - **outs** (list of :any:`Buffer`) - Placeholder for each outputs 86 87 **Returns** 88 89 - **stmt** (:any:`Stmt`, or tuple of three stmts) 90 - If a single stmt is returned, it represents the body 91 - If tuple of three stmts are returned they corresponds to body, 92 reduce_init, reduce_update 93 94 name: str, optional 95 The name of the intrinsic. 96 97 binds: dict of :any:`Tensor` to :any:`Buffer`, optional 98 Dictionary that maps the Tensor to Buffer which specified the data layout 99 requirement of the function. By default, a new compact buffer is created 100 for each tensor in the argument. 101 102 scalar_params: a list of variables used by op, whose values will be passed 103 as scalar_inputs when the tensor intrinsic is called. 104 105 Returns 106 ------- 107 intrin: TensorIntrin 108 A TensorIntrin that can be used in tensorize schedule. 109 """ 110 if not isinstance(op, _tensor.Operation): 111 raise TypeError("expect Operation") 112 inputs = op.input_tensors 113 binds = binds if binds else {} 114 tensors = [x for x in inputs] 115 for i in range(op.num_outputs): 116 tensors.append(op.output(i)) 117 118 binds_list = [] 119 for t in inputs: 120 if not isinstance(t.op, _tensor.PlaceholderOp): 121 raise ValueError("Do not yet support composition op") 122 123 cfg = current_build_config() 124 for t in tensors: 125 buf = (binds[t] if t in binds else 126 _api.decl_buffer(t.shape, t.dtype, t.op.name, 127 data_alignment=cfg.data_alignment, 128 offset_factor=cfg.offset_factor)) 129 binds_list.append(buf) 130 131 if scalar_params: 132 body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params) 133 else: 134 body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):]) 135 scalar_params = [] 136 if isinstance(body, (_expr.Expr, _stmt.Stmt)): 137 body = [body] 138 body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body] 139 if len(body) < 3: 140 body += [None] * (3 - len(body)) 141 return _api_internal._TensorIntrin( 142 name, op, inputs, binds_list, scalar_params, *body) 143