1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Common runtime ctypes."""
18# pylint: disable=invalid-name
19import ctypes
20import json
21import numpy as np
22from .base import _LIB, check_call
23
24tvm_shape_index_t = ctypes.c_int64
25
26
27class ArgTypeCode(object):
28    """Type code used in API calls"""
29
30    INT = 0
31    UINT = 1
32    FLOAT = 2
33    HANDLE = 3
34    NULL = 4
35    TVM_TYPE = 5
36    TVM_CONTEXT = 6
37    DLTENSOR_HANDLE = 7
38    OBJECT_HANDLE = 8
39    MODULE_HANDLE = 9
40    PACKED_FUNC_HANDLE = 10
41    STR = 11
42    BYTES = 12
43    NDARRAY_HANDLE = 13
44    OBJECT_RVALUE_REF_ARG = 14
45    EXT_BEGIN = 15
46
47
48class TVMByteArray(ctypes.Structure):
49    """Temp data structure for byte array."""
50
51    _fields_ = [("data", ctypes.POINTER(ctypes.c_byte)), ("size", ctypes.c_size_t)]
52
53
54class DataTypeCode(object):
55    """DataType code in DLTensor."""
56
57    INT = 0
58    UINT = 1
59    FLOAT = 2
60    HANDLE = 3
61    BFLOAT = 4
62
63
64class DataType(ctypes.Structure):
65    """TVM datatype structure"""
66
67    _fields_ = [("type_code", ctypes.c_uint8), ("bits", ctypes.c_uint8), ("lanes", ctypes.c_uint16)]
68    CODE2STR = {
69        DataTypeCode.INT: "int",
70        DataTypeCode.UINT: "uint",
71        DataTypeCode.FLOAT: "float",
72        DataTypeCode.HANDLE: "handle",
73        DataTypeCode.BFLOAT: "bfloat",
74    }
75
76    def __init__(self, type_str):
77        super(DataType, self).__init__()
78        if isinstance(type_str, np.dtype):
79            type_str = str(type_str)
80
81        if type_str == "bool":
82            self.bits = 1
83            self.type_code = DataTypeCode.UINT
84            self.lanes = 1
85            return
86
87        arr = type_str.split("x")
88        head = arr[0]
89        self.lanes = int(arr[1]) if len(arr) > 1 else 1
90        bits = 32
91
92        if head.startswith("int"):
93            self.type_code = DataTypeCode.INT
94            head = head[3:]
95        elif head.startswith("uint"):
96            self.type_code = DataTypeCode.UINT
97            head = head[4:]
98        elif head.startswith("float"):
99            self.type_code = DataTypeCode.FLOAT
100            head = head[5:]
101        elif head.startswith("handle"):
102            self.type_code = DataTypeCode.HANDLE
103            bits = 64
104            head = ""
105        elif head.startswith("bfloat"):
106            self.type_code = DataTypeCode.BFLOAT
107            head = head[6:]
108        elif head.startswith("custom"):
109            # pylint: disable=import-outside-toplevel
110            import tvm.runtime._ffi_api
111
112            low, high = head.find("["), head.find("]")
113            if not low or not high or low >= high:
114                raise ValueError("Badly formatted custom type string %s" % type_str)
115            type_name = head[low + 1 : high]
116            self.type_code = tvm.runtime._ffi_api._datatype_get_type_code(type_name)
117            head = head[high + 1 :]
118        else:
119            raise ValueError("Do not know how to handle type %s" % type_str)
120        bits = int(head) if head else bits
121        self.bits = bits
122
123    def __repr__(self):
124        # pylint: disable=import-outside-toplevel
125        if self.bits == 1 and self.lanes == 1:
126            return "bool"
127        if self.type_code in DataType.CODE2STR:
128            type_name = DataType.CODE2STR[self.type_code]
129        else:
130            import tvm.runtime._ffi_api
131
132            type_name = "custom[%s]" % tvm.runtime._ffi_api._datatype_get_type_name(self.type_code)
133        x = "%s%d" % (type_name, self.bits)
134        if self.lanes != 1:
135            x += "x%d" % self.lanes
136        return x
137
138    def __eq__(self, other):
139        return (
140            self.bits == other.bits
141            and self.type_code == other.type_code
142            and self.lanes == other.lanes
143        )
144
145    def __ne__(self, other):
146        return not self.__eq__(other)
147
148
149RPC_SESS_MASK = 128
150
151
152class TVMContext(ctypes.Structure):
153    """TVM context strucure."""
154
155    _fields_ = [("device_type", ctypes.c_int), ("device_id", ctypes.c_int)]
156    MASK2STR = {
157        1: "cpu",
158        2: "gpu",
159        4: "opencl",
160        5: "aocl",
161        6: "sdaccel",
162        7: "vulkan",
163        8: "metal",
164        9: "vpi",
165        10: "rocm",
166        12: "ext_dev",
167        13: "micro_dev",
168        14: "hexagon",
169        15: "webgpu",
170    }
171    STR2MASK = {
172        "llvm": 1,
173        "stackvm": 1,
174        "cpu": 1,
175        "c": 1,
176        "gpu": 2,
177        "cuda": 2,
178        "nvptx": 2,
179        "cl": 4,
180        "opencl": 4,
181        "aocl": 5,
182        "aocl_sw_emu": 5,
183        "sdaccel": 6,
184        "vulkan": 7,
185        "metal": 8,
186        "vpi": 9,
187        "rocm": 10,
188        "ext_dev": 12,
189        "micro_dev": 13,
190        "hexagon": 14,
191        "webgpu": 15,
192    }
193
194    def __init__(self, device_type, device_id):
195        super(TVMContext, self).__init__()
196        self.device_type = device_type
197        self.device_id = device_id
198
199    def _GetDeviceAttr(self, device_type, device_id, attr_id):
200        """Internal helper function to invoke runtime.GetDeviceAttr"""
201        # pylint: disable=import-outside-toplevel
202        import tvm.runtime._ffi_api
203
204        return tvm.runtime._ffi_api.GetDeviceAttr(device_type, device_id, attr_id)
205
206    @property
207    def exist(self):
208        """Whether this device exist."""
209        return self._GetDeviceAttr(self.device_type, self.device_id, 0) != 0
210
211    @property
212    def max_threads_per_block(self):
213        """Maximum number of threads on each block."""
214        return self._GetDeviceAttr(self.device_type, self.device_id, 1)
215
216    @property
217    def warp_size(self):
218        """Number of threads that executes in concurrent."""
219        return self._GetDeviceAttr(self.device_type, self.device_id, 2)
220
221    @property
222    def max_shared_memory_per_block(self):
223        """Total amount of shared memory per block in bytes."""
224        return self._GetDeviceAttr(self.device_type, self.device_id, 3)
225
226    @property
227    def compute_version(self):
228        """Get compute verison number in string.
229
230        Currently used to get compute capability of CUDA device.
231
232        Returns
233        -------
234        version : str
235            The version string in `major.minor` format.
236        """
237        return self._GetDeviceAttr(self.device_type, self.device_id, 4)
238
239    @property
240    def device_name(self):
241        """Return the string name of device."""
242        return self._GetDeviceAttr(self.device_type, self.device_id, 5)
243
244    @property
245    def max_clock_rate(self):
246        """Return the max clock frequency of device."""
247        return self._GetDeviceAttr(self.device_type, self.device_id, 6)
248
249    @property
250    def multi_processor_count(self):
251        """Return the number of compute units of device."""
252        return self._GetDeviceAttr(self.device_type, self.device_id, 7)
253
254    @property
255    def max_thread_dimensions(self):
256        """Return the maximum size of each thread axis
257
258        Returns
259        -------
260        dims: List of int
261            The maximum length of threadIdx.x, threadIdx.y, threadIdx.z
262        """
263        return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8))
264
265    def sync(self):
266        """Synchronize until jobs finished at the context."""
267        check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
268
269    def __eq__(self, other):
270        return (
271            isinstance(other, TVMContext)
272            and self.device_id == other.device_id
273            and self.device_type == other.device_type
274        )
275
276    def __ne__(self, other):
277        return not self.__eq__(other)
278
279    def __hash__(self):
280        return hash(str(self))
281
282    def __repr__(self):
283        if self.device_type >= RPC_SESS_MASK:
284            tbl_id = self.device_type / RPC_SESS_MASK - 1
285            dev_type = self.device_type % RPC_SESS_MASK
286            return "remote[%d]:%s(%d)" % (tbl_id, TVMContext.MASK2STR[dev_type], self.device_id)
287        return "%s(%d)" % (TVMContext.MASK2STR[self.device_type], self.device_id)
288
289
290class TVMArray(ctypes.Structure):
291    """TVMValue in C API"""
292
293    _fields_ = [
294        ("data", ctypes.c_void_p),
295        ("ctx", TVMContext),
296        ("ndim", ctypes.c_int),
297        ("dtype", DataType),
298        ("shape", ctypes.POINTER(tvm_shape_index_t)),
299        ("strides", ctypes.POINTER(tvm_shape_index_t)),
300        ("byte_offset", ctypes.c_uint64),
301    ]
302
303
304class ObjectRValueRef:
305    """Represent an RValue ref to an object that can be moved.
306
307    Parameters
308    ----------
309    obj : tvm.runtime.Object
310        The object that this value refers to
311    """
312
313    __slots__ = ["obj"]
314
315    def __init__(self, obj):
316        self.obj = obj
317
318
319TVMArrayHandle = ctypes.POINTER(TVMArray)
320