1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License
17
18
19import sys as _sys
20import ctypes as _ctypes
21import numpy as _numpy
22
23from numbers import Number as _Number
24from ..name import NameManager
25from ..attribute import AttrScope
26from ..symbol_doc import _build_doc
27
28include "./base.pyi"
29
30cdef class SymbolBase:
31    """Symbol is symbolic graph."""
32    # handle for symbolic operator.
33    cdef SymbolHandle chandle
34
35    cdef _set_handle(self, handle):
36        cdef unsigned long long ptr
37        if handle is None:
38            self.chandle = NULL
39        else:
40            ptr = handle.value
41            self.chandle = <SymbolHandle>(ptr)
42
43    property handle:
44        def __get__(self):
45            if self.chandle == NULL:
46                return None
47            else:
48                return _ctypes.cast(<unsigned long long>self.chandle, _ctypes.c_void_p)
49        def __set__(self, value):
50            self._set_handle(value)
51
52    def __init__(self, handle):
53        self._set_handle(handle)
54
55    def __dealloc__(self):
56        CALL(NNSymbolFree(self.chandle))
57
58    def _set_attr(self, **kwargs):
59        """Set the attribute of the symbol.
60
61        Parameters
62        ----------
63        **kwargs
64            The attributes to set
65        """
66        SymbolSetAttr(self.chandle, kwargs)
67
68    def __reduce__(self):
69        return (_symbol_cls, (None,), self.__getstate__())
70
71
72cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
73    cdef string sparam_key
74    cdef string sparam_val
75    cdef const char* param_key
76    cdef const char* param_val
77    for k, v in kwargs.items():
78        sparam_key = c_str(k)
79        sparam_val = c_str(str(v))
80        param_key = sparam_key.c_str()
81        param_val = sparam_val.c_str()
82        CALL(MXSymbolSetAttr(handle, param_key, param_val))
83
84
85_symbol_cls = SymbolBase
86_np_symbol_cls = None
87
88def _set_symbol_class(cls):
89    global _symbol_cls
90    _symbol_cls = cls
91
92
93def _set_np_symbol_class(cls):
94    global _np_symbol_cls
95    _np_symbol_cls = cls
96
97
98cdef NewSymbol(SymbolHandle handle, int is_np_sym=0):
99    """Create a new symbol given handle"""
100    create_symbol_fn = _np_symbol_cls if is_np_sym else _symbol_cls
101    sym = create_symbol_fn(None)
102    (<SymbolBase>sym).chandle = handle
103    return sym
104
105
106def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op=0, output_is_list=0):
107    cdef unsigned long long ihandle = handle
108    cdef OpHandle chandle = <OpHandle>ihandle
109    cdef vector[string] ckeys
110    cdef vector[string] cvals
111    cdef vector[string] sym_keys
112    cdef vector[SymbolHandle] sym_args
113    cdef SymbolHandle ret_handle
114    cdef string cname = c_str(name)
115    cdef nn_uint nout
116
117    for i in keys:
118        ckeys.push_back(c_str(i))
119    for i in vals:
120        cvals.push_back(c_str(str(i)))
121
122    cdef vector[const char*] param_keys = SVec2Ptr(ckeys)
123    cdef vector[const char*] param_vals = SVec2Ptr(cvals)
124
125    CALL(MXSymbolCreateAtomicSymbol(
126        chandle,
127        <nn_uint>param_keys.size(),
128        CBeginPtr(param_keys),
129        CBeginPtr(param_vals),
130        &ret_handle))
131
132    if args and kwargs:
133        raise TypeError(
134            'Operators with variable length input can only accept input'
135            'Symbols either as positional or keyword arguments, not both')
136
137    if args:
138        for i in args:
139            sym_args.push_back((<SymbolBase>i).chandle)
140    elif kwargs:
141        for k, v in kwargs.items():
142            sym_keys.push_back(c_str(k))
143            sym_args.push_back((<SymbolBase>v).chandle)
144
145    cdef vector[const char*] csym_keys = SVec2Ptr(sym_keys)
146
147    CALL(NNSymbolCompose(
148        ret_handle,
149        cname.c_str(),
150        <nn_uint>sym_args.size(),
151        &csym_keys[0] if csym_keys.size() != 0 else NULL,
152        &sym_args[0] if sym_args.size() != 0 else NULL))
153
154    sym = NewSymbol(ret_handle, is_np_op)
155    if is_np_op:
156        CALL(NNSymbolGetNumOutputs(ret_handle, &nout))
157        if nout > 1:
158            return list(sym)
159        elif output_is_list:
160            return [sym]
161    return sym
162