1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Common runtime ctypes.""" 18# pylint: disable=invalid-name 19import ctypes 20import json 21import numpy as np 22from .base import _LIB, check_call 23 24tvm_shape_index_t = ctypes.c_int64 25 26 27class ArgTypeCode(object): 28 """Type code used in API calls""" 29 30 INT = 0 31 UINT = 1 32 FLOAT = 2 33 HANDLE = 3 34 NULL = 4 35 TVM_TYPE = 5 36 TVM_CONTEXT = 6 37 DLTENSOR_HANDLE = 7 38 OBJECT_HANDLE = 8 39 MODULE_HANDLE = 9 40 PACKED_FUNC_HANDLE = 10 41 STR = 11 42 BYTES = 12 43 NDARRAY_HANDLE = 13 44 OBJECT_RVALUE_REF_ARG = 14 45 EXT_BEGIN = 15 46 47 48class TVMByteArray(ctypes.Structure): 49 """Temp data structure for byte array.""" 50 51 _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), ("size", ctypes.c_size_t)] 52 53 54class DataTypeCode(object): 55 """DataType code in DLTensor.""" 56 57 INT = 0 58 UINT = 1 59 FLOAT = 2 60 HANDLE = 3 61 BFLOAT = 4 62 63 64class DataType(ctypes.Structure): 65 """TVM datatype structure""" 66 67 _fields_ = [("type_code", ctypes.c_uint8), ("bits", ctypes.c_uint8), ("lanes", ctypes.c_uint16)] 68 CODE2STR = { 69 DataTypeCode.INT: "int", 70 DataTypeCode.UINT: "uint", 71 DataTypeCode.FLOAT: "float", 72 DataTypeCode.HANDLE: "handle", 73 DataTypeCode.BFLOAT: "bfloat", 74 } 75 76 def __init__(self, type_str): 77 super(DataType, self).__init__() 78 if isinstance(type_str, np.dtype): 79 type_str = str(type_str) 80 81 if type_str == "bool": 82 self.bits = 1 83 self.type_code = DataTypeCode.UINT 84 self.lanes = 1 85 return 86 87 arr = type_str.split("x") 88 head = arr[0] 89 self.lanes = int(arr[1]) if len(arr) > 1 else 1 90 bits = 32 91 92 if head.startswith("int"): 93 self.type_code = DataTypeCode.INT 94 head = head[3:] 95 elif head.startswith("uint"): 96 self.type_code = DataTypeCode.UINT 97 head = head[4:] 98 elif head.startswith("float"): 99 self.type_code = DataTypeCode.FLOAT 100 head = head[5:] 101 elif head.startswith("handle"): 102 self.type_code = DataTypeCode.HANDLE 103 bits = 64 104 head = "" 105 elif head.startswith("bfloat"): 106 self.type_code = DataTypeCode.BFLOAT 107 head = head[6:] 108 elif head.startswith("custom"): 109 # pylint: disable=import-outside-toplevel 110 import tvm.runtime._ffi_api 111 112 low, high = head.find("["), head.find("]") 113 if not low or not high or low >= high: 114 raise ValueError("Badly formatted custom type string %s" % type_str) 115 type_name = head[low + 1 : high] 116 self.type_code = tvm.runtime._ffi_api._datatype_get_type_code(type_name) 117 head = head[high + 1 :] 118 else: 119 raise ValueError("Do not know how to handle type %s" % type_str) 120 bits = int(head) if head else bits 121 self.bits = bits 122 123 def __repr__(self): 124 # pylint: disable=import-outside-toplevel 125 if self.bits == 1 and self.lanes == 1: 126 return "bool" 127 if self.type_code in DataType.CODE2STR: 128 type_name = DataType.CODE2STR[self.type_code] 129 else: 130 import tvm.runtime._ffi_api 131 132 type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code) 133 x = "%s%d" % (type_name, self.bits) 134 if self.lanes != 1: 135 x += "x%d" % self.lanes 136 return x 137 138 def __eq__(self, other): 139 return ( 140 self.bits == other.bits 141 and self.type_code == other.type_code 142 and self.lanes == other.lanes 143 ) 144 145 def __ne__(self, other): 146 return not self.__eq__(other) 147 148 149RPC_SESS_MASK = 128 150 151 152class TVMContext(ctypes.Structure): 153 """TVM context strucure.""" 154 155 _fields_ = [("device_type", ctypes.c_int), ("device_id", ctypes.c_int)] 156 MASK2STR = { 157 1: "cpu", 158 2: "gpu", 159 4: "opencl", 160 5: "aocl", 161 6: "sdaccel", 162 7: "vulkan", 163 8: "metal", 164 9: "vpi", 165 10: "rocm", 166 12: "ext_dev", 167 13: "micro_dev", 168 14: "hexagon", 169 15: "webgpu", 170 } 171 STR2MASK = { 172 "llvm": 1, 173 "stackvm": 1, 174 "cpu": 1, 175 "c": 1, 176 "gpu": 2, 177 "cuda": 2, 178 "nvptx": 2, 179 "cl": 4, 180 "opencl": 4, 181 "aocl": 5, 182 "aocl_sw_emu": 5, 183 "sdaccel": 6, 184 "vulkan": 7, 185 "metal": 8, 186 "vpi": 9, 187 "rocm": 10, 188 "ext_dev": 12, 189 "micro_dev": 13, 190 "hexagon": 14, 191 "webgpu": 15, 192 } 193 194 def __init__(self, device_type, device_id): 195 super(TVMContext, self).__init__() 196 self.device_type = device_type 197 self.device_id = device_id 198 199 def _GetDeviceAttr(self, device_type, device_id, attr_id): 200 """Internal helper function to invoke runtime.GetDeviceAttr""" 201 # pylint: disable=import-outside-toplevel 202 import tvm.runtime._ffi_api 203 204 return tvm.runtime._ffi_api.GetDeviceAttr(device_type, device_id, attr_id) 205 206 @property 207 def exist(self): 208 """Whether this device exist.""" 209 return self._GetDeviceAttr(self.device_type, self.device_id, 0) != 0 210 211 @property 212 def max_threads_per_block(self): 213 """Maximum number of threads on each block.""" 214 return self._GetDeviceAttr(self.device_type, self.device_id, 1) 215 216 @property 217 def warp_size(self): 218 """Number of threads that executes in concurrent.""" 219 return self._GetDeviceAttr(self.device_type, self.device_id, 2) 220 221 @property 222 def max_shared_memory_per_block(self): 223 """Total amount of shared memory per block in bytes.""" 224 return self._GetDeviceAttr(self.device_type, self.device_id, 3) 225 226 @property 227 def compute_version(self): 228 """Get compute verison number in string. 229 230 Currently used to get compute capability of CUDA device. 231 232 Returns 233 ------- 234 version : str 235 The version string in `major.minor` format. 236 """ 237 return self._GetDeviceAttr(self.device_type, self.device_id, 4) 238 239 @property 240 def device_name(self): 241 """Return the string name of device.""" 242 return self._GetDeviceAttr(self.device_type, self.device_id, 5) 243 244 @property 245 def max_clock_rate(self): 246 """Return the max clock frequency of device.""" 247 return self._GetDeviceAttr(self.device_type, self.device_id, 6) 248 249 @property 250 def multi_processor_count(self): 251 """Return the number of compute units of device.""" 252 return self._GetDeviceAttr(self.device_type, self.device_id, 7) 253 254 @property 255 def max_thread_dimensions(self): 256 """Return the maximum size of each thread axis 257 258 Returns 259 ------- 260 dims: List of int 261 The maximum length of threadIdx.x, threadIdx.y, threadIdx.z 262 """ 263 return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8)) 264 265 def sync(self): 266 """Synchronize until jobs finished at the context.""" 267 check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None)) 268 269 def __eq__(self, other): 270 return ( 271 isinstance(other, TVMContext) 272 and self.device_id == other.device_id 273 and self.device_type == other.device_type 274 ) 275 276 def __ne__(self, other): 277 return not self.__eq__(other) 278 279 def __hash__(self): 280 return hash(str(self)) 281 282 def __repr__(self): 283 if self.device_type >= RPC_SESS_MASK: 284 tbl_id = self.device_type / RPC_SESS_MASK - 1 285 dev_type = self.device_type % RPC_SESS_MASK 286 return "remote[%d]:%s(%d)" % (tbl_id, TVMContext.MASK2STR[dev_type], self.device_id) 287 return "%s(%d)" % (TVMContext.MASK2STR[self.device_type], self.device_id) 288 289 290class TVMArray(ctypes.Structure): 291 """TVMValue in C API""" 292 293 _fields_ = [ 294 ("data", ctypes.c_void_p), 295 ("ctx", TVMContext), 296 ("ndim", ctypes.c_int), 297 ("dtype", DataType), 298 ("shape", ctypes.POINTER(tvm_shape_index_t)), 299 ("strides", ctypes.POINTER(tvm_shape_index_t)), 300 ("byte_offset", ctypes.c_uint64), 301 ] 302 303 304class ObjectRValueRef: 305 """Represent an RValue ref to an object that can be moved. 306 307 Parameters 308 ---------- 309 obj : tvm.runtime.Object 310 The object that this value refers to 311 """ 312 313 __slots__ = ["obj"] 314 315 def __init__(self, obj): 316 self.obj = obj 317 318 319TVMArrayHandle = ctypes.POINTER(TVMArray) 320