1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""TVM Runtime NDArray API.
18
19tvm.ndarray provides a minimum runtime array API to test
20the correctness of the program.
21"""
22# pylint: disable=invalid-name,unused-import
23from __future__ import absolute_import as _abs
24import numpy as _np
25
26from ._ffi.ndarray import TVMContext, TVMType, NDArrayBase
27from ._ffi.ndarray import context, empty, from_dlpack
28from ._ffi.ndarray import _set_class_ndarray
29from ._ffi.ndarray import register_extension, free_extension_handle
30
31class NDArray(NDArrayBase):
32    """Lightweight NDArray class of TVM runtime.
33
34    Strictly this is only an Array Container (a buffer object)
35    No arthimetic operations are defined.
36    All operations are performed by TVM functions.
37
38    The goal is not to re-build yet another array library.
39    Instead, this is a minimal data structure to demonstrate
40    how can we use TVM in existing project which might have their own array containers.
41    """
42
43
44def cpu(dev_id=0):
45    """Construct a CPU device
46
47    Parameters
48    ----------
49    dev_id : int, optional
50        The integer device id
51
52    Returns
53    -------
54    ctx : TVMContext
55        The created context
56    """
57    return TVMContext(1, dev_id)
58
59
60def gpu(dev_id=0):
61    """Construct a CPU device
62
63    Parameters
64    ----------
65    dev_id : int, optional
66        The integer device id
67
68    Returns
69    -------
70    ctx : TVMContext
71        The created context
72    """
73    return TVMContext(2, dev_id)
74
75def rocm(dev_id=0):
76    """Construct a ROCM device
77
78    Parameters
79    ----------
80    dev_id : int, optional
81        The integer device id
82
83    Returns
84    -------
85    ctx : TVMContext
86        The created context
87    """
88    return TVMContext(10, dev_id)
89
90
91def opencl(dev_id=0):
92    """Construct a OpenCL device
93
94    Parameters
95    ----------
96    dev_id : int, optional
97        The integer device id
98
99    Returns
100    -------
101    ctx : TVMContext
102        The created context
103    """
104    return TVMContext(4, dev_id)
105
106
107def metal(dev_id=0):
108    """Construct a metal device
109
110    Parameters
111    ----------
112    dev_id : int, optional
113        The integer device id
114
115    Returns
116    -------
117    ctx : TVMContext
118        The created context
119    """
120    return TVMContext(8, dev_id)
121
122
123def vpi(dev_id=0):
124    """Construct a VPI simulated device
125
126    Parameters
127    ----------
128    dev_id : int, optional
129        The integer device id
130
131    Returns
132    -------
133    ctx : TVMContext
134        The created context
135    """
136    return TVMContext(9, dev_id)
137
138
139def vulkan(dev_id=0):
140    """Construct a Vulkan device
141
142    Parameters
143    ----------
144    dev_id : int, optional
145        The integer device id
146
147    Returns
148    -------
149    ctx : TVMContext
150        The created context
151    """
152    return TVMContext(7, dev_id)
153
154
155def opengl(dev_id=0):
156    """Construct a OpenGL device
157
158    Parameters
159    ----------
160    dev_id : int, optional
161        The integer device id
162
163    Returns
164    -------
165    ctx : TVMContext
166        The created context
167    """
168    return TVMContext(11, dev_id)
169
170
171def ext_dev(dev_id=0):
172    """Construct a extension device
173
174    Parameters
175    ----------
176    dev_id : int, optional
177        The integer device id
178
179    Returns
180    -------
181    ctx : TVMContext
182        The created context
183
184    Note
185    ----
186    This API is reserved for quick testing of new
187    device by plugin device API as ext_dev.
188    """
189    return TVMContext(12, dev_id)
190
191
192def micro_dev(dev_id=0):
193    """Construct a micro device
194
195    Parameters
196    ----------
197    dev_id : int, optional
198        The integer device id
199
200    Returns
201    -------
202    ctx : TVMContext
203        The created context
204    """
205    return TVMContext(13, dev_id)
206
207
208cl = opencl
209mtl = metal
210
211
212def array(arr, ctx=cpu(0)):
213    """Create an array from source arr.
214
215    Parameters
216    ----------
217    arr : numpy.ndarray
218        The array to be copied from
219
220    ctx : TVMContext, optional
221        The device context to create the array
222
223    Returns
224    -------
225    ret : NDArray
226        The created array
227    """
228    if not isinstance(arr, (_np.ndarray, NDArray)):
229        arr = _np.array(arr)
230    return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
231
232_set_class_ndarray(NDArray)
233