1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Custom datatype functionality""" 18from __future__ import absolute_import as _abs 19 20from ._ffi.function import register_func as _register_func 21from . import make as _make 22from .api import convert 23from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm 24from ._ffi.runtime_ctypes import TVMType as _TVMType 25from . import _api_internal 26 27 28def register(type_name, type_code): 29 """Register a custom datatype with the given type name and type code 30 Currently, the type code is manually allocated by the user, and the 31 user must ensure that no two custom types share the same code. 32 Generally, this should be straightforward, as the user will be 33 manually registering all of their custom types. 34 35 Parameters 36 ---------- 37 type_name : str 38 The name of the custom datatype 39 40 type_code : int 41 The type's code, which should be >= kCustomBegin 42 """ 43 _api_internal._datatype_register(type_name, type_code) 44 45 46def get_type_name(type_code): 47 """Get the type name from the type code 48 49 Parameters 50 ---------- 51 type_code : int 52 The type code 53 """ 54 return _api_internal._datatype_get_type_name(type_code) 55 56 57def get_type_code(type_name): 58 """Get the type code from the type name 59 60 Parameters 61 ---------- 62 type_name : str 63 The type name 64 """ 65 return _api_internal._datatype_get_type_code(type_name) 66 67 68def get_type_registered(type_code): 69 """Get a boolean representing whether the type is registered 70 71 Parameters 72 ---------- 73 type_code: int 74 The type code 75 """ 76 return _api_internal._datatype_get_type_registered(type_code) 77 78 79def register_op(lower_func, op_name, target, type_name, src_type_name=None): 80 """Register an external function which computes the given op. 81 82 Currently, this will only work with Casts and binary expressions 83 whose arguments are named `a` and `b`. 84 TODO(gus) figure out what other special cases must be handled by 85 looking through expr.py. 86 87 Parameters 88 ---------- 89 lower_func : function 90 The lowering function to call. See create_lower_func. 91 92 op_name : str 93 The name of the operation which the function computes, given by its 94 Halide::Internal class name (e.g. Add, LE, Cast). 95 96 target : str 97 The name of codegen target. 98 99 type_name : str 100 The name of the custom datatype, e.g. posit (but not custom[posit]8). 101 102 src_type_name : str 103 If op_name is "Cast", then this should be set to the source datatype of 104 the argument to the Cast. If op_name is not "Cast", this is unused. 105 """ 106 107 if op_name == "Cast": 108 assert src_type_name is not None 109 lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \ 110 + type_name + "." + src_type_name 111 else: 112 lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \ 113 + type_name 114 _register_func(lower_func_name, lower_func) 115 116 117def create_lower_func(extern_func_name): 118 """Returns a function which lowers an operation to a function call. 119 120 Parameters 121 ---------- 122 extern_func_name : str 123 The name of the extern "C" function to lower to 124 """ 125 126 def lower(op): 127 """ 128 Takes an op---either a Cast or a binary op (e.g. an Add) and returns a 129 call to the specified external function, passing the op's argument 130 (Cast) or arguments (a binary op). The return type of the call depends 131 on the type of the op: if it is a custom type, then a uint of the same 132 width as the custom type is returned. Otherwise, the type is 133 unchanged.""" 134 dtype = op.dtype 135 t = _TVMType(dtype) 136 if get_type_registered(t.type_code): 137 dtype = "uint" + str(t.bits) 138 if t.lanes > 1: 139 dtype += "x" + str(t.lanes) 140 if isinstance(op, (_Cast, _FloatImm)): 141 return _make.Call(dtype, extern_func_name, convert([op.value]), 142 _Call.Extern, None, 0) 143 return _make.Call(dtype, extern_func_name, convert([op.a, op.b]), 144 _Call.Extern, None, 0) 145 146 return lower 147