1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name
18"""Runtime Object api"""
19import ctypes
20from ..base import _LIB, check_call
21from .types import ArgTypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
22from .ndarray import _register_ndarray, NDArrayBase
23
24
25ObjectHandle = ctypes.c_void_p
26__init_by_constructor__ = None
27
28"""Maps object type to its constructor"""
29OBJECT_TYPE = {}
30
31_CLASS_OBJECT = None
32
33
34def _set_class_object(object_class):
35    global _CLASS_OBJECT
36    _CLASS_OBJECT = object_class
37
38
39def _register_object(index, cls):
40    """register object class"""
41    if issubclass(cls, NDArrayBase):
42        _register_ndarray(index, cls)
43        return
44    OBJECT_TYPE[index] = cls
45
46
47def _return_object(x):
48    handle = x.v_handle
49    if not isinstance(handle, ObjectHandle):
50        handle = ObjectHandle(handle)
51    tindex = ctypes.c_uint()
52    check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex)))
53    cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT)
54    if issubclass(cls, PyNativeObject):
55        obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
56        obj.handle = handle
57        return cls.__from_tvm_object__(cls, obj)
58    # Avoid calling __init__ of cls, instead directly call __new__
59    # This allows child class to implement their own __init__
60    obj = cls.__new__(cls)
61    obj.handle = handle
62    return obj
63
64
65RETURN_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _return_object
66C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _wrap_arg_func(
67    _return_object, ArgTypeCode.OBJECT_HANDLE
68)
69
70C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func(
71    _return_object, ArgTypeCode.OBJECT_RVALUE_REF_ARG
72)
73
74
75class PyNativeObject:
76    """Base class of all TVM objects that also subclass python's builtin types."""
77
78    __slots__ = []
79
80    def __init_tvm_object_by_constructor__(self, fconstructor, *args):
81        """Initialize the internal tvm_object by calling constructor function.
82
83        Parameters
84        ----------
85        fconstructor : Function
86            Constructor function.
87
88        args: list of objects
89            The arguments to the constructor
90
91        Note
92        ----
93        We have a special calling convention to call constructor functions.
94        So the return object is directly set into the object
95        """
96        # pylint: disable=assigning-non-slot
97        obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
98        obj.__init_handle_by_constructor__(fconstructor, *args)
99        self.__tvm_object__ = obj
100
101
102class ObjectBase(object):
103    """Base object for all object types"""
104
105    __slots__ = ["handle"]
106
107    def __del__(self):
108        if _LIB is not None:
109            check_call(_LIB.TVMObjectFree(self.handle))
110
111    def __init_handle_by_constructor__(self, fconstructor, *args):
112        """Initialize the handle by calling constructor function.
113
114        Parameters
115        ----------
116        fconstructor : Function
117            Constructor function.
118
119        args: list of objects
120            The arguments to the constructor
121
122        Note
123        ----
124        We have a special calling convention to call constructor functions.
125        So the return handle is directly set into the Node object
126        instead of creating a new Node.
127        """
128        # assign handle first to avoid error raising
129        # pylint: disable=not-callable
130        self.handle = None
131        handle = __init_by_constructor__(fconstructor, args)
132        if not isinstance(handle, ObjectHandle):
133            handle = ObjectHandle(handle)
134        self.handle = handle
135
136    def same_as(self, other):
137        """Check object identity.
138
139        Parameters
140        ----------
141        other : object
142            The other object to compare against.
143
144        Returns
145        -------
146        result : bool
147             The comparison result.
148        """
149        if not isinstance(other, ObjectBase):
150            return False
151        if self.handle is None:
152            return other.handle is None
153        return self.handle.value == other.handle.value
154